Spaces:
Running
on
Zero
Running
on
Zero
Alexander Bagus
commited on
Commit
·
d2c9b66
1
Parent(s):
df103d9
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +19 -0
- README.md +8 -5
- app.py +224 -4
- examples/depth.jpg +0 -0
- examples/hed.jpg +0 -0
- examples/pose.jpg +0 -0
- examples/pose2.jpg +0 -0
- image_utils.py +70 -0
- predict_t2i_control.py +228 -0
- requirements.txt +7 -0
- static/data.json +8 -0
- static/footer.html +16 -0
- static/header.html +11 -0
- videox_fun/__init__.py +0 -0
- videox_fun/api/api.py +226 -0
- videox_fun/api/api_multi_nodes.py +320 -0
- videox_fun/data/__init__.py +9 -0
- videox_fun/data/bucket_sampler.py +379 -0
- videox_fun/data/dataset_image.py +191 -0
- videox_fun/data/dataset_image_video.py +657 -0
- videox_fun/data/dataset_video.py +901 -0
- videox_fun/data/utils.py +347 -0
- videox_fun/pipeline/__init__.py +62 -0
- videox_fun/pipeline/pipeline_cogvideox_fun.py +862 -0
- videox_fun/pipeline/pipeline_cogvideox_fun_control.py +956 -0
- videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py +1136 -0
- videox_fun/pipeline/pipeline_fantasy_talking.py +754 -0
- videox_fun/pipeline/pipeline_flux.py +978 -0
- videox_fun/pipeline/pipeline_flux2.py +900 -0
- videox_fun/pipeline/pipeline_flux2_control.py +973 -0
- videox_fun/pipeline/pipeline_hunyuanvideo.py +805 -0
- videox_fun/pipeline/pipeline_hunyuanvideo_i2v.py +972 -0
- videox_fun/pipeline/pipeline_qwenimage.py +767 -0
- videox_fun/pipeline/pipeline_qwenimage_edit.py +952 -0
- videox_fun/pipeline/pipeline_qwenimage_edit_plus.py +937 -0
- videox_fun/pipeline/pipeline_wan.py +576 -0
- videox_fun/pipeline/pipeline_wan2_2.py +591 -0
- videox_fun/pipeline/pipeline_wan2_2_animate.py +929 -0
- videox_fun/pipeline/pipeline_wan2_2_fun_control.py +903 -0
- videox_fun/pipeline/pipeline_wan2_2_fun_inpaint.py +752 -0
- videox_fun/pipeline/pipeline_wan2_2_s2v.py +815 -0
- videox_fun/pipeline/pipeline_wan2_2_ti2v.py +732 -0
- videox_fun/pipeline/pipeline_wan2_2_vace_fun.py +801 -0
- videox_fun/pipeline/pipeline_wan_fun_control.py +799 -0
- videox_fun/pipeline/pipeline_wan_fun_inpaint.py +734 -0
- videox_fun/pipeline/pipeline_wan_phantom.py +695 -0
- videox_fun/pipeline/pipeline_wan_vace.py +787 -0
- videox_fun/pipeline/pipeline_z_image.py +613 -0
- videox_fun/pipeline/pipeline_z_image_control.py +633 -0
- videox_fun/reward/MPS/README.md +1 -0
.gitignore
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
models/
|
| 5 |
+
|
| 6 |
+
# Packages
|
| 7 |
+
*.egg
|
| 8 |
+
*.egg-info
|
| 9 |
+
dist
|
| 10 |
+
build
|
| 11 |
+
eggs
|
| 12 |
+
parts
|
| 13 |
+
bin
|
| 14 |
+
var
|
| 15 |
+
sdist
|
| 16 |
+
develop-eggs
|
| 17 |
+
.installed.cfg
|
| 18 |
+
lib64
|
| 19 |
+
__pycache__
|
README.md
CHANGED
|
@@ -1,14 +1,17 @@
|
|
| 1 |
---
|
| 2 |
title: ZIT Controlnet
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
short_description: Supports Canny, HED, Depth, Pose and MLSD
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: ZIT Controlnet
|
| 3 |
+
emoji: 🖼
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.44.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
short_description: Supports Canny, HED, Depth, Pose and MLSD
|
| 12 |
+
models:
|
| 13 |
+
- Tongyi-MAI/Z-Image-Turbo
|
| 14 |
+
- alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union
|
| 15 |
---
|
| 16 |
|
| 17 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
|
@@ -1,7 +1,227 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
def greet(name):
|
| 4 |
-
return "Hello " + name + "!!"
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import json
|
| 5 |
+
import spaces
|
| 6 |
+
import torch
|
| 7 |
+
from diffusers import DiffusionPipeline
|
| 8 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
|
| 9 |
+
from videox_fun.pipeline import ZImageControlPipeline
|
| 10 |
+
from videox_fun.models import ZImageControlTransformer2DModel
|
| 11 |
+
from transformers import AutoTokenizer, Qwen3ForCausalLM
|
| 12 |
+
from diffusers import AutoencoderKL
|
| 13 |
+
from image_utils import get_image_latent, scale_image
|
| 14 |
+
# from videox_fun.utils.utils import get_image_latent
|
| 15 |
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo"
|
| 18 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 19 |
+
MAX_IMAGE_SIZE = 1280
|
| 20 |
+
|
| 21 |
+
MODEL_LOCAL = "models/Z-Image-Turbo/"
|
| 22 |
+
TRANSFORMER_LOCAL = "models/Z-Image-Turbo-Fun-Controlnet-Union.safetensors"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
weight_dtype = torch.bfloat16
|
| 26 |
+
|
| 27 |
+
# load transformer
|
| 28 |
+
transformer = ZImageControlTransformer2DModel.from_pretrained(
|
| 29 |
+
MODEL_LOCAL,
|
| 30 |
+
subfolder="transformer",
|
| 31 |
+
low_cpu_mem_usage=True,
|
| 32 |
+
torch_dtype=torch.bfloat16,
|
| 33 |
+
transformer_additional_kwargs={
|
| 34 |
+
"control_layers_places": [0, 5, 10, 15, 20, 25],
|
| 35 |
+
"control_in_dim": 16
|
| 36 |
+
},
|
| 37 |
+
).to(torch.bfloat16)
|
| 38 |
+
|
| 39 |
+
if TRANSFORMER_LOCAL is not None:
|
| 40 |
+
print(f"From checkpoint: {TRANSFORMER_LOCAL}")
|
| 41 |
+
if TRANSFORMER_LOCAL.endswith("safetensors"):
|
| 42 |
+
from safetensors.torch import load_file, safe_open
|
| 43 |
+
state_dict = load_file(TRANSFORMER_LOCAL)
|
| 44 |
+
else:
|
| 45 |
+
state_dict = torch.load(TRANSFORMER_LOCAL, map_location="cpu")
|
| 46 |
+
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
|
| 47 |
+
|
| 48 |
+
m, u = transformer.load_state_dict(state_dict, strict=False)
|
| 49 |
+
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
|
| 50 |
+
|
| 51 |
+
# Load MODEL_REPO
|
| 52 |
+
# Get Vae
|
| 53 |
+
vae = AutoencoderKL.from_pretrained(
|
| 54 |
+
MODEL_LOCAL,
|
| 55 |
+
subfolder="vae"
|
| 56 |
+
).to(weight_dtype)
|
| 57 |
+
|
| 58 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 59 |
+
MODEL_LOCAL, subfolder="tokenizer"
|
| 60 |
+
)
|
| 61 |
+
text_encoder = Qwen3ForCausalLM.from_pretrained(
|
| 62 |
+
MODEL_LOCAL, subfolder="text_encoder", torch_dtype=weight_dtype,
|
| 63 |
+
low_cpu_mem_usage=True,
|
| 64 |
+
)
|
| 65 |
+
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3)
|
| 66 |
+
pipe = ZImageControlPipeline(
|
| 67 |
+
vae=vae,
|
| 68 |
+
tokenizer=tokenizer,
|
| 69 |
+
text_encoder=text_encoder,
|
| 70 |
+
transformer=transformer,
|
| 71 |
+
scheduler=scheduler,
|
| 72 |
+
)
|
| 73 |
+
pipe.transformer = transformer
|
| 74 |
+
pipe.to("cuda")
|
| 75 |
+
|
| 76 |
+
# ======== AoTI compilation + FA3 ========
|
| 77 |
+
pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
|
| 78 |
+
spaces.aoti_blocks_load(pipe.transformer.layers,
|
| 79 |
+
"zerogpu-aoti/Z-Image", variant="fa3")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@spaces.GPU
|
| 83 |
+
def inference(
|
| 84 |
+
prompt,
|
| 85 |
+
input_image,
|
| 86 |
+
image_scale=1.0,
|
| 87 |
+
control_context_scale = 0.75,
|
| 88 |
+
seed=42,
|
| 89 |
+
randomize_seed=True,
|
| 90 |
+
guidance_scale=1.5,
|
| 91 |
+
num_inference_steps=8,
|
| 92 |
+
progress=gr.Progress(track_tqdm=True),
|
| 93 |
+
):
|
| 94 |
+
# process image
|
| 95 |
+
if input_image is None:
|
| 96 |
+
print("Error: input_image is empty.")
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
input_image, width, height = scale_image(input_image, image_scale)
|
| 100 |
+
|
| 101 |
+
control_image = get_image_latent(input_image, sample_size=[height, width])[:, :, 0]
|
| 102 |
+
|
| 103 |
+
# generation
|
| 104 |
+
if randomize_seed:
|
| 105 |
+
seed = random.randint(0, MAX_SEED)
|
| 106 |
+
|
| 107 |
+
generator = torch.Generator().manual_seed(seed)
|
| 108 |
+
|
| 109 |
+
image = pipe(
|
| 110 |
+
prompt=prompt,
|
| 111 |
+
height=height,
|
| 112 |
+
width=width,
|
| 113 |
+
generator=generator,
|
| 114 |
+
guidance_scale=guidance_scale,
|
| 115 |
+
control_image=control_image,
|
| 116 |
+
num_inference_steps=num_inference_steps,
|
| 117 |
+
control_context_scale=control_context_scale,
|
| 118 |
+
).images[0]
|
| 119 |
+
|
| 120 |
+
return image, seed
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def read_file(path: str) -> str:
|
| 124 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 125 |
+
content = f.read()
|
| 126 |
+
return content
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
css = """
|
| 130 |
+
#col-container {
|
| 131 |
+
margin: 0 auto;
|
| 132 |
+
max-width: 960px;
|
| 133 |
+
}
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
with open('static/data.json', 'r') as file:
|
| 137 |
+
data = json.load(file)
|
| 138 |
+
examples = data['examples']
|
| 139 |
+
|
| 140 |
+
with gr.Blocks() as demo:
|
| 141 |
+
with gr.Column(elem_id="col-container"):
|
| 142 |
+
with gr.Column():
|
| 143 |
+
gr.HTML(read_file("static/header.html"))
|
| 144 |
+
with gr.Row(equal_height=True):
|
| 145 |
+
with gr.Column():
|
| 146 |
+
input_image = gr.Image(
|
| 147 |
+
height=290, sources=['upload', 'clipboard'],
|
| 148 |
+
image_mode='RGB',
|
| 149 |
+
# elem_id="image_upload",
|
| 150 |
+
type="pil", label="Upload")
|
| 151 |
+
|
| 152 |
+
prompt = gr.Textbox(
|
| 153 |
+
label="Prompt",
|
| 154 |
+
show_label=False,
|
| 155 |
+
lines=2,
|
| 156 |
+
placeholder="Enter your prompt",
|
| 157 |
+
container=False,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
run_button = gr.Button("Run", variant="primary")
|
| 161 |
+
with gr.Column():
|
| 162 |
+
output_image = gr.Image(label="Result", show_label=False)
|
| 163 |
+
|
| 164 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 165 |
+
seed = gr.Slider(
|
| 166 |
+
label="Seed",
|
| 167 |
+
minimum=0,
|
| 168 |
+
maximum=MAX_SEED,
|
| 169 |
+
step=1,
|
| 170 |
+
value=0,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 174 |
+
|
| 175 |
+
with gr.Row():
|
| 176 |
+
image_scale = gr.Slider(
|
| 177 |
+
label="Image scale",
|
| 178 |
+
minimum=0.5,
|
| 179 |
+
maximum=2.0,
|
| 180 |
+
step=0.1,
|
| 181 |
+
value=1.0,
|
| 182 |
+
)
|
| 183 |
+
control_context_scale = gr.Slider(
|
| 184 |
+
label="Control context scale",
|
| 185 |
+
minimum=0.0,
|
| 186 |
+
maximum=1.0,
|
| 187 |
+
step=0.1,
|
| 188 |
+
value=0.75,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
with gr.Row():
|
| 192 |
+
guidance_scale = gr.Slider(
|
| 193 |
+
label="Guidance scale",
|
| 194 |
+
minimum=0.0,
|
| 195 |
+
maximum=10.0,
|
| 196 |
+
step=0.1,
|
| 197 |
+
value=2.5,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
num_inference_steps = gr.Slider(
|
| 201 |
+
label="Number of inference steps",
|
| 202 |
+
minimum=1,
|
| 203 |
+
maximum=30,
|
| 204 |
+
step=1,
|
| 205 |
+
value=8,
|
| 206 |
+
)
|
| 207 |
+
gr.Examples(examples=examples, inputs=[input_image, prompt])
|
| 208 |
+
|
| 209 |
+
gr.HTML(read_file("static/footer.html"))
|
| 210 |
+
gr.on(
|
| 211 |
+
triggers=[run_button.click, prompt.submit],
|
| 212 |
+
fn=inference,
|
| 213 |
+
inputs=[
|
| 214 |
+
prompt,
|
| 215 |
+
input_image,
|
| 216 |
+
image_scale,
|
| 217 |
+
control_context_scale,
|
| 218 |
+
seed,
|
| 219 |
+
randomize_seed,
|
| 220 |
+
guidance_scale,
|
| 221 |
+
num_inference_steps,
|
| 222 |
+
],
|
| 223 |
+
outputs=[output_image, seed],
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
if __name__ == "__main__":
|
| 227 |
+
demo.launch(mcp_server=True)
|
examples/depth.jpg
ADDED
|
examples/hed.jpg
ADDED
|
examples/pose.jpg
ADDED
|
examples/pose2.jpg
ADDED
|
image_utils.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def scale_image(img, scale):
|
| 6 |
+
w, h = img.size
|
| 7 |
+
new_w = int(w * scale)
|
| 8 |
+
new_h = int(h * scale)
|
| 9 |
+
|
| 10 |
+
# Adjust to nearest multiple of 32
|
| 11 |
+
new_w = (new_w // 32) * 32
|
| 12 |
+
new_h = (new_h // 32) * 32
|
| 13 |
+
|
| 14 |
+
return img.resize((new_w, new_h), Image.LANCZOS), new_w, new_h
|
| 15 |
+
|
| 16 |
+
def padding_image(images, new_width, new_height):
|
| 17 |
+
new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255))
|
| 18 |
+
|
| 19 |
+
aspect_ratio = images.width / images.height
|
| 20 |
+
if new_width / new_height > 1:
|
| 21 |
+
if aspect_ratio > new_width / new_height:
|
| 22 |
+
new_img_width = new_width
|
| 23 |
+
new_img_height = int(new_img_width / aspect_ratio)
|
| 24 |
+
else:
|
| 25 |
+
new_img_height = new_height
|
| 26 |
+
new_img_width = int(new_img_height * aspect_ratio)
|
| 27 |
+
else:
|
| 28 |
+
if aspect_ratio > new_width / new_height:
|
| 29 |
+
new_img_width = new_width
|
| 30 |
+
new_img_height = int(new_img_width / aspect_ratio)
|
| 31 |
+
else:
|
| 32 |
+
new_img_height = new_height
|
| 33 |
+
new_img_width = int(new_img_height * aspect_ratio)
|
| 34 |
+
|
| 35 |
+
resized_img = images.resize((new_img_width, new_img_height))
|
| 36 |
+
|
| 37 |
+
paste_x = (new_width - new_img_width) // 2
|
| 38 |
+
paste_y = (new_height - new_img_height) // 2
|
| 39 |
+
|
| 40 |
+
new_image.paste(resized_img, (paste_x, paste_y))
|
| 41 |
+
|
| 42 |
+
return new_image
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_image_latent(ref_image=None, sample_size=None, padding=False):
|
| 46 |
+
if ref_image is not None:
|
| 47 |
+
if isinstance(ref_image, str):
|
| 48 |
+
ref_image = Image.open(ref_image).convert("RGB")
|
| 49 |
+
if padding:
|
| 50 |
+
ref_image = padding_image(
|
| 51 |
+
ref_image, sample_size[1], sample_size[0])
|
| 52 |
+
ref_image = ref_image.resize((sample_size[1], sample_size[0]))
|
| 53 |
+
ref_image = torch.from_numpy(np.array(ref_image))
|
| 54 |
+
ref_image = ref_image.unsqueeze(0).permute(
|
| 55 |
+
[3, 0, 1, 2]).unsqueeze(0) / 255
|
| 56 |
+
elif isinstance(ref_image, Image.Image):
|
| 57 |
+
ref_image = ref_image.convert("RGB")
|
| 58 |
+
if padding:
|
| 59 |
+
ref_image = padding_image(
|
| 60 |
+
ref_image, sample_size[1], sample_size[0])
|
| 61 |
+
ref_image = ref_image.resize((sample_size[1], sample_size[0]))
|
| 62 |
+
ref_image = torch.from_numpy(np.array(ref_image))
|
| 63 |
+
ref_image = ref_image.unsqueeze(0).permute(
|
| 64 |
+
[3, 0, 1, 2]).unsqueeze(0) / 255
|
| 65 |
+
else:
|
| 66 |
+
ref_image = torch.from_numpy(np.array(ref_image))
|
| 67 |
+
ref_image = ref_image.unsqueeze(0).permute(
|
| 68 |
+
[3, 0, 1, 2]).unsqueeze(0) / 255
|
| 69 |
+
|
| 70 |
+
return ref_image
|
predict_t2i_control.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 7 |
+
from omegaconf import OmegaConf
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
current_file_path = os.path.abspath(__file__)
|
| 11 |
+
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
|
| 12 |
+
for project_root in project_roots:
|
| 13 |
+
sys.path.insert(0, project_root) if project_root not in sys.path else None
|
| 14 |
+
|
| 15 |
+
from videox_fun.dist import set_multi_gpus_devices, shard_model
|
| 16 |
+
from videox_fun.models import (AutoencoderKL, AutoTokenizer,
|
| 17 |
+
Qwen3ForCausalLM, ZImageControlTransformer2DModel)
|
| 18 |
+
from videox_fun.models.cache_utils import get_teacache_coefficients
|
| 19 |
+
from videox_fun.pipeline import ZImageControlPipeline
|
| 20 |
+
from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler
|
| 21 |
+
from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 22 |
+
from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8,
|
| 23 |
+
convert_weight_dtype_wrapper)
|
| 24 |
+
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora
|
| 25 |
+
from videox_fun.utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, get_image,
|
| 26 |
+
get_video_to_video_latent,
|
| 27 |
+
save_videos_grid)
|
| 28 |
+
|
| 29 |
+
# GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
|
| 30 |
+
# model_full_load means that the entire model will be moved to the GPU.
|
| 31 |
+
#
|
| 32 |
+
# model_full_load_and_qfloat8 means that the entire model will be moved to the GPU,
|
| 33 |
+
# and the transformer model has been quantized to float8, which can save more GPU memory.
|
| 34 |
+
#
|
| 35 |
+
# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
|
| 36 |
+
#
|
| 37 |
+
# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
|
| 38 |
+
# and the transformer model has been quantized to float8, which can save more GPU memory.
|
| 39 |
+
#
|
| 40 |
+
# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
|
| 41 |
+
# resulting in slower speeds but saving a large amount of GPU memory.
|
| 42 |
+
GPU_memory_mode = "model_cpu_offload"
|
| 43 |
+
# Multi GPUs config
|
| 44 |
+
# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used.
|
| 45 |
+
# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4.
|
| 46 |
+
# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1.
|
| 47 |
+
ulysses_degree = 1
|
| 48 |
+
ring_degree = 1
|
| 49 |
+
# Use FSDP to save more GPU memory in multi gpus.
|
| 50 |
+
fsdp_dit = False
|
| 51 |
+
fsdp_text_encoder = False
|
| 52 |
+
# Compile will give a speedup in fixed resolution and need a little GPU memory.
|
| 53 |
+
# The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload.
|
| 54 |
+
compile_dit = False
|
| 55 |
+
|
| 56 |
+
# Config and model path
|
| 57 |
+
config_path = "config/z_image/z_image_control.yaml"
|
| 58 |
+
# model path
|
| 59 |
+
model_name = "models/Diffusion_Transformer/Z-Image-Turbo/"
|
| 60 |
+
|
| 61 |
+
# Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++"
|
| 62 |
+
sampler_name = "Flow"
|
| 63 |
+
|
| 64 |
+
# Load pretrained model if need
|
| 65 |
+
transformer_path = "models/Personalized_Model/Z-Image-Turbo-Fun-Controlnet-Union.safetensors"
|
| 66 |
+
vae_path = None
|
| 67 |
+
lora_path = None
|
| 68 |
+
|
| 69 |
+
# Other params
|
| 70 |
+
sample_size = [1728, 992]
|
| 71 |
+
|
| 72 |
+
# Use torch.float16 if GPU does not support torch.bfloat16
|
| 73 |
+
# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
|
| 74 |
+
weight_dtype = torch.bfloat16
|
| 75 |
+
control_image = "asset/pose.jpg"
|
| 76 |
+
control_context_scale = 0.75
|
| 77 |
+
|
| 78 |
+
# 使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性
|
| 79 |
+
# 在neg prompt中添加"安静,固定"等词语可以增加动态性。
|
| 80 |
+
prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,透着一股甜美的青春气息;神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,仿佛沉浸在思绪之中。在她身后,是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕。"
|
| 81 |
+
negative_prompt = " "
|
| 82 |
+
guidance_scale = 0.00
|
| 83 |
+
seed = 43
|
| 84 |
+
num_inference_steps = 9
|
| 85 |
+
lora_weight = 0.55
|
| 86 |
+
save_path = "samples/z-image-t2i-control"
|
| 87 |
+
|
| 88 |
+
device = set_multi_gpus_devices(ulysses_degree, ring_degree)
|
| 89 |
+
config = OmegaConf.load(config_path)
|
| 90 |
+
|
| 91 |
+
transformer = ZImageControlTransformer2DModel.from_pretrained(
|
| 92 |
+
model_name,
|
| 93 |
+
subfolder="transformer",
|
| 94 |
+
low_cpu_mem_usage=True,
|
| 95 |
+
torch_dtype=weight_dtype,
|
| 96 |
+
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
|
| 97 |
+
).to(weight_dtype)
|
| 98 |
+
|
| 99 |
+
if transformer_path is not None:
|
| 100 |
+
print(f"From checkpoint: {transformer_path}")
|
| 101 |
+
if transformer_path.endswith("safetensors"):
|
| 102 |
+
from safetensors.torch import load_file, safe_open
|
| 103 |
+
state_dict = load_file(transformer_path)
|
| 104 |
+
else:
|
| 105 |
+
state_dict = torch.load(transformer_path, map_location="cpu")
|
| 106 |
+
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
|
| 107 |
+
|
| 108 |
+
m, u = transformer.load_state_dict(state_dict, strict=False)
|
| 109 |
+
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
|
| 110 |
+
|
| 111 |
+
# Get Vae
|
| 112 |
+
vae = AutoencoderKL.from_pretrained(
|
| 113 |
+
model_name,
|
| 114 |
+
subfolder="vae"
|
| 115 |
+
).to(weight_dtype)
|
| 116 |
+
|
| 117 |
+
if vae_path is not None:
|
| 118 |
+
print(f"From checkpoint: {vae_path}")
|
| 119 |
+
if vae_path.endswith("safetensors"):
|
| 120 |
+
from safetensors.torch import load_file, safe_open
|
| 121 |
+
state_dict = load_file(vae_path)
|
| 122 |
+
else:
|
| 123 |
+
state_dict = torch.load(vae_path, map_location="cpu")
|
| 124 |
+
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
|
| 125 |
+
|
| 126 |
+
m, u = vae.load_state_dict(state_dict, strict=False)
|
| 127 |
+
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
|
| 128 |
+
|
| 129 |
+
# Get tokenizer and text_encoder
|
| 130 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 131 |
+
model_name, subfolder="tokenizer"
|
| 132 |
+
)
|
| 133 |
+
text_encoder = Qwen3ForCausalLM.from_pretrained(
|
| 134 |
+
model_name, subfolder="text_encoder", torch_dtype=weight_dtype,
|
| 135 |
+
low_cpu_mem_usage=True,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Get Scheduler
|
| 139 |
+
Chosen_Scheduler = scheduler_dict = {
|
| 140 |
+
"Flow": FlowMatchEulerDiscreteScheduler,
|
| 141 |
+
"Flow_Unipc": FlowUniPCMultistepScheduler,
|
| 142 |
+
"Flow_DPM++": FlowDPMSolverMultistepScheduler,
|
| 143 |
+
}[sampler_name]
|
| 144 |
+
scheduler = Chosen_Scheduler.from_pretrained(
|
| 145 |
+
model_name,
|
| 146 |
+
subfolder="scheduler"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
pipeline = ZImageControlPipeline(
|
| 150 |
+
vae=vae,
|
| 151 |
+
tokenizer=tokenizer,
|
| 152 |
+
text_encoder=text_encoder,
|
| 153 |
+
transformer=transformer,
|
| 154 |
+
scheduler=scheduler,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
if ulysses_degree > 1 or ring_degree > 1:
|
| 158 |
+
from functools import partial
|
| 159 |
+
transformer.enable_multi_gpus_inference()
|
| 160 |
+
if fsdp_dit:
|
| 161 |
+
shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=list(transformer.transformer_blocks) + list(transformer.single_transformer_blocks))
|
| 162 |
+
pipeline.transformer = shard_fn(pipeline.transformer)
|
| 163 |
+
print("Add FSDP DIT")
|
| 164 |
+
if fsdp_text_encoder:
|
| 165 |
+
shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.language_model.layers, ignored_modules=[text_encoder.language_model.embed_tokens], transformer_layer_cls_to_wrap=["MistralDecoderLayer", "PixtralTransformer"])
|
| 166 |
+
text_encoder = shard_fn(text_encoder)
|
| 167 |
+
print("Add FSDP TEXT ENCODER")
|
| 168 |
+
|
| 169 |
+
if compile_dit:
|
| 170 |
+
for i in range(len(pipeline.transformer.transformer_blocks)):
|
| 171 |
+
pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i])
|
| 172 |
+
print("Add Compile")
|
| 173 |
+
|
| 174 |
+
if GPU_memory_mode == "sequential_cpu_offload":
|
| 175 |
+
pipeline.enable_sequential_cpu_offload(device=device)
|
| 176 |
+
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
|
| 177 |
+
convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device)
|
| 178 |
+
convert_weight_dtype_wrapper(transformer, weight_dtype)
|
| 179 |
+
pipeline.enable_model_cpu_offload(device=device)
|
| 180 |
+
elif GPU_memory_mode == "model_cpu_offload":
|
| 181 |
+
pipeline.enable_model_cpu_offload(device=device)
|
| 182 |
+
elif GPU_memory_mode == "model_full_load_and_qfloat8":
|
| 183 |
+
convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device)
|
| 184 |
+
convert_weight_dtype_wrapper(transformer, weight_dtype)
|
| 185 |
+
pipeline.to(device=device)
|
| 186 |
+
else:
|
| 187 |
+
pipeline.to(device=device)
|
| 188 |
+
|
| 189 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 190 |
+
|
| 191 |
+
if lora_path is not None:
|
| 192 |
+
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)
|
| 193 |
+
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
if control_image is not None:
|
| 196 |
+
control_image = get_image_latent(control_image, sample_size=sample_size)[:, :, 0]
|
| 197 |
+
|
| 198 |
+
sample = pipeline(
|
| 199 |
+
prompt = prompt,
|
| 200 |
+
negative_prompt = negative_prompt,
|
| 201 |
+
height = sample_size[0],
|
| 202 |
+
width = sample_size[1],
|
| 203 |
+
generator = generator,
|
| 204 |
+
guidance_scale = guidance_scale,
|
| 205 |
+
control_image = control_image,
|
| 206 |
+
num_inference_steps = num_inference_steps,
|
| 207 |
+
control_context_scale = control_context_scale,
|
| 208 |
+
).images
|
| 209 |
+
|
| 210 |
+
if lora_path is not None:
|
| 211 |
+
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)
|
| 212 |
+
|
| 213 |
+
def save_results():
|
| 214 |
+
if not os.path.exists(save_path):
|
| 215 |
+
os.makedirs(save_path, exist_ok=True)
|
| 216 |
+
|
| 217 |
+
index = len([path for path in os.listdir(save_path)]) + 1
|
| 218 |
+
prefix = str(index).zfill(8)
|
| 219 |
+
video_path = os.path.join(save_path, prefix + ".png")
|
| 220 |
+
image = sample[0]
|
| 221 |
+
image.save(video_path)
|
| 222 |
+
|
| 223 |
+
if ulysses_degree * ring_degree > 1:
|
| 224 |
+
import torch.distributed as dist
|
| 225 |
+
if dist.get_rank() == 0:
|
| 226 |
+
save_results()
|
| 227 |
+
else:
|
| 228 |
+
save_results()
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
torch
|
| 3 |
+
transformers
|
| 4 |
+
accelerate
|
| 5 |
+
spaces
|
| 6 |
+
git+https://github.com/huggingface/diffusers.git
|
| 7 |
+
kernels
|
static/data.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"examples": [
|
| 3 |
+
["examples/hed.jpg", "A middle-aged man with a short beard, wearing a casual button-down shirt, sitting at a polished dark wooden table, holding a tumbler of whiskey with ice and taking a thoughtful sip. The background is a softly lit."],
|
| 4 |
+
["examples/depth.jpg", "Modern minimalist, clean lines, open plan, natural light, spacious, serene, contemporary, elegant, architectural, inviting, sophisticated, light-filled, harmonious, texture, shadows, high ceilings."],
|
| 5 |
+
["examples/pose.jpg", "A fit, athletic young woman, squatting low, glancing confidently at the camera. She's on a picturesque tropical beach with gentle waves lapping the shore. The image has the crisp, high-contrast look of a fashion magazine cover. Dynamic pose, bright and inviting."],
|
| 6 |
+
["examples/pose2.jpg", "A majestic female paladin in gleaming plate armor, standing tall and proud, bathed in a celestial glow, with a determined expression, holding a radiant sword aloft against a backdrop of a sun-drenched, ancient castle."]
|
| 7 |
+
]
|
| 8 |
+
}
|
static/footer.html
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div>
|
| 2 |
+
I made this space after seeing a Reddit post about using ControlNet editing with Z-Image from Alibaba.
|
| 3 |
+
The code looks solid and serves as a great example. I believe there’s a lot of potential to build on top of this, add new features, and explore even more creative ideas using this technique.
|
| 4 |
+
|
| 5 |
+
<h2>Usage</h2>
|
| 6 |
+
You can change control_context_scale for more control and better detail.
|
| 7 |
+
For best results, use a detailed prompt.
|
| 8 |
+
The recommended control_context_scale range is 0.65 to 0.80.
|
| 9 |
+
|
| 10 |
+
<h2>Reference</h2>
|
| 11 |
+
<ul>
|
| 12 |
+
<li>Tongyi-MAI/Z-Image-Turbo: <a href="https://huggingface.co/Tongyi-MAI/Z-Image-Turbo">https://huggingface.co/Tongyi-MAI/Z-Image-Turbo</a></li>
|
| 13 |
+
<li>alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union: <a href="https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union">https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union</a></li>
|
| 14 |
+
<li>VideoX-Fun: <a href="https://github.com/aigc-apps/VideoX-Fun">https://github.com/aigc-apps/VideoX-Fun</a></li>
|
| 15 |
+
</ul>
|
| 16 |
+
</div>
|
static/header.html
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div style="text-align: center; max-width: 600px; margin: 0 auto;">
|
| 2 |
+
<h1>
|
| 3 |
+
Z Image Turbo (ZIT) - Controlnet
|
| 4 |
+
</h1>
|
| 5 |
+
<div class="grid-container" >
|
| 6 |
+
<p>
|
| 7 |
+
Supports multiple control conditions - including Canny, HED, Depth, Pose and MLSD.
|
| 8 |
+
<br>
|
| 9 |
+
If you like my spaces, please support me by visiting <a href="https://aisudo.com/" target="_blank">AiSudo</a> for more image generation 😊
|
| 10 |
+
</div>
|
| 11 |
+
</div>
|
videox_fun/__init__.py
ADDED
|
File without changes
|
videox_fun/api/api.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import gc
|
| 3 |
+
import hashlib
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import requests
|
| 11 |
+
import torch
|
| 12 |
+
from fastapi import FastAPI
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Function to encode a file to Base64
|
| 17 |
+
def encode_file_to_base64(file_path):
|
| 18 |
+
with open(file_path, "rb") as file:
|
| 19 |
+
# Encode the data to Base64
|
| 20 |
+
file_base64 = base64.b64encode(file.read())
|
| 21 |
+
return file_base64
|
| 22 |
+
|
| 23 |
+
def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
|
| 24 |
+
@app.post("/videox_fun/update_diffusion_transformer")
|
| 25 |
+
def _update_diffusion_transformer_api(
|
| 26 |
+
datas: dict,
|
| 27 |
+
):
|
| 28 |
+
diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
controller.update_diffusion_transformer(
|
| 32 |
+
diffusion_transformer_path
|
| 33 |
+
)
|
| 34 |
+
comment = "Success"
|
| 35 |
+
except Exception as e:
|
| 36 |
+
torch.cuda.empty_cache()
|
| 37 |
+
comment = f"Error. error information is {str(e)}"
|
| 38 |
+
|
| 39 |
+
return {"message": comment}
|
| 40 |
+
|
| 41 |
+
def download_from_url(url, timeout=10):
|
| 42 |
+
try:
|
| 43 |
+
response = requests.get(url, timeout=timeout)
|
| 44 |
+
response.raise_for_status() # 检查请求是否成功
|
| 45 |
+
return response.content
|
| 46 |
+
except requests.exceptions.RequestException as e:
|
| 47 |
+
print(f"Error downloading from {url}: {e}")
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
def save_base64_video(base64_string):
|
| 51 |
+
video_data = base64.b64decode(base64_string)
|
| 52 |
+
|
| 53 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
| 54 |
+
filename = f"{md5_hash}.mp4"
|
| 55 |
+
|
| 56 |
+
temp_dir = tempfile.gettempdir()
|
| 57 |
+
file_path = os.path.join(temp_dir, filename)
|
| 58 |
+
|
| 59 |
+
with open(file_path, 'wb') as video_file:
|
| 60 |
+
video_file.write(video_data)
|
| 61 |
+
|
| 62 |
+
return file_path
|
| 63 |
+
|
| 64 |
+
def save_base64_image(base64_string):
|
| 65 |
+
video_data = base64.b64decode(base64_string)
|
| 66 |
+
|
| 67 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
| 68 |
+
filename = f"{md5_hash}.jpg"
|
| 69 |
+
|
| 70 |
+
temp_dir = tempfile.gettempdir()
|
| 71 |
+
file_path = os.path.join(temp_dir, filename)
|
| 72 |
+
|
| 73 |
+
with open(file_path, 'wb') as video_file:
|
| 74 |
+
video_file.write(video_data)
|
| 75 |
+
|
| 76 |
+
return file_path
|
| 77 |
+
|
| 78 |
+
def save_url_video(url):
|
| 79 |
+
video_data = download_from_url(url)
|
| 80 |
+
if video_data:
|
| 81 |
+
return save_base64_video(base64.b64encode(video_data))
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
def save_url_image(url):
|
| 85 |
+
image_data = download_from_url(url)
|
| 86 |
+
if image_data:
|
| 87 |
+
return save_base64_image(base64.b64encode(image_data))
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
| 91 |
+
@app.post("/videox_fun/infer_forward")
|
| 92 |
+
def _infer_forward_api(
|
| 93 |
+
datas: dict,
|
| 94 |
+
):
|
| 95 |
+
base_model_path = datas.get('base_model_path', 'none')
|
| 96 |
+
base_model_2_path = datas.get('base_model_2_path', 'none')
|
| 97 |
+
lora_model_path = datas.get('lora_model_path', 'none')
|
| 98 |
+
lora_model_2_path = datas.get('lora_model_2_path', 'none')
|
| 99 |
+
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
| 100 |
+
prompt_textbox = datas.get('prompt_textbox', None)
|
| 101 |
+
negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
|
| 102 |
+
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
| 103 |
+
sample_step_slider = datas.get('sample_step_slider', 30)
|
| 104 |
+
resize_method = datas.get('resize_method', "Generate by")
|
| 105 |
+
width_slider = datas.get('width_slider', 672)
|
| 106 |
+
height_slider = datas.get('height_slider', 384)
|
| 107 |
+
base_resolution = datas.get('base_resolution', 512)
|
| 108 |
+
is_image = datas.get('is_image', False)
|
| 109 |
+
generation_method = datas.get('generation_method', False)
|
| 110 |
+
length_slider = datas.get('length_slider', 49)
|
| 111 |
+
overlap_video_length = datas.get('overlap_video_length', 4)
|
| 112 |
+
partial_video_length = datas.get('partial_video_length', 72)
|
| 113 |
+
cfg_scale_slider = datas.get('cfg_scale_slider', 6)
|
| 114 |
+
start_image = datas.get('start_image', None)
|
| 115 |
+
end_image = datas.get('end_image', None)
|
| 116 |
+
validation_video = datas.get('validation_video', None)
|
| 117 |
+
validation_video_mask = datas.get('validation_video_mask', None)
|
| 118 |
+
control_video = datas.get('control_video', None)
|
| 119 |
+
denoise_strength = datas.get('denoise_strength', 0.70)
|
| 120 |
+
seed_textbox = datas.get("seed_textbox", 43)
|
| 121 |
+
|
| 122 |
+
ref_image = datas.get('ref_image', None)
|
| 123 |
+
enable_teacache = datas.get('enable_teacache', True)
|
| 124 |
+
teacache_threshold = datas.get('teacache_threshold', 0.10)
|
| 125 |
+
num_skip_start_steps = datas.get('num_skip_start_steps', 1)
|
| 126 |
+
teacache_offload = datas.get('teacache_offload', False)
|
| 127 |
+
cfg_skip_ratio = datas.get('cfg_skip_ratio', 0)
|
| 128 |
+
enable_riflex = datas.get('enable_riflex', False)
|
| 129 |
+
riflex_k = datas.get('riflex_k', 6)
|
| 130 |
+
fps = datas.get('fps', None)
|
| 131 |
+
|
| 132 |
+
generation_method = "Image Generation" if is_image else generation_method
|
| 133 |
+
|
| 134 |
+
if start_image is not None:
|
| 135 |
+
if start_image.startswith('http'):
|
| 136 |
+
start_image = save_url_image(start_image)
|
| 137 |
+
start_image = [Image.open(start_image).convert("RGB")]
|
| 138 |
+
else:
|
| 139 |
+
start_image = base64.b64decode(start_image)
|
| 140 |
+
start_image = [Image.open(BytesIO(start_image)).convert("RGB")]
|
| 141 |
+
|
| 142 |
+
if end_image is not None:
|
| 143 |
+
if end_image.startswith('http'):
|
| 144 |
+
end_image = save_url_image(end_image)
|
| 145 |
+
end_image = [Image.open(end_image).convert("RGB")]
|
| 146 |
+
else:
|
| 147 |
+
end_image = base64.b64decode(end_image)
|
| 148 |
+
end_image = [Image.open(BytesIO(end_image)).convert("RGB")]
|
| 149 |
+
|
| 150 |
+
if validation_video is not None:
|
| 151 |
+
if validation_video.startswith('http'):
|
| 152 |
+
validation_video = save_url_video(validation_video)
|
| 153 |
+
else:
|
| 154 |
+
validation_video = save_base64_video(validation_video)
|
| 155 |
+
|
| 156 |
+
if validation_video_mask is not None:
|
| 157 |
+
if validation_video_mask.startswith('http'):
|
| 158 |
+
validation_video_mask = save_url_image(validation_video_mask)
|
| 159 |
+
else:
|
| 160 |
+
validation_video_mask = save_base64_image(validation_video_mask)
|
| 161 |
+
|
| 162 |
+
if control_video is not None:
|
| 163 |
+
if control_video.startswith('http'):
|
| 164 |
+
control_video = save_url_video(control_video)
|
| 165 |
+
else:
|
| 166 |
+
control_video = save_base64_video(control_video)
|
| 167 |
+
|
| 168 |
+
if ref_image is not None:
|
| 169 |
+
if ref_image.startswith('http'):
|
| 170 |
+
ref_image = save_url_image(ref_image)
|
| 171 |
+
ref_image = [Image.open(ref_image).convert("RGB")]
|
| 172 |
+
else:
|
| 173 |
+
ref_image = base64.b64decode(ref_image)
|
| 174 |
+
ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")]
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
save_sample_path, comment = controller.generate(
|
| 178 |
+
"",
|
| 179 |
+
base_model_path,
|
| 180 |
+
lora_model_path,
|
| 181 |
+
lora_alpha_slider,
|
| 182 |
+
prompt_textbox,
|
| 183 |
+
negative_prompt_textbox,
|
| 184 |
+
sampler_dropdown,
|
| 185 |
+
sample_step_slider,
|
| 186 |
+
resize_method,
|
| 187 |
+
width_slider,
|
| 188 |
+
height_slider,
|
| 189 |
+
base_resolution,
|
| 190 |
+
generation_method,
|
| 191 |
+
length_slider,
|
| 192 |
+
overlap_video_length,
|
| 193 |
+
partial_video_length,
|
| 194 |
+
cfg_scale_slider,
|
| 195 |
+
start_image,
|
| 196 |
+
end_image,
|
| 197 |
+
validation_video,
|
| 198 |
+
validation_video_mask,
|
| 199 |
+
control_video,
|
| 200 |
+
denoise_strength,
|
| 201 |
+
seed_textbox,
|
| 202 |
+
ref_image = ref_image,
|
| 203 |
+
enable_teacache = enable_teacache,
|
| 204 |
+
teacache_threshold = teacache_threshold,
|
| 205 |
+
num_skip_start_steps = num_skip_start_steps,
|
| 206 |
+
teacache_offload = teacache_offload,
|
| 207 |
+
cfg_skip_ratio = cfg_skip_ratio,
|
| 208 |
+
enable_riflex = enable_riflex,
|
| 209 |
+
riflex_k = riflex_k,
|
| 210 |
+
base_model_2_dropdown = base_model_2_path,
|
| 211 |
+
lora_model_2_dropdown = lora_model_2_path,
|
| 212 |
+
fps = fps,
|
| 213 |
+
is_api = True,
|
| 214 |
+
)
|
| 215 |
+
except Exception as e:
|
| 216 |
+
gc.collect()
|
| 217 |
+
torch.cuda.empty_cache()
|
| 218 |
+
torch.cuda.ipc_collect()
|
| 219 |
+
save_sample_path = ""
|
| 220 |
+
comment = f"Error. error information is {str(e)}"
|
| 221 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 222 |
+
|
| 223 |
+
if save_sample_path != "":
|
| 224 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
|
| 225 |
+
else:
|
| 226 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": None}
|
videox_fun/api/api_multi_nodes.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is modified from https://github.com/xdit-project/xDiT/blob/main/entrypoints/launch.py
|
| 2 |
+
import base64
|
| 3 |
+
import gc
|
| 4 |
+
import hashlib
|
| 5 |
+
import io
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import requests
|
| 12 |
+
import torch
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
from fastapi import FastAPI, HTTPException
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
from .api import download_from_url, encode_file_to_base64
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import ray
|
| 21 |
+
except:
|
| 22 |
+
print("Ray is not installed. If you want to use multi gpus api. Please install it by running 'pip install ray'.")
|
| 23 |
+
ray = None
|
| 24 |
+
|
| 25 |
+
def save_base64_video_dist(base64_string):
|
| 26 |
+
video_data = base64.b64decode(base64_string)
|
| 27 |
+
|
| 28 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
| 29 |
+
filename = f"{md5_hash}.mp4"
|
| 30 |
+
|
| 31 |
+
temp_dir = tempfile.gettempdir()
|
| 32 |
+
file_path = os.path.join(temp_dir, filename)
|
| 33 |
+
|
| 34 |
+
if dist.is_initialized():
|
| 35 |
+
if dist.get_rank() == 0:
|
| 36 |
+
with open(file_path, 'wb') as video_file:
|
| 37 |
+
video_file.write(video_data)
|
| 38 |
+
dist.barrier()
|
| 39 |
+
else:
|
| 40 |
+
with open(file_path, 'wb') as video_file:
|
| 41 |
+
video_file.write(video_data)
|
| 42 |
+
return file_path
|
| 43 |
+
|
| 44 |
+
def save_base64_image_dist(base64_string):
|
| 45 |
+
video_data = base64.b64decode(base64_string)
|
| 46 |
+
|
| 47 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
| 48 |
+
filename = f"{md5_hash}.jpg"
|
| 49 |
+
|
| 50 |
+
temp_dir = tempfile.gettempdir()
|
| 51 |
+
file_path = os.path.join(temp_dir, filename)
|
| 52 |
+
|
| 53 |
+
if dist.is_initialized():
|
| 54 |
+
if dist.get_rank() == 0:
|
| 55 |
+
with open(file_path, 'wb') as video_file:
|
| 56 |
+
video_file.write(video_data)
|
| 57 |
+
dist.barrier()
|
| 58 |
+
else:
|
| 59 |
+
with open(file_path, 'wb') as video_file:
|
| 60 |
+
video_file.write(video_data)
|
| 61 |
+
return file_path
|
| 62 |
+
|
| 63 |
+
def save_url_video_dist(url):
|
| 64 |
+
video_data = download_from_url(url)
|
| 65 |
+
if video_data:
|
| 66 |
+
return save_base64_video_dist(base64.b64encode(video_data))
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
def save_url_image_dist(url):
|
| 70 |
+
image_data = download_from_url(url)
|
| 71 |
+
if image_data:
|
| 72 |
+
return save_base64_image_dist(base64.b64encode(image_data))
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
if ray is not None:
|
| 76 |
+
@ray.remote(num_gpus=1)
|
| 77 |
+
class MultiNodesGenerator:
|
| 78 |
+
def __init__(
|
| 79 |
+
self, rank: int, world_size: int, Controller,
|
| 80 |
+
GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
|
| 81 |
+
config_path=None, ulysses_degree=1, ring_degree=1,
|
| 82 |
+
fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False,
|
| 83 |
+
weight_dtype=None, savedir_sample=None,
|
| 84 |
+
):
|
| 85 |
+
# Set PyTorch distributed environment variables
|
| 86 |
+
os.environ["RANK"] = str(rank)
|
| 87 |
+
os.environ["WORLD_SIZE"] = str(world_size)
|
| 88 |
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
| 89 |
+
os.environ["MASTER_PORT"] = "29500"
|
| 90 |
+
|
| 91 |
+
self.rank = rank
|
| 92 |
+
self.controller = Controller(
|
| 93 |
+
GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
|
| 94 |
+
ulysses_degree=ulysses_degree, ring_degree=ring_degree,
|
| 95 |
+
fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit,
|
| 96 |
+
weight_dtype=weight_dtype, savedir_sample=savedir_sample,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def generate(self, datas):
|
| 100 |
+
try:
|
| 101 |
+
base_model_path = datas.get('base_model_path', 'none')
|
| 102 |
+
base_model_2_path = datas.get('base_model_2_path', 'none')
|
| 103 |
+
lora_model_path = datas.get('lora_model_path', 'none')
|
| 104 |
+
lora_model_2_path = datas.get('lora_model_2_path', 'none')
|
| 105 |
+
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
| 106 |
+
prompt_textbox = datas.get('prompt_textbox', None)
|
| 107 |
+
negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
|
| 108 |
+
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
| 109 |
+
sample_step_slider = datas.get('sample_step_slider', 30)
|
| 110 |
+
resize_method = datas.get('resize_method', "Generate by")
|
| 111 |
+
width_slider = datas.get('width_slider', 672)
|
| 112 |
+
height_slider = datas.get('height_slider', 384)
|
| 113 |
+
base_resolution = datas.get('base_resolution', 512)
|
| 114 |
+
is_image = datas.get('is_image', False)
|
| 115 |
+
generation_method = datas.get('generation_method', False)
|
| 116 |
+
length_slider = datas.get('length_slider', 49)
|
| 117 |
+
overlap_video_length = datas.get('overlap_video_length', 4)
|
| 118 |
+
partial_video_length = datas.get('partial_video_length', 72)
|
| 119 |
+
cfg_scale_slider = datas.get('cfg_scale_slider', 6)
|
| 120 |
+
start_image = datas.get('start_image', None)
|
| 121 |
+
end_image = datas.get('end_image', None)
|
| 122 |
+
validation_video = datas.get('validation_video', None)
|
| 123 |
+
validation_video_mask = datas.get('validation_video_mask', None)
|
| 124 |
+
control_video = datas.get('control_video', None)
|
| 125 |
+
denoise_strength = datas.get('denoise_strength', 0.70)
|
| 126 |
+
seed_textbox = datas.get("seed_textbox", 43)
|
| 127 |
+
|
| 128 |
+
ref_image = datas.get('ref_image', None)
|
| 129 |
+
enable_teacache = datas.get('enable_teacache', True)
|
| 130 |
+
teacache_threshold = datas.get('teacache_threshold', 0.10)
|
| 131 |
+
num_skip_start_steps = datas.get('num_skip_start_steps', 1)
|
| 132 |
+
teacache_offload = datas.get('teacache_offload', False)
|
| 133 |
+
cfg_skip_ratio = datas.get('cfg_skip_ratio', 0)
|
| 134 |
+
enable_riflex = datas.get('enable_riflex', False)
|
| 135 |
+
riflex_k = datas.get('riflex_k', 6)
|
| 136 |
+
fps = datas.get('fps', None)
|
| 137 |
+
|
| 138 |
+
generation_method = "Image Generation" if is_image else generation_method
|
| 139 |
+
|
| 140 |
+
if start_image is not None:
|
| 141 |
+
if start_image.startswith('http'):
|
| 142 |
+
start_image = save_url_image_dist(start_image)
|
| 143 |
+
start_image = [Image.open(start_image).convert("RGB")]
|
| 144 |
+
else:
|
| 145 |
+
start_image = base64.b64decode(start_image)
|
| 146 |
+
start_image = [Image.open(BytesIO(start_image)).convert("RGB")]
|
| 147 |
+
|
| 148 |
+
if end_image is not None:
|
| 149 |
+
if end_image.startswith('http'):
|
| 150 |
+
end_image = save_url_image_dist(end_image)
|
| 151 |
+
end_image = [Image.open(end_image).convert("RGB")]
|
| 152 |
+
else:
|
| 153 |
+
end_image = base64.b64decode(end_image)
|
| 154 |
+
end_image = [Image.open(BytesIO(end_image)).convert("RGB")]
|
| 155 |
+
|
| 156 |
+
if validation_video is not None:
|
| 157 |
+
if validation_video.startswith('http'):
|
| 158 |
+
validation_video = save_url_video_dist(validation_video)
|
| 159 |
+
else:
|
| 160 |
+
validation_video = save_base64_video_dist(validation_video)
|
| 161 |
+
|
| 162 |
+
if validation_video_mask is not None:
|
| 163 |
+
if validation_video_mask.startswith('http'):
|
| 164 |
+
validation_video_mask = save_url_image_dist(validation_video_mask)
|
| 165 |
+
else:
|
| 166 |
+
validation_video_mask = save_base64_image_dist(validation_video_mask)
|
| 167 |
+
|
| 168 |
+
if control_video is not None:
|
| 169 |
+
if control_video.startswith('http'):
|
| 170 |
+
control_video = save_url_video_dist(control_video)
|
| 171 |
+
else:
|
| 172 |
+
control_video = save_base64_video_dist(control_video)
|
| 173 |
+
|
| 174 |
+
if ref_image is not None:
|
| 175 |
+
if ref_image.startswith('http'):
|
| 176 |
+
ref_image = save_url_image_dist(ref_image)
|
| 177 |
+
ref_image = [Image.open(ref_image).convert("RGB")]
|
| 178 |
+
else:
|
| 179 |
+
ref_image = base64.b64decode(ref_image)
|
| 180 |
+
ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")]
|
| 181 |
+
|
| 182 |
+
try:
|
| 183 |
+
save_sample_path, comment = self.controller.generate(
|
| 184 |
+
"",
|
| 185 |
+
base_model_path,
|
| 186 |
+
lora_model_path,
|
| 187 |
+
lora_alpha_slider,
|
| 188 |
+
prompt_textbox,
|
| 189 |
+
negative_prompt_textbox,
|
| 190 |
+
sampler_dropdown,
|
| 191 |
+
sample_step_slider,
|
| 192 |
+
resize_method,
|
| 193 |
+
width_slider,
|
| 194 |
+
height_slider,
|
| 195 |
+
base_resolution,
|
| 196 |
+
generation_method,
|
| 197 |
+
length_slider,
|
| 198 |
+
overlap_video_length,
|
| 199 |
+
partial_video_length,
|
| 200 |
+
cfg_scale_slider,
|
| 201 |
+
start_image,
|
| 202 |
+
end_image,
|
| 203 |
+
validation_video,
|
| 204 |
+
validation_video_mask,
|
| 205 |
+
control_video,
|
| 206 |
+
denoise_strength,
|
| 207 |
+
seed_textbox,
|
| 208 |
+
ref_image = ref_image,
|
| 209 |
+
enable_teacache = enable_teacache,
|
| 210 |
+
teacache_threshold = teacache_threshold,
|
| 211 |
+
num_skip_start_steps = num_skip_start_steps,
|
| 212 |
+
teacache_offload = teacache_offload,
|
| 213 |
+
cfg_skip_ratio = cfg_skip_ratio,
|
| 214 |
+
enable_riflex = enable_riflex,
|
| 215 |
+
riflex_k = riflex_k,
|
| 216 |
+
base_model_2_dropdown = base_model_2_path,
|
| 217 |
+
lora_model_2_dropdown = lora_model_2_path,
|
| 218 |
+
fps = fps,
|
| 219 |
+
is_api = True,
|
| 220 |
+
)
|
| 221 |
+
except Exception as e:
|
| 222 |
+
gc.collect()
|
| 223 |
+
torch.cuda.empty_cache()
|
| 224 |
+
torch.cuda.ipc_collect()
|
| 225 |
+
save_sample_path = ""
|
| 226 |
+
comment = f"Error. error information is {str(e)}"
|
| 227 |
+
if dist.is_initialized():
|
| 228 |
+
if dist.get_rank() == 0:
|
| 229 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 230 |
+
else:
|
| 231 |
+
return None
|
| 232 |
+
else:
|
| 233 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if dist.is_initialized():
|
| 237 |
+
if dist.get_rank() == 0:
|
| 238 |
+
if save_sample_path != "":
|
| 239 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
|
| 240 |
+
else:
|
| 241 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 242 |
+
else:
|
| 243 |
+
return None
|
| 244 |
+
else:
|
| 245 |
+
if save_sample_path != "":
|
| 246 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
|
| 247 |
+
else:
|
| 248 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 249 |
+
|
| 250 |
+
except Exception as e:
|
| 251 |
+
print(f"Error generating: {str(e)}")
|
| 252 |
+
comment = f"Error generating: {str(e)}"
|
| 253 |
+
if dist.is_initialized():
|
| 254 |
+
if dist.get_rank() == 0:
|
| 255 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 256 |
+
else:
|
| 257 |
+
return None
|
| 258 |
+
else:
|
| 259 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 260 |
+
|
| 261 |
+
class MultiNodesEngine:
|
| 262 |
+
def __init__(
|
| 263 |
+
self,
|
| 264 |
+
world_size,
|
| 265 |
+
Controller,
|
| 266 |
+
GPU_memory_mode,
|
| 267 |
+
scheduler_dict,
|
| 268 |
+
model_name,
|
| 269 |
+
model_type,
|
| 270 |
+
config_path,
|
| 271 |
+
ulysses_degree=1,
|
| 272 |
+
ring_degree=1,
|
| 273 |
+
fsdp_dit=False,
|
| 274 |
+
fsdp_text_encoder=False,
|
| 275 |
+
compile_dit=False,
|
| 276 |
+
weight_dtype=torch.bfloat16,
|
| 277 |
+
savedir_sample="samples"
|
| 278 |
+
):
|
| 279 |
+
# Ensure Ray is initialized
|
| 280 |
+
if not ray.is_initialized():
|
| 281 |
+
ray.init()
|
| 282 |
+
|
| 283 |
+
num_workers = world_size
|
| 284 |
+
self.workers = [
|
| 285 |
+
MultiNodesGenerator.remote(
|
| 286 |
+
rank, world_size, Controller,
|
| 287 |
+
GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
|
| 288 |
+
ulysses_degree=ulysses_degree, ring_degree=ring_degree,
|
| 289 |
+
fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit,
|
| 290 |
+
weight_dtype=weight_dtype, savedir_sample=savedir_sample,
|
| 291 |
+
)
|
| 292 |
+
for rank in range(num_workers)
|
| 293 |
+
]
|
| 294 |
+
print("Update workers done")
|
| 295 |
+
|
| 296 |
+
async def generate(self, data):
|
| 297 |
+
results = ray.get([
|
| 298 |
+
worker.generate.remote(data)
|
| 299 |
+
for worker in self.workers
|
| 300 |
+
])
|
| 301 |
+
|
| 302 |
+
return next(path for path in results if path is not None)
|
| 303 |
+
|
| 304 |
+
def multi_nodes_infer_forward_api(_: gr.Blocks, app: FastAPI, engine):
|
| 305 |
+
|
| 306 |
+
@app.post("/videox_fun/infer_forward")
|
| 307 |
+
async def _multi_nodes_infer_forward_api(
|
| 308 |
+
datas: dict,
|
| 309 |
+
):
|
| 310 |
+
try:
|
| 311 |
+
result = await engine.generate(datas)
|
| 312 |
+
return result
|
| 313 |
+
except Exception as e:
|
| 314 |
+
if isinstance(e, HTTPException):
|
| 315 |
+
raise e
|
| 316 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 317 |
+
else:
|
| 318 |
+
MultiNodesEngine = None
|
| 319 |
+
MultiNodesGenerator = None
|
| 320 |
+
multi_nodes_infer_forward_api = None
|
videox_fun/data/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dataset_image import CC15M, ImageEditDataset
|
| 2 |
+
from .dataset_image_video import (ImageVideoControlDataset, ImageVideoDataset, TextDataset,
|
| 3 |
+
ImageVideoSampler)
|
| 4 |
+
from .dataset_video import VideoDataset, VideoSpeechDataset, VideoAnimateDataset, WebVid10M
|
| 5 |
+
from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
|
| 6 |
+
custom_meshgrid, get_random_mask, get_relative_pose,
|
| 7 |
+
get_video_reader_batch, padding_image, process_pose_file,
|
| 8 |
+
process_pose_params, ray_condition, resize_frame,
|
| 9 |
+
resize_image_with_target_area)
|
videox_fun/data/bucket_sampler.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import os
|
| 3 |
+
from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
|
| 4 |
+
Sized, TypeVar, Union)
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torch.utils.data import BatchSampler, Dataset, Sampler
|
| 11 |
+
|
| 12 |
+
ASPECT_RATIO_512 = {
|
| 13 |
+
'0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
|
| 14 |
+
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
|
| 15 |
+
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
|
| 16 |
+
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
|
| 17 |
+
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
|
| 18 |
+
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
|
| 19 |
+
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
|
| 20 |
+
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
|
| 21 |
+
'2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
|
| 22 |
+
'3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
|
| 23 |
+
}
|
| 24 |
+
ASPECT_RATIO_RANDOM_CROP_512 = {
|
| 25 |
+
'0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
|
| 26 |
+
'0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
|
| 27 |
+
'0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
|
| 28 |
+
'1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
|
| 29 |
+
'2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
|
| 30 |
+
}
|
| 31 |
+
ASPECT_RATIO_RANDOM_CROP_PROB = [
|
| 32 |
+
1, 2,
|
| 33 |
+
4, 4, 4, 4,
|
| 34 |
+
8, 8, 8,
|
| 35 |
+
4, 4, 4, 4,
|
| 36 |
+
2, 1
|
| 37 |
+
]
|
| 38 |
+
ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
|
| 39 |
+
|
| 40 |
+
def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
|
| 41 |
+
aspect_ratio = height / width
|
| 42 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
| 43 |
+
return ratios[closest_ratio], float(closest_ratio)
|
| 44 |
+
|
| 45 |
+
def get_image_size_without_loading(path):
|
| 46 |
+
with Image.open(path) as img:
|
| 47 |
+
return img.size # (width, height)
|
| 48 |
+
|
| 49 |
+
class RandomSampler(Sampler[int]):
|
| 50 |
+
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
|
| 51 |
+
|
| 52 |
+
If with replacement, then user can specify :attr:`num_samples` to draw.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
data_source (Dataset): dataset to sample from
|
| 56 |
+
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
|
| 57 |
+
num_samples (int): number of samples to draw, default=`len(dataset)`.
|
| 58 |
+
generator (Generator): Generator used in sampling.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
data_source: Sized
|
| 62 |
+
replacement: bool
|
| 63 |
+
|
| 64 |
+
def __init__(self, data_source: Sized, replacement: bool = False,
|
| 65 |
+
num_samples: Optional[int] = None, generator=None) -> None:
|
| 66 |
+
self.data_source = data_source
|
| 67 |
+
self.replacement = replacement
|
| 68 |
+
self._num_samples = num_samples
|
| 69 |
+
self.generator = generator
|
| 70 |
+
self._pos_start = 0
|
| 71 |
+
|
| 72 |
+
if not isinstance(self.replacement, bool):
|
| 73 |
+
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
|
| 74 |
+
|
| 75 |
+
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
|
| 76 |
+
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def num_samples(self) -> int:
|
| 80 |
+
# dataset size might change at runtime
|
| 81 |
+
if self._num_samples is None:
|
| 82 |
+
return len(self.data_source)
|
| 83 |
+
return self._num_samples
|
| 84 |
+
|
| 85 |
+
def __iter__(self) -> Iterator[int]:
|
| 86 |
+
n = len(self.data_source)
|
| 87 |
+
if self.generator is None:
|
| 88 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 89 |
+
generator = torch.Generator()
|
| 90 |
+
generator.manual_seed(seed)
|
| 91 |
+
else:
|
| 92 |
+
generator = self.generator
|
| 93 |
+
|
| 94 |
+
if self.replacement:
|
| 95 |
+
for _ in range(self.num_samples // 32):
|
| 96 |
+
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
|
| 97 |
+
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
|
| 98 |
+
else:
|
| 99 |
+
for _ in range(self.num_samples // n):
|
| 100 |
+
xx = torch.randperm(n, generator=generator).tolist()
|
| 101 |
+
if self._pos_start >= n:
|
| 102 |
+
self._pos_start = 0
|
| 103 |
+
print("xx top 10", xx[:10], self._pos_start)
|
| 104 |
+
for idx in range(self._pos_start, n):
|
| 105 |
+
yield xx[idx]
|
| 106 |
+
self._pos_start = (self._pos_start + 1) % n
|
| 107 |
+
self._pos_start = 0
|
| 108 |
+
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
|
| 109 |
+
|
| 110 |
+
def __len__(self) -> int:
|
| 111 |
+
return self.num_samples
|
| 112 |
+
|
| 113 |
+
class AspectRatioBatchImageSampler(BatchSampler):
|
| 114 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
sampler (Sampler): Base sampler.
|
| 118 |
+
dataset (Dataset): Dataset providing data information.
|
| 119 |
+
batch_size (int): Size of mini-batch.
|
| 120 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 121 |
+
its size would be less than ``batch_size``.
|
| 122 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 123 |
+
"""
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
sampler: Sampler,
|
| 127 |
+
dataset: Dataset,
|
| 128 |
+
batch_size: int,
|
| 129 |
+
train_folder: str = None,
|
| 130 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
| 131 |
+
drop_last: bool = False,
|
| 132 |
+
config=None,
|
| 133 |
+
**kwargs
|
| 134 |
+
) -> None:
|
| 135 |
+
if not isinstance(sampler, Sampler):
|
| 136 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 137 |
+
f'but got {sampler}')
|
| 138 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 139 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 140 |
+
f'but got batch_size={batch_size}')
|
| 141 |
+
self.sampler = sampler
|
| 142 |
+
self.dataset = dataset
|
| 143 |
+
self.train_folder = train_folder
|
| 144 |
+
self.batch_size = batch_size
|
| 145 |
+
self.aspect_ratios = aspect_ratios
|
| 146 |
+
self.drop_last = drop_last
|
| 147 |
+
self.config = config
|
| 148 |
+
# buckets for each aspect ratio
|
| 149 |
+
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
|
| 150 |
+
# [str(k) for k, v in aspect_ratios]
|
| 151 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
| 152 |
+
|
| 153 |
+
def __iter__(self):
|
| 154 |
+
for idx in self.sampler:
|
| 155 |
+
try:
|
| 156 |
+
image_dict = self.dataset[idx]
|
| 157 |
+
|
| 158 |
+
width, height = image_dict.get("width", None), image_dict.get("height", None)
|
| 159 |
+
if width is None or height is None:
|
| 160 |
+
image_id, name = image_dict['file_path'], image_dict['text']
|
| 161 |
+
if self.train_folder is None:
|
| 162 |
+
image_dir = image_id
|
| 163 |
+
else:
|
| 164 |
+
image_dir = os.path.join(self.train_folder, image_id)
|
| 165 |
+
|
| 166 |
+
width, height = get_image_size_without_loading(image_dir)
|
| 167 |
+
|
| 168 |
+
ratio = height / width # self.dataset[idx]
|
| 169 |
+
else:
|
| 170 |
+
height = int(height)
|
| 171 |
+
width = int(width)
|
| 172 |
+
ratio = height / width # self.dataset[idx]
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(e)
|
| 175 |
+
continue
|
| 176 |
+
# find the closest aspect ratio
|
| 177 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 178 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 179 |
+
continue
|
| 180 |
+
bucket = self._aspect_ratio_buckets[closest_ratio]
|
| 181 |
+
bucket.append(idx)
|
| 182 |
+
# yield a batch of indices in the same aspect ratio group
|
| 183 |
+
if len(bucket) == self.batch_size:
|
| 184 |
+
yield bucket[:]
|
| 185 |
+
del bucket[:]
|
| 186 |
+
|
| 187 |
+
class AspectRatioBatchSampler(BatchSampler):
|
| 188 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
sampler (Sampler): Base sampler.
|
| 192 |
+
dataset (Dataset): Dataset providing data information.
|
| 193 |
+
batch_size (int): Size of mini-batch.
|
| 194 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 195 |
+
its size would be less than ``batch_size``.
|
| 196 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 197 |
+
"""
|
| 198 |
+
def __init__(
|
| 199 |
+
self,
|
| 200 |
+
sampler: Sampler,
|
| 201 |
+
dataset: Dataset,
|
| 202 |
+
batch_size: int,
|
| 203 |
+
video_folder: str = None,
|
| 204 |
+
train_data_format: str = "webvid",
|
| 205 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
| 206 |
+
drop_last: bool = False,
|
| 207 |
+
config=None,
|
| 208 |
+
**kwargs
|
| 209 |
+
) -> None:
|
| 210 |
+
if not isinstance(sampler, Sampler):
|
| 211 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 212 |
+
f'but got {sampler}')
|
| 213 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 214 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 215 |
+
f'but got batch_size={batch_size}')
|
| 216 |
+
self.sampler = sampler
|
| 217 |
+
self.dataset = dataset
|
| 218 |
+
self.video_folder = video_folder
|
| 219 |
+
self.train_data_format = train_data_format
|
| 220 |
+
self.batch_size = batch_size
|
| 221 |
+
self.aspect_ratios = aspect_ratios
|
| 222 |
+
self.drop_last = drop_last
|
| 223 |
+
self.config = config
|
| 224 |
+
# buckets for each aspect ratio
|
| 225 |
+
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
|
| 226 |
+
# [str(k) for k, v in aspect_ratios]
|
| 227 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
| 228 |
+
|
| 229 |
+
def __iter__(self):
|
| 230 |
+
for idx in self.sampler:
|
| 231 |
+
try:
|
| 232 |
+
video_dict = self.dataset[idx]
|
| 233 |
+
width, more = video_dict.get("width", None), video_dict.get("height", None)
|
| 234 |
+
|
| 235 |
+
if width is None or height is None:
|
| 236 |
+
if self.train_data_format == "normal":
|
| 237 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
| 238 |
+
if self.video_folder is None:
|
| 239 |
+
video_dir = video_id
|
| 240 |
+
else:
|
| 241 |
+
video_dir = os.path.join(self.video_folder, video_id)
|
| 242 |
+
else:
|
| 243 |
+
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
|
| 244 |
+
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
|
| 245 |
+
cap = cv2.VideoCapture(video_dir)
|
| 246 |
+
|
| 247 |
+
# 获取视频尺寸
|
| 248 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
|
| 249 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
|
| 250 |
+
|
| 251 |
+
ratio = height / width # self.dataset[idx]
|
| 252 |
+
else:
|
| 253 |
+
height = int(height)
|
| 254 |
+
width = int(width)
|
| 255 |
+
ratio = height / width # self.dataset[idx]
|
| 256 |
+
except Exception as e:
|
| 257 |
+
print(e, self.dataset[idx], "This item is error, please check it.")
|
| 258 |
+
continue
|
| 259 |
+
# find the closest aspect ratio
|
| 260 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 261 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 262 |
+
continue
|
| 263 |
+
bucket = self._aspect_ratio_buckets[closest_ratio]
|
| 264 |
+
bucket.append(idx)
|
| 265 |
+
# yield a batch of indices in the same aspect ratio group
|
| 266 |
+
if len(bucket) == self.batch_size:
|
| 267 |
+
yield bucket[:]
|
| 268 |
+
del bucket[:]
|
| 269 |
+
|
| 270 |
+
class AspectRatioBatchImageVideoSampler(BatchSampler):
|
| 271 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
sampler (Sampler): Base sampler.
|
| 275 |
+
dataset (Dataset): Dataset providing data information.
|
| 276 |
+
batch_size (int): Size of mini-batch.
|
| 277 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 278 |
+
its size would be less than ``batch_size``.
|
| 279 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
def __init__(self,
|
| 283 |
+
sampler: Sampler,
|
| 284 |
+
dataset: Dataset,
|
| 285 |
+
batch_size: int,
|
| 286 |
+
train_folder: str = None,
|
| 287 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
| 288 |
+
drop_last: bool = False
|
| 289 |
+
) -> None:
|
| 290 |
+
if not isinstance(sampler, Sampler):
|
| 291 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 292 |
+
f'but got {sampler}')
|
| 293 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 294 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 295 |
+
f'but got batch_size={batch_size}')
|
| 296 |
+
self.sampler = sampler
|
| 297 |
+
self.dataset = dataset
|
| 298 |
+
self.train_folder = train_folder
|
| 299 |
+
self.batch_size = batch_size
|
| 300 |
+
self.aspect_ratios = aspect_ratios
|
| 301 |
+
self.drop_last = drop_last
|
| 302 |
+
|
| 303 |
+
# buckets for each aspect ratio
|
| 304 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
| 305 |
+
self.bucket = {
|
| 306 |
+
'image':{ratio: [] for ratio in aspect_ratios},
|
| 307 |
+
'video':{ratio: [] for ratio in aspect_ratios}
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
def __iter__(self):
|
| 311 |
+
for idx in self.sampler:
|
| 312 |
+
content_type = self.dataset[idx].get('type', 'image')
|
| 313 |
+
if content_type == 'image':
|
| 314 |
+
try:
|
| 315 |
+
image_dict = self.dataset[idx]
|
| 316 |
+
|
| 317 |
+
width, height = image_dict.get("width", None), image_dict.get("height", None)
|
| 318 |
+
if width is None or height is None:
|
| 319 |
+
image_id, name = image_dict['file_path'], image_dict['text']
|
| 320 |
+
if self.train_folder is None:
|
| 321 |
+
image_dir = image_id
|
| 322 |
+
else:
|
| 323 |
+
image_dir = os.path.join(self.train_folder, image_id)
|
| 324 |
+
|
| 325 |
+
width, height = get_image_size_without_loading(image_dir)
|
| 326 |
+
|
| 327 |
+
ratio = height / width # self.dataset[idx]
|
| 328 |
+
else:
|
| 329 |
+
height = int(height)
|
| 330 |
+
width = int(width)
|
| 331 |
+
ratio = height / width # self.dataset[idx]
|
| 332 |
+
except Exception as e:
|
| 333 |
+
print(e, self.dataset[idx], "This item is error, please check it.")
|
| 334 |
+
continue
|
| 335 |
+
# find the closest aspect ratio
|
| 336 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 337 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 338 |
+
continue
|
| 339 |
+
bucket = self.bucket['image'][closest_ratio]
|
| 340 |
+
bucket.append(idx)
|
| 341 |
+
# yield a batch of indices in the same aspect ratio group
|
| 342 |
+
if len(bucket) == self.batch_size:
|
| 343 |
+
yield bucket[:]
|
| 344 |
+
del bucket[:]
|
| 345 |
+
else:
|
| 346 |
+
try:
|
| 347 |
+
video_dict = self.dataset[idx]
|
| 348 |
+
width, height = video_dict.get("width", None), video_dict.get("height", None)
|
| 349 |
+
|
| 350 |
+
if width is None or height is None:
|
| 351 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
| 352 |
+
if self.train_folder is None:
|
| 353 |
+
video_dir = video_id
|
| 354 |
+
else:
|
| 355 |
+
video_dir = os.path.join(self.train_folder, video_id)
|
| 356 |
+
cap = cv2.VideoCapture(video_dir)
|
| 357 |
+
|
| 358 |
+
# 获取视频尺寸
|
| 359 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
|
| 360 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
|
| 361 |
+
|
| 362 |
+
ratio = height / width # self.dataset[idx]
|
| 363 |
+
else:
|
| 364 |
+
height = int(height)
|
| 365 |
+
width = int(width)
|
| 366 |
+
ratio = height / width # self.dataset[idx]
|
| 367 |
+
except Exception as e:
|
| 368 |
+
print(e, self.dataset[idx], "This item is error, please check it.")
|
| 369 |
+
continue
|
| 370 |
+
# find the closest aspect ratio
|
| 371 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 372 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 373 |
+
continue
|
| 374 |
+
bucket = self.bucket['video'][closest_ratio]
|
| 375 |
+
bucket.append(idx)
|
| 376 |
+
# yield a batch of indices in the same aspect ratio group
|
| 377 |
+
if len(bucket) == self.batch_size:
|
| 378 |
+
yield bucket[:]
|
| 379 |
+
del bucket[:]
|
videox_fun/data/dataset_image.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torch.utils.data.dataset import Dataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CC15M(Dataset):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
json_path,
|
| 16 |
+
video_folder=None,
|
| 17 |
+
resolution=512,
|
| 18 |
+
enable_bucket=False,
|
| 19 |
+
):
|
| 20 |
+
print(f"loading annotations from {json_path} ...")
|
| 21 |
+
self.dataset = json.load(open(json_path, 'r'))
|
| 22 |
+
self.length = len(self.dataset)
|
| 23 |
+
print(f"data scale: {self.length}")
|
| 24 |
+
|
| 25 |
+
self.enable_bucket = enable_bucket
|
| 26 |
+
self.video_folder = video_folder
|
| 27 |
+
|
| 28 |
+
resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
|
| 29 |
+
self.pixel_transforms = transforms.Compose([
|
| 30 |
+
transforms.Resize(resolution[0]),
|
| 31 |
+
transforms.CenterCrop(resolution),
|
| 32 |
+
transforms.ToTensor(),
|
| 33 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 34 |
+
])
|
| 35 |
+
|
| 36 |
+
def get_batch(self, idx):
|
| 37 |
+
video_dict = self.dataset[idx]
|
| 38 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
| 39 |
+
|
| 40 |
+
if self.video_folder is None:
|
| 41 |
+
video_dir = video_id
|
| 42 |
+
else:
|
| 43 |
+
video_dir = os.path.join(self.video_folder, video_id)
|
| 44 |
+
|
| 45 |
+
pixel_values = Image.open(video_dir).convert("RGB")
|
| 46 |
+
return pixel_values, name
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return self.length
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, idx):
|
| 52 |
+
while True:
|
| 53 |
+
try:
|
| 54 |
+
pixel_values, name = self.get_batch(idx)
|
| 55 |
+
break
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(e)
|
| 58 |
+
idx = random.randint(0, self.length-1)
|
| 59 |
+
|
| 60 |
+
if not self.enable_bucket:
|
| 61 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
| 62 |
+
else:
|
| 63 |
+
pixel_values = np.array(pixel_values)
|
| 64 |
+
|
| 65 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
| 66 |
+
return sample
|
| 67 |
+
|
| 68 |
+
class ImageEditDataset(Dataset):
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
ann_path, data_root=None,
|
| 72 |
+
image_sample_size=512,
|
| 73 |
+
text_drop_ratio=0.1,
|
| 74 |
+
enable_bucket=False,
|
| 75 |
+
enable_inpaint=False,
|
| 76 |
+
return_file_name=False,
|
| 77 |
+
):
|
| 78 |
+
# Loading annotations from files
|
| 79 |
+
print(f"loading annotations from {ann_path} ...")
|
| 80 |
+
if ann_path.endswith('.csv'):
|
| 81 |
+
with open(ann_path, 'r') as csvfile:
|
| 82 |
+
dataset = list(csv.DictReader(csvfile))
|
| 83 |
+
elif ann_path.endswith('.json'):
|
| 84 |
+
dataset = json.load(open(ann_path))
|
| 85 |
+
|
| 86 |
+
self.data_root = data_root
|
| 87 |
+
self.dataset = dataset
|
| 88 |
+
|
| 89 |
+
self.length = len(self.dataset)
|
| 90 |
+
print(f"data scale: {self.length}")
|
| 91 |
+
# TODO: enable bucket training
|
| 92 |
+
self.enable_bucket = enable_bucket
|
| 93 |
+
self.text_drop_ratio = text_drop_ratio
|
| 94 |
+
self.enable_inpaint = enable_inpaint
|
| 95 |
+
self.return_file_name = return_file_name
|
| 96 |
+
|
| 97 |
+
# Image params
|
| 98 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
| 99 |
+
self.image_transforms = transforms.Compose([
|
| 100 |
+
transforms.Resize(min(self.image_sample_size)),
|
| 101 |
+
transforms.CenterCrop(self.image_sample_size),
|
| 102 |
+
transforms.ToTensor(),
|
| 103 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
| 104 |
+
])
|
| 105 |
+
|
| 106 |
+
def get_batch(self, idx):
|
| 107 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 108 |
+
|
| 109 |
+
image_path, text = data_info['file_path'], data_info['text']
|
| 110 |
+
if self.data_root is not None:
|
| 111 |
+
image_path = os.path.join(self.data_root, image_path)
|
| 112 |
+
image = Image.open(image_path).convert('RGB')
|
| 113 |
+
|
| 114 |
+
if not self.enable_bucket:
|
| 115 |
+
raise ValueError("Not enable_bucket is not supported now. ")
|
| 116 |
+
else:
|
| 117 |
+
image = np.expand_dims(np.array(image), 0)
|
| 118 |
+
|
| 119 |
+
source_image_path = data_info.get('source_file_path', [])
|
| 120 |
+
source_image = []
|
| 121 |
+
if isinstance(source_image_path, list):
|
| 122 |
+
for _source_image_path in source_image_path:
|
| 123 |
+
if self.data_root is not None:
|
| 124 |
+
_source_image_path = os.path.join(self.data_root, _source_image_path)
|
| 125 |
+
_source_image = Image.open(_source_image_path).convert('RGB')
|
| 126 |
+
source_image.append(_source_image)
|
| 127 |
+
else:
|
| 128 |
+
if self.data_root is not None:
|
| 129 |
+
_source_image_path = os.path.join(self.data_root, source_image_path)
|
| 130 |
+
_source_image = Image.open(_source_image_path).convert('RGB')
|
| 131 |
+
source_image.append(_source_image)
|
| 132 |
+
|
| 133 |
+
if not self.enable_bucket:
|
| 134 |
+
raise ValueError("Not enable_bucket is not supported now. ")
|
| 135 |
+
else:
|
| 136 |
+
source_image = [np.array(_source_image) for _source_image in source_image]
|
| 137 |
+
|
| 138 |
+
if random.random() < self.text_drop_ratio:
|
| 139 |
+
text = ''
|
| 140 |
+
return image, source_image, text, 'image', image_path
|
| 141 |
+
|
| 142 |
+
def __len__(self):
|
| 143 |
+
return self.length
|
| 144 |
+
|
| 145 |
+
def __getitem__(self, idx):
|
| 146 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 147 |
+
data_type = data_info.get('type', 'image')
|
| 148 |
+
while True:
|
| 149 |
+
sample = {}
|
| 150 |
+
try:
|
| 151 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 152 |
+
data_type_local = data_info_local.get('type', 'image')
|
| 153 |
+
if data_type_local != data_type:
|
| 154 |
+
raise ValueError("data_type_local != data_type")
|
| 155 |
+
|
| 156 |
+
pixel_values, source_pixel_values, name, data_type, file_path = self.get_batch(idx)
|
| 157 |
+
sample["pixel_values"] = pixel_values
|
| 158 |
+
sample["source_pixel_values"] = source_pixel_values
|
| 159 |
+
sample["text"] = name
|
| 160 |
+
sample["data_type"] = data_type
|
| 161 |
+
sample["idx"] = idx
|
| 162 |
+
if self.return_file_name:
|
| 163 |
+
sample["file_name"] = os.path.basename(file_path)
|
| 164 |
+
|
| 165 |
+
if len(sample) > 0:
|
| 166 |
+
break
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 169 |
+
idx = random.randint(0, self.length-1)
|
| 170 |
+
|
| 171 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 172 |
+
mask = get_random_mask(pixel_values.size())
|
| 173 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
| 174 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 175 |
+
sample["mask"] = mask
|
| 176 |
+
|
| 177 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 178 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 179 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 180 |
+
|
| 181 |
+
return sample
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
dataset = CC15M(
|
| 185 |
+
csv_path="./cc15m_add_index.json",
|
| 186 |
+
resolution=512,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
|
| 190 |
+
for idx, batch in enumerate(dataloader):
|
| 191 |
+
print(batch["pixel_values"].shape, len(batch["text"]))
|
videox_fun/data/dataset_image_video.py
ADDED
|
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import gc
|
| 3 |
+
import io
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from random import shuffle
|
| 10 |
+
from threading import Thread
|
| 11 |
+
|
| 12 |
+
import albumentations
|
| 13 |
+
import cv2
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import torchvision.transforms as transforms
|
| 18 |
+
from decord import VideoReader
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from func_timeout import FunctionTimedOut, func_timeout
|
| 21 |
+
from packaging import version as pver
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from safetensors.torch import load_file
|
| 24 |
+
from torch.utils.data import BatchSampler, Sampler
|
| 25 |
+
from torch.utils.data.dataset import Dataset
|
| 26 |
+
|
| 27 |
+
from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
|
| 28 |
+
custom_meshgrid, get_random_mask, get_relative_pose,
|
| 29 |
+
get_video_reader_batch, padding_image, process_pose_file,
|
| 30 |
+
process_pose_params, ray_condition, resize_frame,
|
| 31 |
+
resize_image_with_target_area)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ImageVideoSampler(BatchSampler):
|
| 35 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
sampler (Sampler): Base sampler.
|
| 39 |
+
dataset (Dataset): Dataset providing data information.
|
| 40 |
+
batch_size (int): Size of mini-batch.
|
| 41 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 42 |
+
its size would be less than ``batch_size``.
|
| 43 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self,
|
| 47 |
+
sampler: Sampler,
|
| 48 |
+
dataset: Dataset,
|
| 49 |
+
batch_size: int,
|
| 50 |
+
drop_last: bool = False
|
| 51 |
+
) -> None:
|
| 52 |
+
if not isinstance(sampler, Sampler):
|
| 53 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 54 |
+
f'but got {sampler}')
|
| 55 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 56 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 57 |
+
f'but got batch_size={batch_size}')
|
| 58 |
+
self.sampler = sampler
|
| 59 |
+
self.dataset = dataset
|
| 60 |
+
self.batch_size = batch_size
|
| 61 |
+
self.drop_last = drop_last
|
| 62 |
+
|
| 63 |
+
# buckets for each aspect ratio
|
| 64 |
+
self.bucket = {'image':[], 'video':[]}
|
| 65 |
+
|
| 66 |
+
def __iter__(self):
|
| 67 |
+
for idx in self.sampler:
|
| 68 |
+
content_type = self.dataset.dataset[idx].get('type', 'image')
|
| 69 |
+
self.bucket[content_type].append(idx)
|
| 70 |
+
|
| 71 |
+
# yield a batch of indices in the same aspect ratio group
|
| 72 |
+
if len(self.bucket['video']) == self.batch_size:
|
| 73 |
+
bucket = self.bucket['video']
|
| 74 |
+
yield bucket[:]
|
| 75 |
+
del bucket[:]
|
| 76 |
+
elif len(self.bucket['image']) == self.batch_size:
|
| 77 |
+
bucket = self.bucket['image']
|
| 78 |
+
yield bucket[:]
|
| 79 |
+
del bucket[:]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ImageVideoDataset(Dataset):
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
ann_path, data_root=None,
|
| 86 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
| 87 |
+
image_sample_size=512,
|
| 88 |
+
video_repeat=0,
|
| 89 |
+
text_drop_ratio=0.1,
|
| 90 |
+
enable_bucket=False,
|
| 91 |
+
video_length_drop_start=0.0,
|
| 92 |
+
video_length_drop_end=1.0,
|
| 93 |
+
enable_inpaint=False,
|
| 94 |
+
return_file_name=False,
|
| 95 |
+
):
|
| 96 |
+
# Loading annotations from files
|
| 97 |
+
print(f"loading annotations from {ann_path} ...")
|
| 98 |
+
if ann_path.endswith('.csv'):
|
| 99 |
+
with open(ann_path, 'r') as csvfile:
|
| 100 |
+
dataset = list(csv.DictReader(csvfile))
|
| 101 |
+
elif ann_path.endswith('.json'):
|
| 102 |
+
dataset = json.load(open(ann_path))
|
| 103 |
+
|
| 104 |
+
self.data_root = data_root
|
| 105 |
+
|
| 106 |
+
# It's used to balance num of images and videos.
|
| 107 |
+
if video_repeat > 0:
|
| 108 |
+
self.dataset = []
|
| 109 |
+
for data in dataset:
|
| 110 |
+
if data.get('type', 'image') != 'video':
|
| 111 |
+
self.dataset.append(data)
|
| 112 |
+
|
| 113 |
+
for _ in range(video_repeat):
|
| 114 |
+
for data in dataset:
|
| 115 |
+
if data.get('type', 'image') == 'video':
|
| 116 |
+
self.dataset.append(data)
|
| 117 |
+
else:
|
| 118 |
+
self.dataset = dataset
|
| 119 |
+
del dataset
|
| 120 |
+
|
| 121 |
+
self.length = len(self.dataset)
|
| 122 |
+
print(f"data scale: {self.length}")
|
| 123 |
+
# TODO: enable bucket training
|
| 124 |
+
self.enable_bucket = enable_bucket
|
| 125 |
+
self.text_drop_ratio = text_drop_ratio
|
| 126 |
+
self.enable_inpaint = enable_inpaint
|
| 127 |
+
self.return_file_name = return_file_name
|
| 128 |
+
|
| 129 |
+
self.video_length_drop_start = video_length_drop_start
|
| 130 |
+
self.video_length_drop_end = video_length_drop_end
|
| 131 |
+
|
| 132 |
+
# Video params
|
| 133 |
+
self.video_sample_stride = video_sample_stride
|
| 134 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 135 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 136 |
+
self.video_transforms = transforms.Compose(
|
| 137 |
+
[
|
| 138 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 139 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 140 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 141 |
+
]
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Image params
|
| 145 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
| 146 |
+
self.image_transforms = transforms.Compose([
|
| 147 |
+
transforms.Resize(min(self.image_sample_size)),
|
| 148 |
+
transforms.CenterCrop(self.image_sample_size),
|
| 149 |
+
transforms.ToTensor(),
|
| 150 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
| 151 |
+
])
|
| 152 |
+
|
| 153 |
+
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
|
| 154 |
+
|
| 155 |
+
def get_batch(self, idx):
|
| 156 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 157 |
+
|
| 158 |
+
if data_info.get('type', 'image')=='video':
|
| 159 |
+
video_id, text = data_info['file_path'], data_info['text']
|
| 160 |
+
|
| 161 |
+
if self.data_root is None:
|
| 162 |
+
video_dir = video_id
|
| 163 |
+
else:
|
| 164 |
+
video_dir = os.path.join(self.data_root, video_id)
|
| 165 |
+
|
| 166 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 167 |
+
min_sample_n_frames = min(
|
| 168 |
+
self.video_sample_n_frames,
|
| 169 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 170 |
+
)
|
| 171 |
+
if min_sample_n_frames == 0:
|
| 172 |
+
raise ValueError(f"No Frames in video.")
|
| 173 |
+
|
| 174 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 175 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 176 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 177 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 178 |
+
|
| 179 |
+
try:
|
| 180 |
+
sample_args = (video_reader, batch_index)
|
| 181 |
+
pixel_values = func_timeout(
|
| 182 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 183 |
+
)
|
| 184 |
+
resized_frames = []
|
| 185 |
+
for i in range(len(pixel_values)):
|
| 186 |
+
frame = pixel_values[i]
|
| 187 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 188 |
+
resized_frames.append(resized_frame)
|
| 189 |
+
pixel_values = np.array(resized_frames)
|
| 190 |
+
except FunctionTimedOut:
|
| 191 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 192 |
+
except Exception as e:
|
| 193 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 194 |
+
|
| 195 |
+
if not self.enable_bucket:
|
| 196 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 197 |
+
pixel_values = pixel_values / 255.
|
| 198 |
+
del video_reader
|
| 199 |
+
else:
|
| 200 |
+
pixel_values = pixel_values
|
| 201 |
+
|
| 202 |
+
if not self.enable_bucket:
|
| 203 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 204 |
+
|
| 205 |
+
# Random use no text generation
|
| 206 |
+
if random.random() < self.text_drop_ratio:
|
| 207 |
+
text = ''
|
| 208 |
+
return pixel_values, text, 'video', video_dir
|
| 209 |
+
else:
|
| 210 |
+
image_path, text = data_info['file_path'], data_info['text']
|
| 211 |
+
if self.data_root is not None:
|
| 212 |
+
image_path = os.path.join(self.data_root, image_path)
|
| 213 |
+
image = Image.open(image_path).convert('RGB')
|
| 214 |
+
if not self.enable_bucket:
|
| 215 |
+
image = self.image_transforms(image).unsqueeze(0)
|
| 216 |
+
else:
|
| 217 |
+
image = np.expand_dims(np.array(image), 0)
|
| 218 |
+
if random.random() < self.text_drop_ratio:
|
| 219 |
+
text = ''
|
| 220 |
+
return image, text, 'image', image_path
|
| 221 |
+
|
| 222 |
+
def __len__(self):
|
| 223 |
+
return self.length
|
| 224 |
+
|
| 225 |
+
def __getitem__(self, idx):
|
| 226 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 227 |
+
data_type = data_info.get('type', 'image')
|
| 228 |
+
while True:
|
| 229 |
+
sample = {}
|
| 230 |
+
try:
|
| 231 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 232 |
+
data_type_local = data_info_local.get('type', 'image')
|
| 233 |
+
if data_type_local != data_type:
|
| 234 |
+
raise ValueError("data_type_local != data_type")
|
| 235 |
+
|
| 236 |
+
pixel_values, name, data_type, file_path = self.get_batch(idx)
|
| 237 |
+
sample["pixel_values"] = pixel_values
|
| 238 |
+
sample["text"] = name
|
| 239 |
+
sample["data_type"] = data_type
|
| 240 |
+
sample["idx"] = idx
|
| 241 |
+
if self.return_file_name:
|
| 242 |
+
sample["file_name"] = os.path.basename(file_path)
|
| 243 |
+
|
| 244 |
+
if len(sample) > 0:
|
| 245 |
+
break
|
| 246 |
+
except Exception as e:
|
| 247 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 248 |
+
idx = random.randint(0, self.length-1)
|
| 249 |
+
|
| 250 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 251 |
+
mask = get_random_mask(pixel_values.size())
|
| 252 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
| 253 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 254 |
+
sample["mask"] = mask
|
| 255 |
+
|
| 256 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 257 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 258 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 259 |
+
|
| 260 |
+
return sample
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class ImageVideoControlDataset(Dataset):
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
ann_path, data_root=None,
|
| 267 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
| 268 |
+
image_sample_size=512,
|
| 269 |
+
video_repeat=0,
|
| 270 |
+
text_drop_ratio=0.1,
|
| 271 |
+
enable_bucket=False,
|
| 272 |
+
video_length_drop_start=0.1,
|
| 273 |
+
video_length_drop_end=0.9,
|
| 274 |
+
enable_inpaint=False,
|
| 275 |
+
enable_camera_info=False,
|
| 276 |
+
return_file_name=False,
|
| 277 |
+
enable_subject_info=False,
|
| 278 |
+
padding_subject_info=True,
|
| 279 |
+
):
|
| 280 |
+
# Loading annotations from files
|
| 281 |
+
print(f"loading annotations from {ann_path} ...")
|
| 282 |
+
if ann_path.endswith('.csv'):
|
| 283 |
+
with open(ann_path, 'r') as csvfile:
|
| 284 |
+
dataset = list(csv.DictReader(csvfile))
|
| 285 |
+
elif ann_path.endswith('.json'):
|
| 286 |
+
dataset = json.load(open(ann_path))
|
| 287 |
+
|
| 288 |
+
self.data_root = data_root
|
| 289 |
+
|
| 290 |
+
# It's used to balance num of images and videos.
|
| 291 |
+
if video_repeat > 0:
|
| 292 |
+
self.dataset = []
|
| 293 |
+
for data in dataset:
|
| 294 |
+
if data.get('type', 'image') != 'video':
|
| 295 |
+
self.dataset.append(data)
|
| 296 |
+
|
| 297 |
+
for _ in range(video_repeat):
|
| 298 |
+
for data in dataset:
|
| 299 |
+
if data.get('type', 'image') == 'video':
|
| 300 |
+
self.dataset.append(data)
|
| 301 |
+
else:
|
| 302 |
+
self.dataset = dataset
|
| 303 |
+
del dataset
|
| 304 |
+
|
| 305 |
+
self.length = len(self.dataset)
|
| 306 |
+
print(f"data scale: {self.length}")
|
| 307 |
+
# TODO: enable bucket training
|
| 308 |
+
self.enable_bucket = enable_bucket
|
| 309 |
+
self.text_drop_ratio = text_drop_ratio
|
| 310 |
+
self.enable_inpaint = enable_inpaint
|
| 311 |
+
self.enable_camera_info = enable_camera_info
|
| 312 |
+
self.enable_subject_info = enable_subject_info
|
| 313 |
+
self.padding_subject_info = padding_subject_info
|
| 314 |
+
|
| 315 |
+
self.video_length_drop_start = video_length_drop_start
|
| 316 |
+
self.video_length_drop_end = video_length_drop_end
|
| 317 |
+
|
| 318 |
+
# Video params
|
| 319 |
+
self.video_sample_stride = video_sample_stride
|
| 320 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 321 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 322 |
+
self.video_transforms = transforms.Compose(
|
| 323 |
+
[
|
| 324 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 325 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 326 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 327 |
+
]
|
| 328 |
+
)
|
| 329 |
+
if self.enable_camera_info:
|
| 330 |
+
self.video_transforms_camera = transforms.Compose(
|
| 331 |
+
[
|
| 332 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 333 |
+
transforms.CenterCrop(self.video_sample_size)
|
| 334 |
+
]
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Image params
|
| 338 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
| 339 |
+
self.image_transforms = transforms.Compose([
|
| 340 |
+
transforms.Resize(min(self.image_sample_size)),
|
| 341 |
+
transforms.CenterCrop(self.image_sample_size),
|
| 342 |
+
transforms.ToTensor(),
|
| 343 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
| 344 |
+
])
|
| 345 |
+
|
| 346 |
+
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
|
| 347 |
+
|
| 348 |
+
def get_batch(self, idx):
|
| 349 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 350 |
+
video_id, text = data_info['file_path'], data_info['text']
|
| 351 |
+
|
| 352 |
+
if data_info.get('type', 'image')=='video':
|
| 353 |
+
if self.data_root is None:
|
| 354 |
+
video_dir = video_id
|
| 355 |
+
else:
|
| 356 |
+
video_dir = os.path.join(self.data_root, video_id)
|
| 357 |
+
|
| 358 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 359 |
+
min_sample_n_frames = min(
|
| 360 |
+
self.video_sample_n_frames,
|
| 361 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 362 |
+
)
|
| 363 |
+
if min_sample_n_frames == 0:
|
| 364 |
+
raise ValueError(f"No Frames in video.")
|
| 365 |
+
|
| 366 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 367 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 368 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 369 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 370 |
+
|
| 371 |
+
try:
|
| 372 |
+
sample_args = (video_reader, batch_index)
|
| 373 |
+
pixel_values = func_timeout(
|
| 374 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 375 |
+
)
|
| 376 |
+
resized_frames = []
|
| 377 |
+
for i in range(len(pixel_values)):
|
| 378 |
+
frame = pixel_values[i]
|
| 379 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 380 |
+
resized_frames.append(resized_frame)
|
| 381 |
+
pixel_values = np.array(resized_frames)
|
| 382 |
+
except FunctionTimedOut:
|
| 383 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 384 |
+
except Exception as e:
|
| 385 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 386 |
+
|
| 387 |
+
if not self.enable_bucket:
|
| 388 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 389 |
+
pixel_values = pixel_values / 255.
|
| 390 |
+
del video_reader
|
| 391 |
+
else:
|
| 392 |
+
pixel_values = pixel_values
|
| 393 |
+
|
| 394 |
+
if not self.enable_bucket:
|
| 395 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 396 |
+
|
| 397 |
+
# Random use no text generation
|
| 398 |
+
if random.random() < self.text_drop_ratio:
|
| 399 |
+
text = ''
|
| 400 |
+
|
| 401 |
+
control_video_id = data_info['control_file_path']
|
| 402 |
+
|
| 403 |
+
if control_video_id is not None:
|
| 404 |
+
if self.data_root is None:
|
| 405 |
+
control_video_id = control_video_id
|
| 406 |
+
else:
|
| 407 |
+
control_video_id = os.path.join(self.data_root, control_video_id)
|
| 408 |
+
|
| 409 |
+
if self.enable_camera_info:
|
| 410 |
+
if control_video_id.lower().endswith('.txt'):
|
| 411 |
+
if not self.enable_bucket:
|
| 412 |
+
control_pixel_values = torch.zeros_like(pixel_values)
|
| 413 |
+
|
| 414 |
+
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0])
|
| 415 |
+
control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous()
|
| 416 |
+
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)
|
| 417 |
+
control_camera_values = self.video_transforms_camera(control_camera_values)
|
| 418 |
+
else:
|
| 419 |
+
control_pixel_values = np.zeros_like(pixel_values)
|
| 420 |
+
|
| 421 |
+
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True)
|
| 422 |
+
control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0)
|
| 423 |
+
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0]
|
| 424 |
+
control_camera_values = np.array([control_camera_values[index] for index in batch_index])
|
| 425 |
+
else:
|
| 426 |
+
if not self.enable_bucket:
|
| 427 |
+
control_pixel_values = torch.zeros_like(pixel_values)
|
| 428 |
+
control_camera_values = None
|
| 429 |
+
else:
|
| 430 |
+
control_pixel_values = np.zeros_like(pixel_values)
|
| 431 |
+
control_camera_values = None
|
| 432 |
+
else:
|
| 433 |
+
if control_video_id is not None:
|
| 434 |
+
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
|
| 435 |
+
try:
|
| 436 |
+
sample_args = (control_video_reader, batch_index)
|
| 437 |
+
control_pixel_values = func_timeout(
|
| 438 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 439 |
+
)
|
| 440 |
+
resized_frames = []
|
| 441 |
+
for i in range(len(control_pixel_values)):
|
| 442 |
+
frame = control_pixel_values[i]
|
| 443 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 444 |
+
resized_frames.append(resized_frame)
|
| 445 |
+
control_pixel_values = np.array(resized_frames)
|
| 446 |
+
except FunctionTimedOut:
|
| 447 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 448 |
+
except Exception as e:
|
| 449 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 450 |
+
|
| 451 |
+
if not self.enable_bucket:
|
| 452 |
+
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 453 |
+
control_pixel_values = control_pixel_values / 255.
|
| 454 |
+
del control_video_reader
|
| 455 |
+
else:
|
| 456 |
+
control_pixel_values = control_pixel_values
|
| 457 |
+
|
| 458 |
+
if not self.enable_bucket:
|
| 459 |
+
control_pixel_values = self.video_transforms(control_pixel_values)
|
| 460 |
+
else:
|
| 461 |
+
if not self.enable_bucket:
|
| 462 |
+
control_pixel_values = torch.zeros_like(pixel_values)
|
| 463 |
+
else:
|
| 464 |
+
control_pixel_values = np.zeros_like(pixel_values)
|
| 465 |
+
control_camera_values = None
|
| 466 |
+
|
| 467 |
+
if self.enable_subject_info:
|
| 468 |
+
if not self.enable_bucket:
|
| 469 |
+
visual_height, visual_width = pixel_values.shape[-2:]
|
| 470 |
+
else:
|
| 471 |
+
visual_height, visual_width = pixel_values.shape[1:3]
|
| 472 |
+
|
| 473 |
+
subject_id = data_info.get('object_file_path', [])
|
| 474 |
+
shuffle(subject_id)
|
| 475 |
+
subject_images = []
|
| 476 |
+
for i in range(min(len(subject_id), 4)):
|
| 477 |
+
subject_image = Image.open(subject_id[i])
|
| 478 |
+
width, height = subject_image.size
|
| 479 |
+
total_pixels = width * height
|
| 480 |
+
|
| 481 |
+
if self.padding_subject_info:
|
| 482 |
+
img = padding_image(subject_image, visual_width, visual_height)
|
| 483 |
+
else:
|
| 484 |
+
img = resize_image_with_target_area(subject_image, 1024 * 1024)
|
| 485 |
+
|
| 486 |
+
if random.random() < 0.5:
|
| 487 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
| 488 |
+
subject_images.append(np.array(img))
|
| 489 |
+
if self.padding_subject_info:
|
| 490 |
+
subject_image = np.array(subject_images)
|
| 491 |
+
else:
|
| 492 |
+
subject_image = subject_images
|
| 493 |
+
else:
|
| 494 |
+
subject_image = None
|
| 495 |
+
|
| 496 |
+
return pixel_values, control_pixel_values, subject_image, control_camera_values, text, "video"
|
| 497 |
+
else:
|
| 498 |
+
image_path, text = data_info['file_path'], data_info['text']
|
| 499 |
+
if self.data_root is not None:
|
| 500 |
+
image_path = os.path.join(self.data_root, image_path)
|
| 501 |
+
image = Image.open(image_path).convert('RGB')
|
| 502 |
+
if not self.enable_bucket:
|
| 503 |
+
image = self.image_transforms(image).unsqueeze(0)
|
| 504 |
+
else:
|
| 505 |
+
image = np.expand_dims(np.array(image), 0)
|
| 506 |
+
|
| 507 |
+
if random.random() < self.text_drop_ratio:
|
| 508 |
+
text = ''
|
| 509 |
+
|
| 510 |
+
control_image_id = data_info['control_file_path']
|
| 511 |
+
|
| 512 |
+
if self.data_root is None:
|
| 513 |
+
control_image_id = control_image_id
|
| 514 |
+
else:
|
| 515 |
+
control_image_id = os.path.join(self.data_root, control_image_id)
|
| 516 |
+
|
| 517 |
+
control_image = Image.open(control_image_id).convert('RGB')
|
| 518 |
+
if not self.enable_bucket:
|
| 519 |
+
control_image = self.image_transforms(control_image).unsqueeze(0)
|
| 520 |
+
else:
|
| 521 |
+
control_image = np.expand_dims(np.array(control_image), 0)
|
| 522 |
+
|
| 523 |
+
if self.enable_subject_info:
|
| 524 |
+
if not self.enable_bucket:
|
| 525 |
+
visual_height, visual_width = image.shape[-2:]
|
| 526 |
+
else:
|
| 527 |
+
visual_height, visual_width = image.shape[1:3]
|
| 528 |
+
|
| 529 |
+
subject_id = data_info.get('object_file_path', [])
|
| 530 |
+
shuffle(subject_id)
|
| 531 |
+
subject_images = []
|
| 532 |
+
for i in range(min(len(subject_id), 4)):
|
| 533 |
+
subject_image = Image.open(subject_id[i]).convert('RGB')
|
| 534 |
+
width, height = subject_image.size
|
| 535 |
+
total_pixels = width * height
|
| 536 |
+
|
| 537 |
+
if self.padding_subject_info:
|
| 538 |
+
img = padding_image(subject_image, visual_width, visual_height)
|
| 539 |
+
else:
|
| 540 |
+
img = resize_image_with_target_area(subject_image, 1024 * 1024)
|
| 541 |
+
|
| 542 |
+
if random.random() < 0.5:
|
| 543 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
| 544 |
+
subject_images.append(np.array(img))
|
| 545 |
+
if self.padding_subject_info:
|
| 546 |
+
subject_image = np.array(subject_images)
|
| 547 |
+
else:
|
| 548 |
+
subject_image = subject_images
|
| 549 |
+
else:
|
| 550 |
+
subject_image = None
|
| 551 |
+
|
| 552 |
+
return image, control_image, subject_image, None, text, 'image'
|
| 553 |
+
|
| 554 |
+
def __len__(self):
|
| 555 |
+
return self.length
|
| 556 |
+
|
| 557 |
+
def __getitem__(self, idx):
|
| 558 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 559 |
+
data_type = data_info.get('type', 'image')
|
| 560 |
+
while True:
|
| 561 |
+
sample = {}
|
| 562 |
+
try:
|
| 563 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 564 |
+
data_type_local = data_info_local.get('type', 'image')
|
| 565 |
+
if data_type_local != data_type:
|
| 566 |
+
raise ValueError("data_type_local != data_type")
|
| 567 |
+
|
| 568 |
+
pixel_values, control_pixel_values, subject_image, control_camera_values, name, data_type = self.get_batch(idx)
|
| 569 |
+
|
| 570 |
+
sample["pixel_values"] = pixel_values
|
| 571 |
+
sample["control_pixel_values"] = control_pixel_values
|
| 572 |
+
sample["subject_image"] = subject_image
|
| 573 |
+
sample["text"] = name
|
| 574 |
+
sample["data_type"] = data_type
|
| 575 |
+
sample["idx"] = idx
|
| 576 |
+
|
| 577 |
+
if self.enable_camera_info:
|
| 578 |
+
sample["control_camera_values"] = control_camera_values
|
| 579 |
+
|
| 580 |
+
if len(sample) > 0:
|
| 581 |
+
break
|
| 582 |
+
except Exception as e:
|
| 583 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 584 |
+
idx = random.randint(0, self.length-1)
|
| 585 |
+
|
| 586 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 587 |
+
mask = get_random_mask(pixel_values.size())
|
| 588 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
|
| 589 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 590 |
+
sample["mask"] = mask
|
| 591 |
+
|
| 592 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 593 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 594 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 595 |
+
|
| 596 |
+
return sample
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
class ImageVideoSafetensorsDataset(Dataset):
|
| 600 |
+
def __init__(
|
| 601 |
+
self,
|
| 602 |
+
ann_path,
|
| 603 |
+
data_root=None,
|
| 604 |
+
):
|
| 605 |
+
# Loading annotations from files
|
| 606 |
+
print(f"loading annotations from {ann_path} ...")
|
| 607 |
+
if ann_path.endswith('.json'):
|
| 608 |
+
dataset = json.load(open(ann_path))
|
| 609 |
+
|
| 610 |
+
self.data_root = data_root
|
| 611 |
+
self.dataset = dataset
|
| 612 |
+
self.length = len(self.dataset)
|
| 613 |
+
print(f"data scale: {self.length}")
|
| 614 |
+
|
| 615 |
+
def __len__(self):
|
| 616 |
+
return self.length
|
| 617 |
+
|
| 618 |
+
def __getitem__(self, idx):
|
| 619 |
+
if self.data_root is None:
|
| 620 |
+
path = self.dataset[idx]["file_path"]
|
| 621 |
+
else:
|
| 622 |
+
path = os.path.join(self.data_root, self.dataset[idx]["file_path"])
|
| 623 |
+
state_dict = load_file(path)
|
| 624 |
+
return state_dict
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
class TextDataset(Dataset):
|
| 628 |
+
def __init__(self, ann_path, text_drop_ratio=0.0):
|
| 629 |
+
print(f"loading annotations from {ann_path} ...")
|
| 630 |
+
with open(ann_path, 'r') as f:
|
| 631 |
+
self.dataset = json.load(f)
|
| 632 |
+
self.length = len(self.dataset)
|
| 633 |
+
print(f"data scale: {self.length}")
|
| 634 |
+
self.text_drop_ratio = text_drop_ratio
|
| 635 |
+
|
| 636 |
+
def __len__(self):
|
| 637 |
+
return self.length
|
| 638 |
+
|
| 639 |
+
def __getitem__(self, idx):
|
| 640 |
+
while True:
|
| 641 |
+
try:
|
| 642 |
+
item = self.dataset[idx]
|
| 643 |
+
text = item['text']
|
| 644 |
+
|
| 645 |
+
# Randomly drop text (for classifier-free guidance)
|
| 646 |
+
if random.random() < self.text_drop_ratio:
|
| 647 |
+
text = ''
|
| 648 |
+
|
| 649 |
+
sample = {
|
| 650 |
+
"text": text,
|
| 651 |
+
"idx": idx
|
| 652 |
+
}
|
| 653 |
+
return sample
|
| 654 |
+
|
| 655 |
+
except Exception as e:
|
| 656 |
+
print(f"Error at index {idx}: {e}, retrying with random index...")
|
| 657 |
+
idx = np.random.randint(0, self.length - 1)
|
videox_fun/data/dataset_video.py
ADDED
|
@@ -0,0 +1,901 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import gc
|
| 3 |
+
import io
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from threading import Thread
|
| 10 |
+
|
| 11 |
+
import albumentations
|
| 12 |
+
import cv2
|
| 13 |
+
import librosa
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torchvision.transforms as transforms
|
| 17 |
+
from decord import VideoReader
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
from func_timeout import FunctionTimedOut, func_timeout
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from torch.utils.data import BatchSampler, Sampler
|
| 22 |
+
from torch.utils.data.dataset import Dataset
|
| 23 |
+
|
| 24 |
+
from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
|
| 25 |
+
custom_meshgrid, get_random_mask, get_relative_pose,
|
| 26 |
+
get_video_reader_batch, padding_image, process_pose_file,
|
| 27 |
+
process_pose_params, ray_condition, resize_frame,
|
| 28 |
+
resize_image_with_target_area)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class WebVid10M(Dataset):
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
csv_path, video_folder,
|
| 35 |
+
sample_size=256, sample_stride=4, sample_n_frames=16,
|
| 36 |
+
enable_bucket=False, enable_inpaint=False, is_image=False,
|
| 37 |
+
):
|
| 38 |
+
print(f"loading annotations from {csv_path} ...")
|
| 39 |
+
with open(csv_path, 'r') as csvfile:
|
| 40 |
+
self.dataset = list(csv.DictReader(csvfile))
|
| 41 |
+
self.length = len(self.dataset)
|
| 42 |
+
print(f"data scale: {self.length}")
|
| 43 |
+
|
| 44 |
+
self.video_folder = video_folder
|
| 45 |
+
self.sample_stride = sample_stride
|
| 46 |
+
self.sample_n_frames = sample_n_frames
|
| 47 |
+
self.enable_bucket = enable_bucket
|
| 48 |
+
self.enable_inpaint = enable_inpaint
|
| 49 |
+
self.is_image = is_image
|
| 50 |
+
|
| 51 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
| 52 |
+
self.pixel_transforms = transforms.Compose([
|
| 53 |
+
transforms.Resize(sample_size[0]),
|
| 54 |
+
transforms.CenterCrop(sample_size),
|
| 55 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 56 |
+
])
|
| 57 |
+
|
| 58 |
+
def get_batch(self, idx):
|
| 59 |
+
video_dict = self.dataset[idx]
|
| 60 |
+
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
|
| 61 |
+
|
| 62 |
+
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
|
| 63 |
+
video_reader = VideoReader(video_dir)
|
| 64 |
+
video_length = len(video_reader)
|
| 65 |
+
|
| 66 |
+
if not self.is_image:
|
| 67 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
|
| 68 |
+
start_idx = random.randint(0, video_length - clip_length)
|
| 69 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
| 70 |
+
else:
|
| 71 |
+
batch_index = [random.randint(0, video_length - 1)]
|
| 72 |
+
|
| 73 |
+
if not self.enable_bucket:
|
| 74 |
+
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
| 75 |
+
pixel_values = pixel_values / 255.
|
| 76 |
+
del video_reader
|
| 77 |
+
else:
|
| 78 |
+
pixel_values = video_reader.get_batch(batch_index).asnumpy()
|
| 79 |
+
|
| 80 |
+
if self.is_image:
|
| 81 |
+
pixel_values = pixel_values[0]
|
| 82 |
+
return pixel_values, name
|
| 83 |
+
|
| 84 |
+
def __len__(self):
|
| 85 |
+
return self.length
|
| 86 |
+
|
| 87 |
+
def __getitem__(self, idx):
|
| 88 |
+
while True:
|
| 89 |
+
try:
|
| 90 |
+
pixel_values, name = self.get_batch(idx)
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print("Error info:", e)
|
| 95 |
+
idx = random.randint(0, self.length-1)
|
| 96 |
+
|
| 97 |
+
if not self.enable_bucket:
|
| 98 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
| 99 |
+
if self.enable_inpaint:
|
| 100 |
+
mask = get_random_mask(pixel_values.size())
|
| 101 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
| 102 |
+
sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
|
| 103 |
+
else:
|
| 104 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
| 105 |
+
return sample
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class VideoDataset(Dataset):
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
ann_path, data_root=None,
|
| 112 |
+
sample_size=256, sample_stride=4, sample_n_frames=16,
|
| 113 |
+
enable_bucket=False, enable_inpaint=False
|
| 114 |
+
):
|
| 115 |
+
print(f"loading annotations from {ann_path} ...")
|
| 116 |
+
self.dataset = json.load(open(ann_path, 'r'))
|
| 117 |
+
self.length = len(self.dataset)
|
| 118 |
+
print(f"data scale: {self.length}")
|
| 119 |
+
|
| 120 |
+
self.data_root = data_root
|
| 121 |
+
self.sample_stride = sample_stride
|
| 122 |
+
self.sample_n_frames = sample_n_frames
|
| 123 |
+
self.enable_bucket = enable_bucket
|
| 124 |
+
self.enable_inpaint = enable_inpaint
|
| 125 |
+
|
| 126 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
| 127 |
+
self.pixel_transforms = transforms.Compose(
|
| 128 |
+
[
|
| 129 |
+
transforms.Resize(sample_size[0]),
|
| 130 |
+
transforms.CenterCrop(sample_size),
|
| 131 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 132 |
+
]
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def get_batch(self, idx):
|
| 136 |
+
video_dict = self.dataset[idx]
|
| 137 |
+
video_id, text = video_dict['file_path'], video_dict['text']
|
| 138 |
+
|
| 139 |
+
if self.data_root is None:
|
| 140 |
+
video_dir = video_id
|
| 141 |
+
else:
|
| 142 |
+
video_dir = os.path.join(self.data_root, video_id)
|
| 143 |
+
|
| 144 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 145 |
+
min_sample_n_frames = min(
|
| 146 |
+
self.video_sample_n_frames,
|
| 147 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 148 |
+
)
|
| 149 |
+
if min_sample_n_frames == 0:
|
| 150 |
+
raise ValueError(f"No Frames in video.")
|
| 151 |
+
|
| 152 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 153 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 154 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 155 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
sample_args = (video_reader, batch_index)
|
| 159 |
+
pixel_values = func_timeout(
|
| 160 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 161 |
+
)
|
| 162 |
+
except FunctionTimedOut:
|
| 163 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 164 |
+
except Exception as e:
|
| 165 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 166 |
+
|
| 167 |
+
if not self.enable_bucket:
|
| 168 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 169 |
+
pixel_values = pixel_values / 255.
|
| 170 |
+
del video_reader
|
| 171 |
+
else:
|
| 172 |
+
pixel_values = pixel_values
|
| 173 |
+
|
| 174 |
+
if not self.enable_bucket:
|
| 175 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 176 |
+
|
| 177 |
+
# Random use no text generation
|
| 178 |
+
if random.random() < self.text_drop_ratio:
|
| 179 |
+
text = ''
|
| 180 |
+
return pixel_values, text
|
| 181 |
+
|
| 182 |
+
def __len__(self):
|
| 183 |
+
return self.length
|
| 184 |
+
|
| 185 |
+
def __getitem__(self, idx):
|
| 186 |
+
while True:
|
| 187 |
+
sample = {}
|
| 188 |
+
try:
|
| 189 |
+
pixel_values, name = self.get_batch(idx)
|
| 190 |
+
sample["pixel_values"] = pixel_values
|
| 191 |
+
sample["text"] = name
|
| 192 |
+
sample["idx"] = idx
|
| 193 |
+
if len(sample) > 0:
|
| 194 |
+
break
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 198 |
+
idx = random.randint(0, self.length-1)
|
| 199 |
+
|
| 200 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 201 |
+
mask = get_random_mask(pixel_values.size())
|
| 202 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
|
| 203 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 204 |
+
sample["mask"] = mask
|
| 205 |
+
|
| 206 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 207 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 208 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 209 |
+
|
| 210 |
+
return sample
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class VideoSpeechDataset(Dataset):
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
ann_path, data_root=None,
|
| 217 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
| 218 |
+
enable_bucket=False, enable_inpaint=False,
|
| 219 |
+
audio_sr=16000, # 新增:目标音频采样率
|
| 220 |
+
text_drop_ratio=0.1 # 新增:文本丢弃概率
|
| 221 |
+
):
|
| 222 |
+
print(f"loading annotations from {ann_path} ...")
|
| 223 |
+
self.dataset = json.load(open(ann_path, 'r'))
|
| 224 |
+
self.length = len(self.dataset)
|
| 225 |
+
print(f"data scale: {self.length}")
|
| 226 |
+
|
| 227 |
+
self.data_root = data_root
|
| 228 |
+
self.video_sample_stride = video_sample_stride
|
| 229 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 230 |
+
self.enable_bucket = enable_bucket
|
| 231 |
+
self.enable_inpaint = enable_inpaint
|
| 232 |
+
self.audio_sr = audio_sr
|
| 233 |
+
self.text_drop_ratio = text_drop_ratio
|
| 234 |
+
|
| 235 |
+
video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 236 |
+
self.pixel_transforms = transforms.Compose(
|
| 237 |
+
[
|
| 238 |
+
transforms.Resize(video_sample_size[0]),
|
| 239 |
+
transforms.CenterCrop(video_sample_size),
|
| 240 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 241 |
+
]
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
def get_batch(self, idx):
|
| 245 |
+
video_dict = self.dataset[idx]
|
| 246 |
+
video_id, text = video_dict['file_path'], video_dict['text']
|
| 247 |
+
audio_id = video_dict['audio_path']
|
| 248 |
+
|
| 249 |
+
if self.data_root is None:
|
| 250 |
+
video_path = video_id
|
| 251 |
+
else:
|
| 252 |
+
video_path = os.path.join(self.data_root, video_id)
|
| 253 |
+
|
| 254 |
+
if self.data_root is None:
|
| 255 |
+
audio_path = audio_id
|
| 256 |
+
else:
|
| 257 |
+
audio_path = os.path.join(self.data_root, audio_id)
|
| 258 |
+
|
| 259 |
+
if not os.path.exists(audio_path):
|
| 260 |
+
raise FileNotFoundError(f"Audio file not found for {video_path}")
|
| 261 |
+
|
| 262 |
+
with VideoReader_contextmanager(video_path, num_threads=2) as video_reader:
|
| 263 |
+
total_frames = len(video_reader)
|
| 264 |
+
fps = video_reader.get_avg_fps() # 获取原始视频帧率
|
| 265 |
+
|
| 266 |
+
# 计算实际采样的视频帧数(考虑边界)
|
| 267 |
+
max_possible_frames = (total_frames - 1) // self.video_sample_stride + 1
|
| 268 |
+
actual_n_frames = min(self.video_sample_n_frames, max_possible_frames)
|
| 269 |
+
if actual_n_frames <= 0:
|
| 270 |
+
raise ValueError(f"Video too short: {video_path}")
|
| 271 |
+
|
| 272 |
+
# 随机选择起始帧
|
| 273 |
+
max_start = total_frames - (actual_n_frames - 1) * self.video_sample_stride - 1
|
| 274 |
+
start_frame = random.randint(0, max_start) if max_start > 0 else 0
|
| 275 |
+
frame_indices = [start_frame + i * self.video_sample_stride for i in range(actual_n_frames)]
|
| 276 |
+
|
| 277 |
+
# 读取视频帧
|
| 278 |
+
try:
|
| 279 |
+
sample_args = (video_reader, frame_indices)
|
| 280 |
+
pixel_values = func_timeout(
|
| 281 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 282 |
+
)
|
| 283 |
+
except FunctionTimedOut:
|
| 284 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 285 |
+
except Exception as e:
|
| 286 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 287 |
+
|
| 288 |
+
# 视频后处理
|
| 289 |
+
if not self.enable_bucket:
|
| 290 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 291 |
+
pixel_values = pixel_values / 255.
|
| 292 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
| 293 |
+
|
| 294 |
+
# === 新增:加载并截取对应音频 ===
|
| 295 |
+
# 视频片段的起止时间(秒)
|
| 296 |
+
start_time = start_frame / fps
|
| 297 |
+
end_time = (start_frame + (actual_n_frames - 1) * self.video_sample_stride) / fps
|
| 298 |
+
duration = end_time - start_time
|
| 299 |
+
|
| 300 |
+
# 使用 librosa 加载整个音频(或仅加载所需部分,但 librosa.load 不支持精确 seek,所以先加载再切)
|
| 301 |
+
audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) # 重采样到目标 sr
|
| 302 |
+
|
| 303 |
+
# 转换为样本索引
|
| 304 |
+
start_sample = int(start_time * self.audio_sr)
|
| 305 |
+
end_sample = int(end_time * self.audio_sr)
|
| 306 |
+
|
| 307 |
+
# 安全截取
|
| 308 |
+
if start_sample >= len(audio_input):
|
| 309 |
+
# 音频太短,用零填充或截断
|
| 310 |
+
audio_segment = np.zeros(int(duration * self.audio_sr), dtype=np.float32)
|
| 311 |
+
else:
|
| 312 |
+
audio_segment = audio_input[start_sample:end_sample]
|
| 313 |
+
# 如果太短,补零
|
| 314 |
+
target_len = int(duration * self.audio_sr)
|
| 315 |
+
if len(audio_segment) < target_len:
|
| 316 |
+
audio_segment = np.pad(audio_segment, (0, target_len - len(audio_segment)), mode='constant')
|
| 317 |
+
|
| 318 |
+
# === 文本随机丢弃 ===
|
| 319 |
+
if random.random() < self.text_drop_ratio:
|
| 320 |
+
text = ''
|
| 321 |
+
|
| 322 |
+
return pixel_values, text, audio_segment, sample_rate
|
| 323 |
+
|
| 324 |
+
def __len__(self):
|
| 325 |
+
return self.length
|
| 326 |
+
|
| 327 |
+
def __getitem__(self, idx):
|
| 328 |
+
while True:
|
| 329 |
+
sample = {}
|
| 330 |
+
try:
|
| 331 |
+
pixel_values, text, audio, sample_rate = self.get_batch(idx)
|
| 332 |
+
sample["pixel_values"] = pixel_values
|
| 333 |
+
sample["text"] = text
|
| 334 |
+
sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor
|
| 335 |
+
sample["sample_rate"] = sample_rate
|
| 336 |
+
sample["idx"] = idx
|
| 337 |
+
break
|
| 338 |
+
except Exception as e:
|
| 339 |
+
print(f"Error processing {idx}: {e}, retrying with random idx...")
|
| 340 |
+
idx = random.randint(0, self.length - 1)
|
| 341 |
+
|
| 342 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 343 |
+
mask = get_random_mask(pixel_values.size(), image_start_only=True)
|
| 344 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
|
| 345 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 346 |
+
sample["mask"] = mask
|
| 347 |
+
|
| 348 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 349 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 350 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 351 |
+
|
| 352 |
+
return sample
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class VideoSpeechControlDataset(Dataset):
|
| 356 |
+
def __init__(
|
| 357 |
+
self,
|
| 358 |
+
ann_path, data_root=None,
|
| 359 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
| 360 |
+
enable_bucket=False, enable_inpaint=False,
|
| 361 |
+
audio_sr=16000,
|
| 362 |
+
text_drop_ratio=0.1,
|
| 363 |
+
enable_motion_info=False,
|
| 364 |
+
motion_frames=73,
|
| 365 |
+
):
|
| 366 |
+
print(f"loading annotations from {ann_path} ...")
|
| 367 |
+
self.dataset = json.load(open(ann_path, 'r'))
|
| 368 |
+
self.length = len(self.dataset)
|
| 369 |
+
print(f"data scale: {self.length}")
|
| 370 |
+
|
| 371 |
+
self.data_root = data_root
|
| 372 |
+
self.video_sample_stride = video_sample_stride
|
| 373 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 374 |
+
self.enable_bucket = enable_bucket
|
| 375 |
+
self.enable_inpaint = enable_inpaint
|
| 376 |
+
self.audio_sr = audio_sr
|
| 377 |
+
self.text_drop_ratio = text_drop_ratio
|
| 378 |
+
self.enable_motion_info = enable_motion_info
|
| 379 |
+
self.motion_frames = motion_frames
|
| 380 |
+
|
| 381 |
+
video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 382 |
+
self.pixel_transforms = transforms.Compose(
|
| 383 |
+
[
|
| 384 |
+
transforms.Resize(video_sample_size[0]),
|
| 385 |
+
transforms.CenterCrop(video_sample_size),
|
| 386 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 387 |
+
]
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
self.video_sample_size = video_sample_size
|
| 391 |
+
|
| 392 |
+
def get_batch(self, idx):
|
| 393 |
+
video_dict = self.dataset[idx]
|
| 394 |
+
video_id, text = video_dict['file_path'], video_dict['text']
|
| 395 |
+
audio_id = video_dict['audio_path']
|
| 396 |
+
control_video_id = video_dict['control_file_path']
|
| 397 |
+
|
| 398 |
+
if self.data_root is None:
|
| 399 |
+
video_path = video_id
|
| 400 |
+
else:
|
| 401 |
+
video_path = os.path.join(self.data_root, video_id)
|
| 402 |
+
|
| 403 |
+
if self.data_root is None:
|
| 404 |
+
audio_path = audio_id
|
| 405 |
+
else:
|
| 406 |
+
audio_path = os.path.join(self.data_root, audio_id)
|
| 407 |
+
|
| 408 |
+
if self.data_root is None:
|
| 409 |
+
control_video_id = control_video_id
|
| 410 |
+
else:
|
| 411 |
+
control_video_id = os.path.join(self.data_root, control_video_id)
|
| 412 |
+
|
| 413 |
+
if not os.path.exists(audio_path):
|
| 414 |
+
raise FileNotFoundError(f"Audio file not found for {video_path}")
|
| 415 |
+
|
| 416 |
+
# Video information
|
| 417 |
+
with VideoReader_contextmanager(video_path, num_threads=2) as video_reader:
|
| 418 |
+
total_frames = len(video_reader)
|
| 419 |
+
fps = video_reader.get_avg_fps()
|
| 420 |
+
if fps <= 0:
|
| 421 |
+
raise ValueError(f"Video has negative fps: {video_path}")
|
| 422 |
+
local_video_sample_stride = self.video_sample_stride
|
| 423 |
+
new_fps = int(fps // local_video_sample_stride)
|
| 424 |
+
while new_fps > 30:
|
| 425 |
+
local_video_sample_stride = local_video_sample_stride + 1
|
| 426 |
+
new_fps = int(fps // local_video_sample_stride)
|
| 427 |
+
|
| 428 |
+
max_possible_frames = (total_frames - 1) // local_video_sample_stride + 1
|
| 429 |
+
actual_n_frames = min(self.video_sample_n_frames, max_possible_frames)
|
| 430 |
+
if actual_n_frames <= 0:
|
| 431 |
+
raise ValueError(f"Video too short: {video_path}")
|
| 432 |
+
|
| 433 |
+
max_start = total_frames - (actual_n_frames - 1) * local_video_sample_stride - 1
|
| 434 |
+
start_frame = random.randint(0, max_start) if max_start > 0 else 0
|
| 435 |
+
frame_indices = [start_frame + i * local_video_sample_stride for i in range(actual_n_frames)]
|
| 436 |
+
|
| 437 |
+
try:
|
| 438 |
+
sample_args = (video_reader, frame_indices)
|
| 439 |
+
pixel_values = func_timeout(
|
| 440 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 441 |
+
)
|
| 442 |
+
except FunctionTimedOut:
|
| 443 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 444 |
+
except Exception as e:
|
| 445 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 446 |
+
|
| 447 |
+
_, height, width, channel = np.shape(pixel_values)
|
| 448 |
+
if self.enable_motion_info:
|
| 449 |
+
motion_pixel_values = np.ones([self.motion_frames, height, width, channel]) * 127.5
|
| 450 |
+
if start_frame > 0:
|
| 451 |
+
motion_max_possible_frames = (start_frame - 1) // local_video_sample_stride + 1
|
| 452 |
+
motion_frame_indices = [0 + i * local_video_sample_stride for i in range(motion_max_possible_frames)]
|
| 453 |
+
motion_frame_indices = motion_frame_indices[-self.motion_frames:]
|
| 454 |
+
|
| 455 |
+
_motion_sample_args = (video_reader, motion_frame_indices)
|
| 456 |
+
_motion_pixel_values = func_timeout(
|
| 457 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=_motion_sample_args
|
| 458 |
+
)
|
| 459 |
+
motion_pixel_values[-len(motion_frame_indices):] = _motion_pixel_values
|
| 460 |
+
|
| 461 |
+
if not self.enable_bucket:
|
| 462 |
+
motion_pixel_values = torch.from_numpy(motion_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 463 |
+
motion_pixel_values = motion_pixel_values / 255.
|
| 464 |
+
motion_pixel_values = self.pixel_transforms(motion_pixel_values)
|
| 465 |
+
else:
|
| 466 |
+
motion_pixel_values = None
|
| 467 |
+
|
| 468 |
+
if not self.enable_bucket:
|
| 469 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 470 |
+
pixel_values = pixel_values / 255.
|
| 471 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
| 472 |
+
|
| 473 |
+
# Audio information
|
| 474 |
+
start_time = start_frame / fps
|
| 475 |
+
end_time = (start_frame + (actual_n_frames - 1) * local_video_sample_stride) / fps
|
| 476 |
+
duration = end_time - start_time
|
| 477 |
+
|
| 478 |
+
audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr)
|
| 479 |
+
start_sample = int(start_time * self.audio_sr)
|
| 480 |
+
end_sample = int(end_time * self.audio_sr)
|
| 481 |
+
|
| 482 |
+
if start_sample >= len(audio_input):
|
| 483 |
+
raise ValueError(f"Audio file too short: {audio_path}")
|
| 484 |
+
else:
|
| 485 |
+
audio_segment = audio_input[start_sample:end_sample]
|
| 486 |
+
target_len = int(duration * self.audio_sr)
|
| 487 |
+
if len(audio_segment) < target_len:
|
| 488 |
+
raise ValueError(f"Audio file too short: {audio_path}")
|
| 489 |
+
|
| 490 |
+
# Control information
|
| 491 |
+
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
|
| 492 |
+
try:
|
| 493 |
+
sample_args = (control_video_reader, frame_indices)
|
| 494 |
+
control_pixel_values = func_timeout(
|
| 495 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 496 |
+
)
|
| 497 |
+
resized_frames = []
|
| 498 |
+
for i in range(len(control_pixel_values)):
|
| 499 |
+
frame = control_pixel_values[i]
|
| 500 |
+
resized_frame = resize_frame(frame, max(self.video_sample_size))
|
| 501 |
+
resized_frames.append(resized_frame)
|
| 502 |
+
control_pixel_values = np.array(control_pixel_values)
|
| 503 |
+
except FunctionTimedOut:
|
| 504 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 505 |
+
except Exception as e:
|
| 506 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 507 |
+
|
| 508 |
+
if not self.enable_bucket:
|
| 509 |
+
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 510 |
+
control_pixel_values = control_pixel_values / 255.
|
| 511 |
+
del control_video_reader
|
| 512 |
+
else:
|
| 513 |
+
control_pixel_values = control_pixel_values
|
| 514 |
+
|
| 515 |
+
if not self.enable_bucket:
|
| 516 |
+
control_pixel_values = self.video_transforms(control_pixel_values)
|
| 517 |
+
|
| 518 |
+
if random.random() < self.text_drop_ratio:
|
| 519 |
+
text = ''
|
| 520 |
+
|
| 521 |
+
return pixel_values, motion_pixel_values, control_pixel_values, text, audio_segment, sample_rate, new_fps
|
| 522 |
+
|
| 523 |
+
def __len__(self):
|
| 524 |
+
return self.length
|
| 525 |
+
|
| 526 |
+
def __getitem__(self, idx):
|
| 527 |
+
while True:
|
| 528 |
+
sample = {}
|
| 529 |
+
try:
|
| 530 |
+
pixel_values, motion_pixel_values, control_pixel_values, text, audio, sample_rate, new_fps = self.get_batch(idx)
|
| 531 |
+
sample["pixel_values"] = pixel_values
|
| 532 |
+
sample["motion_pixel_values"] = motion_pixel_values
|
| 533 |
+
sample["control_pixel_values"] = control_pixel_values
|
| 534 |
+
sample["text"] = text
|
| 535 |
+
sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor
|
| 536 |
+
sample["sample_rate"] = sample_rate
|
| 537 |
+
sample["fps"] = new_fps
|
| 538 |
+
sample["idx"] = idx
|
| 539 |
+
break
|
| 540 |
+
except Exception as e:
|
| 541 |
+
print(f"Error processing {idx}: {e}, retrying with random idx...")
|
| 542 |
+
idx = random.randint(0, self.length - 1)
|
| 543 |
+
|
| 544 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 545 |
+
mask = get_random_mask(pixel_values.size(), image_start_only=True)
|
| 546 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
|
| 547 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 548 |
+
sample["mask"] = mask
|
| 549 |
+
|
| 550 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 551 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 552 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 553 |
+
|
| 554 |
+
return sample
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
class VideoAnimateDataset(Dataset):
|
| 558 |
+
def __init__(
|
| 559 |
+
self,
|
| 560 |
+
ann_path, data_root=None,
|
| 561 |
+
video_sample_size=512,
|
| 562 |
+
video_sample_stride=4,
|
| 563 |
+
video_sample_n_frames=16,
|
| 564 |
+
video_repeat=0,
|
| 565 |
+
text_drop_ratio=0.1,
|
| 566 |
+
enable_bucket=False,
|
| 567 |
+
video_length_drop_start=0.1,
|
| 568 |
+
video_length_drop_end=0.9,
|
| 569 |
+
return_file_name=False,
|
| 570 |
+
):
|
| 571 |
+
# Loading annotations from files
|
| 572 |
+
print(f"loading annotations from {ann_path} ...")
|
| 573 |
+
if ann_path.endswith('.csv'):
|
| 574 |
+
with open(ann_path, 'r') as csvfile:
|
| 575 |
+
dataset = list(csv.DictReader(csvfile))
|
| 576 |
+
elif ann_path.endswith('.json'):
|
| 577 |
+
dataset = json.load(open(ann_path))
|
| 578 |
+
|
| 579 |
+
self.data_root = data_root
|
| 580 |
+
|
| 581 |
+
# It's used to balance num of images and videos.
|
| 582 |
+
if video_repeat > 0:
|
| 583 |
+
self.dataset = []
|
| 584 |
+
for data in dataset:
|
| 585 |
+
if data.get('type', 'image') != 'video':
|
| 586 |
+
self.dataset.append(data)
|
| 587 |
+
|
| 588 |
+
for _ in range(video_repeat):
|
| 589 |
+
for data in dataset:
|
| 590 |
+
if data.get('type', 'image') == 'video':
|
| 591 |
+
self.dataset.append(data)
|
| 592 |
+
else:
|
| 593 |
+
self.dataset = dataset
|
| 594 |
+
del dataset
|
| 595 |
+
|
| 596 |
+
self.length = len(self.dataset)
|
| 597 |
+
print(f"data scale: {self.length}")
|
| 598 |
+
# TODO: enable bucket training
|
| 599 |
+
self.enable_bucket = enable_bucket
|
| 600 |
+
self.text_drop_ratio = text_drop_ratio
|
| 601 |
+
|
| 602 |
+
self.video_length_drop_start = video_length_drop_start
|
| 603 |
+
self.video_length_drop_end = video_length_drop_end
|
| 604 |
+
|
| 605 |
+
# Video params
|
| 606 |
+
self.video_sample_stride = video_sample_stride
|
| 607 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 608 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 609 |
+
self.video_transforms = transforms.Compose(
|
| 610 |
+
[
|
| 611 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 612 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 613 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 614 |
+
]
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
self.larger_side_of_image_and_video = min(self.video_sample_size)
|
| 618 |
+
|
| 619 |
+
def get_batch(self, idx):
|
| 620 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 621 |
+
video_id, text = data_info['file_path'], data_info['text']
|
| 622 |
+
|
| 623 |
+
if self.data_root is None:
|
| 624 |
+
video_dir = video_id
|
| 625 |
+
else:
|
| 626 |
+
video_dir = os.path.join(self.data_root, video_id)
|
| 627 |
+
|
| 628 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 629 |
+
min_sample_n_frames = min(
|
| 630 |
+
self.video_sample_n_frames,
|
| 631 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 632 |
+
)
|
| 633 |
+
if min_sample_n_frames == 0:
|
| 634 |
+
raise ValueError(f"No Frames in video.")
|
| 635 |
+
|
| 636 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 637 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 638 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 639 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 640 |
+
|
| 641 |
+
try:
|
| 642 |
+
sample_args = (video_reader, batch_index)
|
| 643 |
+
pixel_values = func_timeout(
|
| 644 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 645 |
+
)
|
| 646 |
+
resized_frames = []
|
| 647 |
+
for i in range(len(pixel_values)):
|
| 648 |
+
frame = pixel_values[i]
|
| 649 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 650 |
+
resized_frames.append(resized_frame)
|
| 651 |
+
pixel_values = np.array(resized_frames)
|
| 652 |
+
except FunctionTimedOut:
|
| 653 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 654 |
+
except Exception as e:
|
| 655 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 656 |
+
|
| 657 |
+
if not self.enable_bucket:
|
| 658 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 659 |
+
pixel_values = pixel_values / 255.
|
| 660 |
+
del video_reader
|
| 661 |
+
else:
|
| 662 |
+
pixel_values = pixel_values
|
| 663 |
+
|
| 664 |
+
if not self.enable_bucket:
|
| 665 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 666 |
+
|
| 667 |
+
# Random use no text generation
|
| 668 |
+
if random.random() < self.text_drop_ratio:
|
| 669 |
+
text = ''
|
| 670 |
+
|
| 671 |
+
control_video_id = data_info['control_file_path']
|
| 672 |
+
|
| 673 |
+
if control_video_id is not None:
|
| 674 |
+
if self.data_root is None:
|
| 675 |
+
control_video_id = control_video_id
|
| 676 |
+
else:
|
| 677 |
+
control_video_id = os.path.join(self.data_root, control_video_id)
|
| 678 |
+
|
| 679 |
+
if control_video_id is not None:
|
| 680 |
+
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
|
| 681 |
+
try:
|
| 682 |
+
sample_args = (control_video_reader, batch_index)
|
| 683 |
+
control_pixel_values = func_timeout(
|
| 684 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 685 |
+
)
|
| 686 |
+
resized_frames = []
|
| 687 |
+
for i in range(len(control_pixel_values)):
|
| 688 |
+
frame = control_pixel_values[i]
|
| 689 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 690 |
+
resized_frames.append(resized_frame)
|
| 691 |
+
control_pixel_values = np.array(resized_frames)
|
| 692 |
+
except FunctionTimedOut:
|
| 693 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 694 |
+
except Exception as e:
|
| 695 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 696 |
+
|
| 697 |
+
if not self.enable_bucket:
|
| 698 |
+
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 699 |
+
control_pixel_values = control_pixel_values / 255.
|
| 700 |
+
del control_video_reader
|
| 701 |
+
else:
|
| 702 |
+
control_pixel_values = control_pixel_values
|
| 703 |
+
|
| 704 |
+
if not self.enable_bucket:
|
| 705 |
+
control_pixel_values = self.video_transforms(control_pixel_values)
|
| 706 |
+
else:
|
| 707 |
+
if not self.enable_bucket:
|
| 708 |
+
control_pixel_values = torch.zeros_like(pixel_values)
|
| 709 |
+
else:
|
| 710 |
+
control_pixel_values = np.zeros_like(pixel_values)
|
| 711 |
+
|
| 712 |
+
face_video_id = data_info['face_file_path']
|
| 713 |
+
|
| 714 |
+
if face_video_id is not None:
|
| 715 |
+
if self.data_root is None:
|
| 716 |
+
face_video_id = face_video_id
|
| 717 |
+
else:
|
| 718 |
+
face_video_id = os.path.join(self.data_root, face_video_id)
|
| 719 |
+
|
| 720 |
+
if face_video_id is not None:
|
| 721 |
+
with VideoReader_contextmanager(face_video_id, num_threads=2) as face_video_reader:
|
| 722 |
+
try:
|
| 723 |
+
sample_args = (face_video_reader, batch_index)
|
| 724 |
+
face_pixel_values = func_timeout(
|
| 725 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 726 |
+
)
|
| 727 |
+
resized_frames = []
|
| 728 |
+
for i in range(len(face_pixel_values)):
|
| 729 |
+
frame = face_pixel_values[i]
|
| 730 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 731 |
+
resized_frames.append(resized_frame)
|
| 732 |
+
face_pixel_values = np.array(resized_frames)
|
| 733 |
+
except FunctionTimedOut:
|
| 734 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 735 |
+
except Exception as e:
|
| 736 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 737 |
+
|
| 738 |
+
if not self.enable_bucket:
|
| 739 |
+
face_pixel_values = torch.from_numpy(face_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 740 |
+
face_pixel_values = face_pixel_values / 255.
|
| 741 |
+
del face_video_reader
|
| 742 |
+
else:
|
| 743 |
+
face_pixel_values = face_pixel_values
|
| 744 |
+
|
| 745 |
+
if not self.enable_bucket:
|
| 746 |
+
face_pixel_values = self.video_transforms(face_pixel_values)
|
| 747 |
+
else:
|
| 748 |
+
if not self.enable_bucket:
|
| 749 |
+
face_pixel_values = torch.zeros_like(pixel_values)
|
| 750 |
+
else:
|
| 751 |
+
face_pixel_values = np.zeros_like(pixel_values)
|
| 752 |
+
|
| 753 |
+
background_video_id = data_info.get('background_file_path', None)
|
| 754 |
+
|
| 755 |
+
if background_video_id is not None:
|
| 756 |
+
if self.data_root is None:
|
| 757 |
+
background_video_id = background_video_id
|
| 758 |
+
else:
|
| 759 |
+
background_video_id = os.path.join(self.data_root, background_video_id)
|
| 760 |
+
|
| 761 |
+
if background_video_id is not None:
|
| 762 |
+
with VideoReader_contextmanager(background_video_id, num_threads=2) as background_video_reader:
|
| 763 |
+
try:
|
| 764 |
+
sample_args = (background_video_reader, batch_index)
|
| 765 |
+
background_pixel_values = func_timeout(
|
| 766 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 767 |
+
)
|
| 768 |
+
resized_frames = []
|
| 769 |
+
for i in range(len(background_pixel_values)):
|
| 770 |
+
frame = background_pixel_values[i]
|
| 771 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 772 |
+
resized_frames.append(resized_frame)
|
| 773 |
+
background_pixel_values = np.array(resized_frames)
|
| 774 |
+
except FunctionTimedOut:
|
| 775 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 776 |
+
except Exception as e:
|
| 777 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 778 |
+
|
| 779 |
+
if not self.enable_bucket:
|
| 780 |
+
background_pixel_values = torch.from_numpy(background_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 781 |
+
background_pixel_values = background_pixel_values / 255.
|
| 782 |
+
del background_video_reader
|
| 783 |
+
else:
|
| 784 |
+
background_pixel_values = background_pixel_values
|
| 785 |
+
|
| 786 |
+
if not self.enable_bucket:
|
| 787 |
+
background_pixel_values = self.video_transforms(background_pixel_values)
|
| 788 |
+
else:
|
| 789 |
+
if not self.enable_bucket:
|
| 790 |
+
background_pixel_values = torch.ones_like(pixel_values) * 127.5
|
| 791 |
+
else:
|
| 792 |
+
background_pixel_values = np.ones_like(pixel_values) * 127.5
|
| 793 |
+
|
| 794 |
+
mask_video_id = data_info.get('mask_file_path', None)
|
| 795 |
+
|
| 796 |
+
if mask_video_id is not None:
|
| 797 |
+
if self.data_root is None:
|
| 798 |
+
mask_video_id = mask_video_id
|
| 799 |
+
else:
|
| 800 |
+
mask_video_id = os.path.join(self.data_root, mask_video_id)
|
| 801 |
+
|
| 802 |
+
if mask_video_id is not None:
|
| 803 |
+
with VideoReader_contextmanager(mask_video_id, num_threads=2) as mask_video_reader:
|
| 804 |
+
try:
|
| 805 |
+
sample_args = (mask_video_reader, batch_index)
|
| 806 |
+
mask = func_timeout(
|
| 807 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 808 |
+
)
|
| 809 |
+
resized_frames = []
|
| 810 |
+
for i in range(len(mask)):
|
| 811 |
+
frame = mask[i]
|
| 812 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 813 |
+
resized_frames.append(resized_frame)
|
| 814 |
+
mask = np.array(resized_frames)
|
| 815 |
+
except FunctionTimedOut:
|
| 816 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 817 |
+
except Exception as e:
|
| 818 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 819 |
+
|
| 820 |
+
if not self.enable_bucket:
|
| 821 |
+
mask = torch.from_numpy(mask).permute(0, 3, 1, 2).contiguous()
|
| 822 |
+
mask = mask / 255.
|
| 823 |
+
del mask_video_reader
|
| 824 |
+
else:
|
| 825 |
+
mask = mask
|
| 826 |
+
else:
|
| 827 |
+
if not self.enable_bucket:
|
| 828 |
+
mask = torch.ones_like(pixel_values)
|
| 829 |
+
else:
|
| 830 |
+
mask = np.ones_like(pixel_values) * 255
|
| 831 |
+
mask = mask[:, :, :, :1]
|
| 832 |
+
|
| 833 |
+
ref_pixel_values_path = data_info.get('ref_file_path', [])
|
| 834 |
+
if self.data_root is not None:
|
| 835 |
+
ref_pixel_values_path = os.path.join(self.data_root, ref_pixel_values_path)
|
| 836 |
+
ref_pixel_values = Image.open(ref_pixel_values_path).convert('RGB')
|
| 837 |
+
|
| 838 |
+
if not self.enable_bucket:
|
| 839 |
+
raise ValueError("Not enable_bucket is not supported now. ")
|
| 840 |
+
else:
|
| 841 |
+
ref_pixel_values = np.array(ref_pixel_values)
|
| 842 |
+
|
| 843 |
+
return pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, text, "video"
|
| 844 |
+
|
| 845 |
+
def __len__(self):
|
| 846 |
+
return self.length
|
| 847 |
+
|
| 848 |
+
def __getitem__(self, idx):
|
| 849 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 850 |
+
data_type = data_info.get('type', 'image')
|
| 851 |
+
while True:
|
| 852 |
+
sample = {}
|
| 853 |
+
try:
|
| 854 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 855 |
+
data_type_local = data_info_local.get('type', 'image')
|
| 856 |
+
if data_type_local != data_type:
|
| 857 |
+
raise ValueError("data_type_local != data_type")
|
| 858 |
+
|
| 859 |
+
pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, name, data_type = \
|
| 860 |
+
self.get_batch(idx)
|
| 861 |
+
|
| 862 |
+
sample["pixel_values"] = pixel_values
|
| 863 |
+
sample["control_pixel_values"] = control_pixel_values
|
| 864 |
+
sample["face_pixel_values"] = face_pixel_values
|
| 865 |
+
sample["background_pixel_values"] = background_pixel_values
|
| 866 |
+
sample["mask"] = mask
|
| 867 |
+
sample["ref_pixel_values"] = ref_pixel_values
|
| 868 |
+
sample["clip_pixel_values"] = ref_pixel_values
|
| 869 |
+
sample["text"] = name
|
| 870 |
+
sample["data_type"] = data_type
|
| 871 |
+
sample["idx"] = idx
|
| 872 |
+
|
| 873 |
+
if len(sample) > 0:
|
| 874 |
+
break
|
| 875 |
+
except Exception as e:
|
| 876 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 877 |
+
idx = random.randint(0, self.length-1)
|
| 878 |
+
|
| 879 |
+
return sample
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
if __name__ == "__main__":
|
| 883 |
+
if 1:
|
| 884 |
+
dataset = VideoDataset(
|
| 885 |
+
json_path="./webvidval/results_2M_val.json",
|
| 886 |
+
sample_size=256,
|
| 887 |
+
sample_stride=4, sample_n_frames=16,
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
if 0:
|
| 891 |
+
dataset = WebVid10M(
|
| 892 |
+
csv_path="./webvid/results_2M_val.csv",
|
| 893 |
+
video_folder="./webvid/2M_val",
|
| 894 |
+
sample_size=256,
|
| 895 |
+
sample_stride=4, sample_n_frames=16,
|
| 896 |
+
is_image=False,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
|
| 900 |
+
for idx, batch in enumerate(dataloader):
|
| 901 |
+
print(batch["pixel_values"].shape, len(batch["text"]))
|
videox_fun/data/utils.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import gc
|
| 3 |
+
import io
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from random import shuffle
|
| 10 |
+
from threading import Thread
|
| 11 |
+
|
| 12 |
+
import albumentations
|
| 13 |
+
import cv2
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import torchvision.transforms as transforms
|
| 18 |
+
from decord import VideoReader
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from func_timeout import FunctionTimedOut, func_timeout
|
| 21 |
+
from packaging import version as pver
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from safetensors.torch import load_file
|
| 24 |
+
from torch.utils.data import BatchSampler, Sampler
|
| 25 |
+
from torch.utils.data.dataset import Dataset
|
| 26 |
+
|
| 27 |
+
VIDEO_READER_TIMEOUT = 20
|
| 28 |
+
|
| 29 |
+
def get_random_mask(shape, image_start_only=False):
|
| 30 |
+
f, c, h, w = shape
|
| 31 |
+
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
|
| 32 |
+
|
| 33 |
+
if not image_start_only:
|
| 34 |
+
if f != 1:
|
| 35 |
+
mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05])
|
| 36 |
+
else:
|
| 37 |
+
mask_index = np.random.choice([0, 1, 7, 8], p = [0.2, 0.7, 0.05, 0.05])
|
| 38 |
+
if mask_index == 0:
|
| 39 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 40 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 41 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
| 42 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
| 43 |
+
|
| 44 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 45 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 46 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 47 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 48 |
+
mask[:, :, start_y:end_y, start_x:end_x] = 1
|
| 49 |
+
elif mask_index == 1:
|
| 50 |
+
mask[:, :, :, :] = 1
|
| 51 |
+
elif mask_index == 2:
|
| 52 |
+
mask_frame_index = np.random.randint(1, 5)
|
| 53 |
+
mask[mask_frame_index:, :, :, :] = 1
|
| 54 |
+
elif mask_index == 3:
|
| 55 |
+
mask_frame_index = np.random.randint(1, 5)
|
| 56 |
+
mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
|
| 57 |
+
elif mask_index == 4:
|
| 58 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 59 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 60 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
| 61 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
| 62 |
+
|
| 63 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 64 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 65 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 66 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 67 |
+
|
| 68 |
+
mask_frame_before = np.random.randint(0, f // 2)
|
| 69 |
+
mask_frame_after = np.random.randint(f // 2, f)
|
| 70 |
+
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
|
| 71 |
+
elif mask_index == 5:
|
| 72 |
+
mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
|
| 73 |
+
elif mask_index == 6:
|
| 74 |
+
num_frames_to_mask = random.randint(1, max(f // 2, 1))
|
| 75 |
+
frames_to_mask = random.sample(range(f), num_frames_to_mask)
|
| 76 |
+
|
| 77 |
+
for i in frames_to_mask:
|
| 78 |
+
block_height = random.randint(1, h // 4)
|
| 79 |
+
block_width = random.randint(1, w // 4)
|
| 80 |
+
top_left_y = random.randint(0, h - block_height)
|
| 81 |
+
top_left_x = random.randint(0, w - block_width)
|
| 82 |
+
mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
|
| 83 |
+
elif mask_index == 7:
|
| 84 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 85 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 86 |
+
a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴
|
| 87 |
+
b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴
|
| 88 |
+
|
| 89 |
+
for i in range(h):
|
| 90 |
+
for j in range(w):
|
| 91 |
+
if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
|
| 92 |
+
mask[:, :, i, j] = 1
|
| 93 |
+
elif mask_index == 8:
|
| 94 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 95 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 96 |
+
radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
|
| 97 |
+
for i in range(h):
|
| 98 |
+
for j in range(w):
|
| 99 |
+
if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
|
| 100 |
+
mask[:, :, i, j] = 1
|
| 101 |
+
elif mask_index == 9:
|
| 102 |
+
for idx in range(f):
|
| 103 |
+
if np.random.rand() > 0.5:
|
| 104 |
+
mask[idx, :, :, :] = 1
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError(f"The mask_index {mask_index} is not define")
|
| 107 |
+
else:
|
| 108 |
+
if f != 1:
|
| 109 |
+
mask[1:, :, :, :] = 1
|
| 110 |
+
else:
|
| 111 |
+
mask[:, :, :, :] = 1
|
| 112 |
+
return mask
|
| 113 |
+
|
| 114 |
+
@contextmanager
|
| 115 |
+
def VideoReader_contextmanager(*args, **kwargs):
|
| 116 |
+
vr = VideoReader(*args, **kwargs)
|
| 117 |
+
try:
|
| 118 |
+
yield vr
|
| 119 |
+
finally:
|
| 120 |
+
del vr
|
| 121 |
+
gc.collect()
|
| 122 |
+
|
| 123 |
+
def get_video_reader_batch(video_reader, batch_index):
|
| 124 |
+
frames = video_reader.get_batch(batch_index).asnumpy()
|
| 125 |
+
return frames
|
| 126 |
+
|
| 127 |
+
def resize_frame(frame, target_short_side):
|
| 128 |
+
h, w, _ = frame.shape
|
| 129 |
+
if h < w:
|
| 130 |
+
if target_short_side > h:
|
| 131 |
+
return frame
|
| 132 |
+
new_h = target_short_side
|
| 133 |
+
new_w = int(target_short_side * w / h)
|
| 134 |
+
else:
|
| 135 |
+
if target_short_side > w:
|
| 136 |
+
return frame
|
| 137 |
+
new_w = target_short_side
|
| 138 |
+
new_h = int(target_short_side * h / w)
|
| 139 |
+
|
| 140 |
+
resized_frame = cv2.resize(frame, (new_w, new_h))
|
| 141 |
+
return resized_frame
|
| 142 |
+
|
| 143 |
+
def padding_image(images, new_width, new_height):
|
| 144 |
+
new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255))
|
| 145 |
+
|
| 146 |
+
aspect_ratio = images.width / images.height
|
| 147 |
+
if new_width / new_height > 1:
|
| 148 |
+
if aspect_ratio > new_width / new_height:
|
| 149 |
+
new_img_width = new_width
|
| 150 |
+
new_img_height = int(new_img_width / aspect_ratio)
|
| 151 |
+
else:
|
| 152 |
+
new_img_height = new_height
|
| 153 |
+
new_img_width = int(new_img_height * aspect_ratio)
|
| 154 |
+
else:
|
| 155 |
+
if aspect_ratio > new_width / new_height:
|
| 156 |
+
new_img_width = new_width
|
| 157 |
+
new_img_height = int(new_img_width / aspect_ratio)
|
| 158 |
+
else:
|
| 159 |
+
new_img_height = new_height
|
| 160 |
+
new_img_width = int(new_img_height * aspect_ratio)
|
| 161 |
+
|
| 162 |
+
resized_img = images.resize((new_img_width, new_img_height))
|
| 163 |
+
|
| 164 |
+
paste_x = (new_width - new_img_width) // 2
|
| 165 |
+
paste_y = (new_height - new_img_height) // 2
|
| 166 |
+
|
| 167 |
+
new_image.paste(resized_img, (paste_x, paste_y))
|
| 168 |
+
|
| 169 |
+
return new_image
|
| 170 |
+
|
| 171 |
+
def resize_image_with_target_area(img: Image.Image, target_area: int = 1024 * 1024) -> Image.Image:
|
| 172 |
+
"""
|
| 173 |
+
将 PIL 图像缩放到接近指定像素面积(target_area),保持原始宽高比,
|
| 174 |
+
并确保新宽度和高度均为 32 的整数倍。
|
| 175 |
+
|
| 176 |
+
参数:
|
| 177 |
+
img (PIL.Image.Image): 输入图像
|
| 178 |
+
target_area (int): 目标像素总面积,例如 1024*1024 = 1048576
|
| 179 |
+
|
| 180 |
+
返回:
|
| 181 |
+
PIL.Image.Image: Resize 后的图像
|
| 182 |
+
"""
|
| 183 |
+
orig_w, orig_h = img.size
|
| 184 |
+
if orig_w == 0 or orig_h == 0:
|
| 185 |
+
raise ValueError("Input image has zero width or height.")
|
| 186 |
+
|
| 187 |
+
ratio = orig_w / orig_h
|
| 188 |
+
ideal_width = math.sqrt(target_area * ratio)
|
| 189 |
+
ideal_height = ideal_width / ratio
|
| 190 |
+
|
| 191 |
+
new_width = round(ideal_width / 32) * 32
|
| 192 |
+
new_height = round(ideal_height / 32) * 32
|
| 193 |
+
|
| 194 |
+
new_width = max(32, new_width)
|
| 195 |
+
new_height = max(32, new_height)
|
| 196 |
+
|
| 197 |
+
new_width = int(new_width)
|
| 198 |
+
new_height = int(new_height)
|
| 199 |
+
|
| 200 |
+
resized_img = img.resize((new_width, new_height), Image.LANCZOS)
|
| 201 |
+
return resized_img
|
| 202 |
+
|
| 203 |
+
class Camera(object):
|
| 204 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 205 |
+
"""
|
| 206 |
+
def __init__(self, entry):
|
| 207 |
+
fx, fy, cx, cy = entry[1:5]
|
| 208 |
+
self.fx = fx
|
| 209 |
+
self.fy = fy
|
| 210 |
+
self.cx = cx
|
| 211 |
+
self.cy = cy
|
| 212 |
+
w2c_mat = np.array(entry[7:]).reshape(3, 4)
|
| 213 |
+
w2c_mat_4x4 = np.eye(4)
|
| 214 |
+
w2c_mat_4x4[:3, :] = w2c_mat
|
| 215 |
+
self.w2c_mat = w2c_mat_4x4
|
| 216 |
+
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
|
| 217 |
+
|
| 218 |
+
def custom_meshgrid(*args):
|
| 219 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 220 |
+
"""
|
| 221 |
+
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
| 222 |
+
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
| 223 |
+
return torch.meshgrid(*args)
|
| 224 |
+
else:
|
| 225 |
+
return torch.meshgrid(*args, indexing='ij')
|
| 226 |
+
|
| 227 |
+
def get_relative_pose(cam_params):
|
| 228 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 229 |
+
"""
|
| 230 |
+
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
| 231 |
+
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
| 232 |
+
cam_to_origin = 0
|
| 233 |
+
target_cam_c2w = np.array([
|
| 234 |
+
[1, 0, 0, 0],
|
| 235 |
+
[0, 1, 0, -cam_to_origin],
|
| 236 |
+
[0, 0, 1, 0],
|
| 237 |
+
[0, 0, 0, 1]
|
| 238 |
+
])
|
| 239 |
+
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
| 240 |
+
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
| 241 |
+
ret_poses = np.array(ret_poses, dtype=np.float32)
|
| 242 |
+
return ret_poses
|
| 243 |
+
|
| 244 |
+
def ray_condition(K, c2w, H, W, device):
|
| 245 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 246 |
+
"""
|
| 247 |
+
# c2w: B, V, 4, 4
|
| 248 |
+
# K: B, V, 4
|
| 249 |
+
|
| 250 |
+
B = K.shape[0]
|
| 251 |
+
|
| 252 |
+
j, i = custom_meshgrid(
|
| 253 |
+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
| 254 |
+
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
| 255 |
+
)
|
| 256 |
+
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 257 |
+
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 258 |
+
|
| 259 |
+
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
| 260 |
+
|
| 261 |
+
zs = torch.ones_like(i) # [B, HxW]
|
| 262 |
+
xs = (i - cx) / fx * zs
|
| 263 |
+
ys = (j - cy) / fy * zs
|
| 264 |
+
zs = zs.expand_as(ys)
|
| 265 |
+
|
| 266 |
+
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
| 267 |
+
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
| 268 |
+
|
| 269 |
+
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
| 270 |
+
rays_o = c2w[..., :3, 3] # B, V, 3
|
| 271 |
+
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
| 272 |
+
# c2w @ dirctions
|
| 273 |
+
rays_dxo = torch.cross(rays_o, rays_d)
|
| 274 |
+
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
| 275 |
+
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
| 276 |
+
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
| 277 |
+
return plucker
|
| 278 |
+
|
| 279 |
+
def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
|
| 280 |
+
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 281 |
+
"""
|
| 282 |
+
with open(pose_file_path, 'r') as f:
|
| 283 |
+
poses = f.readlines()
|
| 284 |
+
|
| 285 |
+
poses = [pose.strip().split(' ') for pose in poses[1:]]
|
| 286 |
+
cam_params = [[float(x) for x in pose] for pose in poses]
|
| 287 |
+
if return_poses:
|
| 288 |
+
return cam_params
|
| 289 |
+
else:
|
| 290 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
| 291 |
+
|
| 292 |
+
sample_wh_ratio = width / height
|
| 293 |
+
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
| 294 |
+
|
| 295 |
+
if pose_wh_ratio > sample_wh_ratio:
|
| 296 |
+
resized_ori_w = height * pose_wh_ratio
|
| 297 |
+
for cam_param in cam_params:
|
| 298 |
+
cam_param.fx = resized_ori_w * cam_param.fx / width
|
| 299 |
+
else:
|
| 300 |
+
resized_ori_h = width / pose_wh_ratio
|
| 301 |
+
for cam_param in cam_params:
|
| 302 |
+
cam_param.fy = resized_ori_h * cam_param.fy / height
|
| 303 |
+
|
| 304 |
+
intrinsic = np.asarray([[cam_param.fx * width,
|
| 305 |
+
cam_param.fy * height,
|
| 306 |
+
cam_param.cx * width,
|
| 307 |
+
cam_param.cy * height]
|
| 308 |
+
for cam_param in cam_params], dtype=np.float32)
|
| 309 |
+
|
| 310 |
+
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
| 311 |
+
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
| 312 |
+
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
| 313 |
+
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
| 314 |
+
plucker_embedding = plucker_embedding[None]
|
| 315 |
+
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
| 316 |
+
return plucker_embedding
|
| 317 |
+
|
| 318 |
+
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
|
| 319 |
+
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 320 |
+
"""
|
| 321 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
| 322 |
+
|
| 323 |
+
sample_wh_ratio = width / height
|
| 324 |
+
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
| 325 |
+
|
| 326 |
+
if pose_wh_ratio > sample_wh_ratio:
|
| 327 |
+
resized_ori_w = height * pose_wh_ratio
|
| 328 |
+
for cam_param in cam_params:
|
| 329 |
+
cam_param.fx = resized_ori_w * cam_param.fx / width
|
| 330 |
+
else:
|
| 331 |
+
resized_ori_h = width / pose_wh_ratio
|
| 332 |
+
for cam_param in cam_params:
|
| 333 |
+
cam_param.fy = resized_ori_h * cam_param.fy / height
|
| 334 |
+
|
| 335 |
+
intrinsic = np.asarray([[cam_param.fx * width,
|
| 336 |
+
cam_param.fy * height,
|
| 337 |
+
cam_param.cx * width,
|
| 338 |
+
cam_param.cy * height]
|
| 339 |
+
for cam_param in cam_params], dtype=np.float32)
|
| 340 |
+
|
| 341 |
+
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
| 342 |
+
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
| 343 |
+
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
| 344 |
+
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
| 345 |
+
plucker_embedding = plucker_embedding[None]
|
| 346 |
+
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
| 347 |
+
return plucker_embedding
|
videox_fun/pipeline/__init__.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from .pipeline_cogvideox_fun import CogVideoXFunPipeline
|
| 2 |
+
# from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline
|
| 3 |
+
# from .pipeline_cogvideox_fun_inpaint import CogVideoXFunInpaintPipeline
|
| 4 |
+
# from .pipeline_fantasy_talking import FantasyTalkingPipeline
|
| 5 |
+
# from .pipeline_flux import FluxPipeline
|
| 6 |
+
# from .pipeline_flux2 import Flux2Pipeline
|
| 7 |
+
# from .pipeline_flux2_control import Flux2ControlPipeline
|
| 8 |
+
# from .pipeline_hunyuanvideo import HunyuanVideoPipeline
|
| 9 |
+
# from .pipeline_hunyuanvideo_i2v import HunyuanVideoI2VPipeline
|
| 10 |
+
# from .pipeline_qwenimage import QwenImagePipeline
|
| 11 |
+
# from .pipeline_qwenimage_edit import QwenImageEditPipeline
|
| 12 |
+
# from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
|
| 13 |
+
# from .pipeline_wan import WanPipeline
|
| 14 |
+
# from .pipeline_wan2_2 import Wan2_2Pipeline
|
| 15 |
+
# from .pipeline_wan2_2_animate import Wan2_2AnimatePipeline
|
| 16 |
+
# from .pipeline_wan2_2_fun_control import Wan2_2FunControlPipeline
|
| 17 |
+
# from .pipeline_wan2_2_fun_inpaint import Wan2_2FunInpaintPipeline
|
| 18 |
+
# from .pipeline_wan2_2_s2v import Wan2_2S2VPipeline
|
| 19 |
+
# from .pipeline_wan2_2_ti2v import Wan2_2TI2VPipeline
|
| 20 |
+
# from .pipeline_wan2_2_vace_fun import Wan2_2VaceFunPipeline
|
| 21 |
+
# from .pipeline_wan_fun_control import WanFunControlPipeline
|
| 22 |
+
# from .pipeline_wan_fun_inpaint import WanFunInpaintPipeline
|
| 23 |
+
# from .pipeline_wan_phantom import WanFunPhantomPipeline
|
| 24 |
+
# from .pipeline_wan_vace import WanVacePipeline
|
| 25 |
+
from .pipeline_z_image import ZImagePipeline
|
| 26 |
+
from .pipeline_z_image_control import ZImageControlPipeline
|
| 27 |
+
|
| 28 |
+
# WanFunPipeline = WanPipeline
|
| 29 |
+
# WanI2VPipeline = WanFunInpaintPipeline
|
| 30 |
+
|
| 31 |
+
# Wan2_2FunPipeline = Wan2_2Pipeline
|
| 32 |
+
# Wan2_2I2VPipeline = Wan2_2FunInpaintPipeline
|
| 33 |
+
|
| 34 |
+
# import importlib.util
|
| 35 |
+
|
| 36 |
+
# if importlib.util.find_spec("paifuser") is not None:
|
| 37 |
+
# # --------------------------------------------------------------- #
|
| 38 |
+
# # Sparse Attention
|
| 39 |
+
# # --------------------------------------------------------------- #
|
| 40 |
+
# from paifuser.ops import sparse_reset
|
| 41 |
+
|
| 42 |
+
# # Wan2.1
|
| 43 |
+
# WanFunInpaintPipeline.__call__ = sparse_reset(WanFunInpaintPipeline.__call__)
|
| 44 |
+
# WanFunPipeline.__call__ = sparse_reset(WanFunPipeline.__call__)
|
| 45 |
+
# WanFunControlPipeline.__call__ = sparse_reset(WanFunControlPipeline.__call__)
|
| 46 |
+
# WanI2VPipeline.__call__ = sparse_reset(WanI2VPipeline.__call__)
|
| 47 |
+
# WanPipeline.__call__ = sparse_reset(WanPipeline.__call__)
|
| 48 |
+
# WanVacePipeline.__call__ = sparse_reset(WanVacePipeline.__call__)
|
| 49 |
+
|
| 50 |
+
# # Phantom
|
| 51 |
+
# WanFunPhantomPipeline.__call__ = sparse_reset(WanFunPhantomPipeline.__call__)
|
| 52 |
+
|
| 53 |
+
# # Wan2.2
|
| 54 |
+
# Wan2_2FunInpaintPipeline.__call__ = sparse_reset(Wan2_2FunInpaintPipeline.__call__)
|
| 55 |
+
# Wan2_2FunPipeline.__call__ = sparse_reset(Wan2_2FunPipeline.__call__)
|
| 56 |
+
# Wan2_2FunControlPipeline.__call__ = sparse_reset(Wan2_2FunControlPipeline.__call__)
|
| 57 |
+
# Wan2_2Pipeline.__call__ = sparse_reset(Wan2_2Pipeline.__call__)
|
| 58 |
+
# Wan2_2I2VPipeline.__call__ = sparse_reset(Wan2_2I2VPipeline.__call__)
|
| 59 |
+
# Wan2_2TI2VPipeline.__call__ = sparse_reset(Wan2_2TI2VPipeline.__call__)
|
| 60 |
+
# Wan2_2S2VPipeline.__call__ = sparse_reset(Wan2_2S2VPipeline.__call__)
|
| 61 |
+
# Wan2_2VaceFunPipeline.__call__ = sparse_reset(Wan2_2VaceFunPipeline.__call__)
|
| 62 |
+
# Wan2_2AnimatePipeline.__call__ = sparse_reset(Wan2_2AnimatePipeline.__call__)
|
videox_fun/pipeline/pipeline_cogvideox_fun.py
ADDED
|
@@ -0,0 +1,862 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 24 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 26 |
+
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 27 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 28 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 29 |
+
from diffusers.video_processor import VideoProcessor
|
| 30 |
+
|
| 31 |
+
from ..models import (AutoencoderKLCogVideoX,
|
| 32 |
+
CogVideoXTransformer3DModel, T5EncoderModel,
|
| 33 |
+
T5Tokenizer)
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
EXAMPLE_DOC_STRING = """
|
| 39 |
+
Examples:
|
| 40 |
+
```python
|
| 41 |
+
pass
|
| 42 |
+
```
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed
|
| 47 |
+
def get_3d_rotary_pos_embed(
|
| 48 |
+
embed_dim,
|
| 49 |
+
crops_coords,
|
| 50 |
+
grid_size,
|
| 51 |
+
temporal_size,
|
| 52 |
+
theta: int = 10000,
|
| 53 |
+
use_real: bool = True,
|
| 54 |
+
grid_type: str = "linspace",
|
| 55 |
+
max_size: Optional[Tuple[int, int]] = None,
|
| 56 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 57 |
+
"""
|
| 58 |
+
RoPE for video tokens with 3D structure.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
embed_dim: (`int`):
|
| 62 |
+
The embedding dimension size, corresponding to hidden_size_head.
|
| 63 |
+
crops_coords (`Tuple[int]`):
|
| 64 |
+
The top-left and bottom-right coordinates of the crop.
|
| 65 |
+
grid_size (`Tuple[int]`):
|
| 66 |
+
The grid size of the spatial positional embedding (height, width).
|
| 67 |
+
temporal_size (`int`):
|
| 68 |
+
The size of the temporal dimension.
|
| 69 |
+
theta (`float`):
|
| 70 |
+
Scaling factor for frequency computation.
|
| 71 |
+
grid_type (`str`):
|
| 72 |
+
Whether to use "linspace" or "slice" to compute grids.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
| 76 |
+
"""
|
| 77 |
+
if use_real is not True:
|
| 78 |
+
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
| 79 |
+
|
| 80 |
+
if grid_type == "linspace":
|
| 81 |
+
start, stop = crops_coords
|
| 82 |
+
grid_size_h, grid_size_w = grid_size
|
| 83 |
+
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
| 84 |
+
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
| 85 |
+
grid_t = np.arange(temporal_size, dtype=np.float32)
|
| 86 |
+
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
| 87 |
+
elif grid_type == "slice":
|
| 88 |
+
max_h, max_w = max_size
|
| 89 |
+
grid_size_h, grid_size_w = grid_size
|
| 90 |
+
grid_h = np.arange(max_h, dtype=np.float32)
|
| 91 |
+
grid_w = np.arange(max_w, dtype=np.float32)
|
| 92 |
+
grid_t = np.arange(temporal_size, dtype=np.float32)
|
| 93 |
+
else:
|
| 94 |
+
raise ValueError("Invalid value passed for `grid_type`.")
|
| 95 |
+
|
| 96 |
+
# Compute dimensions for each axis
|
| 97 |
+
dim_t = embed_dim // 4
|
| 98 |
+
dim_h = embed_dim // 8 * 3
|
| 99 |
+
dim_w = embed_dim // 8 * 3
|
| 100 |
+
|
| 101 |
+
# Temporal frequencies
|
| 102 |
+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
| 103 |
+
# Spatial frequencies for height and width
|
| 104 |
+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
| 105 |
+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
| 106 |
+
|
| 107 |
+
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
| 108 |
+
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
| 109 |
+
freqs_t = freqs_t[:, None, None, :].expand(
|
| 110 |
+
-1, grid_size_h, grid_size_w, -1
|
| 111 |
+
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
| 112 |
+
freqs_h = freqs_h[None, :, None, :].expand(
|
| 113 |
+
temporal_size, -1, grid_size_w, -1
|
| 114 |
+
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
| 115 |
+
freqs_w = freqs_w[None, None, :, :].expand(
|
| 116 |
+
temporal_size, grid_size_h, -1, -1
|
| 117 |
+
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
| 118 |
+
|
| 119 |
+
freqs = torch.cat(
|
| 120 |
+
[freqs_t, freqs_h, freqs_w], dim=-1
|
| 121 |
+
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
| 122 |
+
freqs = freqs.view(
|
| 123 |
+
temporal_size * grid_size_h * grid_size_w, -1
|
| 124 |
+
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
| 125 |
+
return freqs
|
| 126 |
+
|
| 127 |
+
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
| 128 |
+
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
| 129 |
+
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
| 130 |
+
|
| 131 |
+
if grid_type == "slice":
|
| 132 |
+
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
|
| 133 |
+
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
|
| 134 |
+
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
|
| 135 |
+
|
| 136 |
+
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
| 137 |
+
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
| 138 |
+
return cos, sin
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
| 142 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 143 |
+
tw = tgt_width
|
| 144 |
+
th = tgt_height
|
| 145 |
+
h, w = src
|
| 146 |
+
r = h / w
|
| 147 |
+
if r > (th / tw):
|
| 148 |
+
resize_height = th
|
| 149 |
+
resize_width = int(round(th / h * w))
|
| 150 |
+
else:
|
| 151 |
+
resize_width = tw
|
| 152 |
+
resize_height = int(round(tw / w * h))
|
| 153 |
+
|
| 154 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 155 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 156 |
+
|
| 157 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 161 |
+
def retrieve_timesteps(
|
| 162 |
+
scheduler,
|
| 163 |
+
num_inference_steps: Optional[int] = None,
|
| 164 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 165 |
+
timesteps: Optional[List[int]] = None,
|
| 166 |
+
sigmas: Optional[List[float]] = None,
|
| 167 |
+
**kwargs,
|
| 168 |
+
):
|
| 169 |
+
"""
|
| 170 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 171 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
scheduler (`SchedulerMixin`):
|
| 175 |
+
The scheduler to get timesteps from.
|
| 176 |
+
num_inference_steps (`int`):
|
| 177 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 178 |
+
must be `None`.
|
| 179 |
+
device (`str` or `torch.device`, *optional*):
|
| 180 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 181 |
+
timesteps (`List[int]`, *optional*):
|
| 182 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 183 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 184 |
+
sigmas (`List[float]`, *optional*):
|
| 185 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 186 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 190 |
+
second element is the number of inference steps.
|
| 191 |
+
"""
|
| 192 |
+
if timesteps is not None and sigmas is not None:
|
| 193 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 194 |
+
if timesteps is not None:
|
| 195 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 196 |
+
if not accepts_timesteps:
|
| 197 |
+
raise ValueError(
|
| 198 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 199 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 200 |
+
)
|
| 201 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 202 |
+
timesteps = scheduler.timesteps
|
| 203 |
+
num_inference_steps = len(timesteps)
|
| 204 |
+
elif sigmas is not None:
|
| 205 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 206 |
+
if not accept_sigmas:
|
| 207 |
+
raise ValueError(
|
| 208 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 209 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 210 |
+
)
|
| 211 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 212 |
+
timesteps = scheduler.timesteps
|
| 213 |
+
num_inference_steps = len(timesteps)
|
| 214 |
+
else:
|
| 215 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 216 |
+
timesteps = scheduler.timesteps
|
| 217 |
+
return timesteps, num_inference_steps
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
@dataclass
|
| 221 |
+
class CogVideoXFunPipelineOutput(BaseOutput):
|
| 222 |
+
r"""
|
| 223 |
+
Output class for CogVideo pipelines.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 227 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 228 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 229 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
videos: torch.Tensor
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class CogVideoXFunPipeline(DiffusionPipeline):
|
| 236 |
+
r"""
|
| 237 |
+
Pipeline for text-to-video generation using CogVideoX_Fun.
|
| 238 |
+
|
| 239 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 240 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
vae ([`AutoencoderKL`]):
|
| 244 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 245 |
+
text_encoder ([`T5EncoderModel`]):
|
| 246 |
+
Frozen text-encoder. CogVideoX uses
|
| 247 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 248 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 249 |
+
tokenizer (`T5Tokenizer`):
|
| 250 |
+
Tokenizer of class
|
| 251 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 252 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
| 253 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
| 254 |
+
scheduler ([`SchedulerMixin`]):
|
| 255 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
_optional_components = []
|
| 259 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 260 |
+
|
| 261 |
+
_callback_tensor_inputs = [
|
| 262 |
+
"latents",
|
| 263 |
+
"prompt_embeds",
|
| 264 |
+
"negative_prompt_embeds",
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
tokenizer: T5Tokenizer,
|
| 270 |
+
text_encoder: T5EncoderModel,
|
| 271 |
+
vae: AutoencoderKLCogVideoX,
|
| 272 |
+
transformer: CogVideoXTransformer3DModel,
|
| 273 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 274 |
+
):
|
| 275 |
+
super().__init__()
|
| 276 |
+
|
| 277 |
+
self.register_modules(
|
| 278 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 279 |
+
)
|
| 280 |
+
self.vae_scale_factor_spatial = (
|
| 281 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 282 |
+
)
|
| 283 |
+
self.vae_scale_factor_temporal = (
|
| 284 |
+
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 288 |
+
|
| 289 |
+
def _get_t5_prompt_embeds(
|
| 290 |
+
self,
|
| 291 |
+
prompt: Union[str, List[str]] = None,
|
| 292 |
+
num_videos_per_prompt: int = 1,
|
| 293 |
+
max_sequence_length: int = 226,
|
| 294 |
+
device: Optional[torch.device] = None,
|
| 295 |
+
dtype: Optional[torch.dtype] = None,
|
| 296 |
+
):
|
| 297 |
+
device = device or self._execution_device
|
| 298 |
+
dtype = dtype or self.text_encoder.dtype
|
| 299 |
+
|
| 300 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 301 |
+
batch_size = len(prompt)
|
| 302 |
+
|
| 303 |
+
text_inputs = self.tokenizer(
|
| 304 |
+
prompt,
|
| 305 |
+
padding="max_length",
|
| 306 |
+
max_length=max_sequence_length,
|
| 307 |
+
truncation=True,
|
| 308 |
+
add_special_tokens=True,
|
| 309 |
+
return_tensors="pt",
|
| 310 |
+
)
|
| 311 |
+
text_input_ids = text_inputs.input_ids
|
| 312 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 313 |
+
|
| 314 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 315 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 316 |
+
logger.warning(
|
| 317 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 318 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
| 322 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 323 |
+
|
| 324 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 325 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 326 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 327 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 328 |
+
|
| 329 |
+
return prompt_embeds
|
| 330 |
+
|
| 331 |
+
def encode_prompt(
|
| 332 |
+
self,
|
| 333 |
+
prompt: Union[str, List[str]],
|
| 334 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 335 |
+
do_classifier_free_guidance: bool = True,
|
| 336 |
+
num_videos_per_prompt: int = 1,
|
| 337 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 338 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 339 |
+
max_sequence_length: int = 226,
|
| 340 |
+
device: Optional[torch.device] = None,
|
| 341 |
+
dtype: Optional[torch.dtype] = None,
|
| 342 |
+
):
|
| 343 |
+
r"""
|
| 344 |
+
Encodes the prompt into text encoder hidden states.
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 348 |
+
prompt to be encoded
|
| 349 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 350 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 351 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 352 |
+
less than `1`).
|
| 353 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 354 |
+
Whether to use classifier free guidance or not.
|
| 355 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 356 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 357 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 358 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 359 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 360 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 361 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 362 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 363 |
+
argument.
|
| 364 |
+
device: (`torch.device`, *optional*):
|
| 365 |
+
torch device
|
| 366 |
+
dtype: (`torch.dtype`, *optional*):
|
| 367 |
+
torch dtype
|
| 368 |
+
"""
|
| 369 |
+
device = device or self._execution_device
|
| 370 |
+
|
| 371 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 372 |
+
if prompt is not None:
|
| 373 |
+
batch_size = len(prompt)
|
| 374 |
+
else:
|
| 375 |
+
batch_size = prompt_embeds.shape[0]
|
| 376 |
+
|
| 377 |
+
if prompt_embeds is None:
|
| 378 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 379 |
+
prompt=prompt,
|
| 380 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 381 |
+
max_sequence_length=max_sequence_length,
|
| 382 |
+
device=device,
|
| 383 |
+
dtype=dtype,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 387 |
+
negative_prompt = negative_prompt or ""
|
| 388 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 389 |
+
|
| 390 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 391 |
+
raise TypeError(
|
| 392 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 393 |
+
f" {type(prompt)}."
|
| 394 |
+
)
|
| 395 |
+
elif batch_size != len(negative_prompt):
|
| 396 |
+
raise ValueError(
|
| 397 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 398 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 399 |
+
" the batch size of `prompt`."
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 403 |
+
prompt=negative_prompt,
|
| 404 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 405 |
+
max_sequence_length=max_sequence_length,
|
| 406 |
+
device=device,
|
| 407 |
+
dtype=dtype,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
return prompt_embeds, negative_prompt_embeds
|
| 411 |
+
|
| 412 |
+
def prepare_latents(
|
| 413 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 414 |
+
):
|
| 415 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 416 |
+
raise ValueError(
|
| 417 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 418 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
shape = (
|
| 422 |
+
batch_size,
|
| 423 |
+
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
| 424 |
+
num_channels_latents,
|
| 425 |
+
height // self.vae_scale_factor_spatial,
|
| 426 |
+
width // self.vae_scale_factor_spatial,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
if latents is None:
|
| 430 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 431 |
+
else:
|
| 432 |
+
latents = latents.to(device)
|
| 433 |
+
|
| 434 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 435 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 436 |
+
return latents
|
| 437 |
+
|
| 438 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 439 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
| 440 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 441 |
+
|
| 442 |
+
frames = self.vae.decode(latents).sample
|
| 443 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 444 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 445 |
+
frames = frames.cpu().float().numpy()
|
| 446 |
+
return frames
|
| 447 |
+
|
| 448 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 449 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 450 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 451 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 452 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 453 |
+
# and should be between [0, 1]
|
| 454 |
+
|
| 455 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 456 |
+
extra_step_kwargs = {}
|
| 457 |
+
if accepts_eta:
|
| 458 |
+
extra_step_kwargs["eta"] = eta
|
| 459 |
+
|
| 460 |
+
# check if the scheduler accepts generator
|
| 461 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 462 |
+
if accepts_generator:
|
| 463 |
+
extra_step_kwargs["generator"] = generator
|
| 464 |
+
return extra_step_kwargs
|
| 465 |
+
|
| 466 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 467 |
+
def check_inputs(
|
| 468 |
+
self,
|
| 469 |
+
prompt,
|
| 470 |
+
height,
|
| 471 |
+
width,
|
| 472 |
+
negative_prompt,
|
| 473 |
+
callback_on_step_end_tensor_inputs,
|
| 474 |
+
prompt_embeds=None,
|
| 475 |
+
negative_prompt_embeds=None,
|
| 476 |
+
):
|
| 477 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 478 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 479 |
+
|
| 480 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 481 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 482 |
+
):
|
| 483 |
+
raise ValueError(
|
| 484 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 485 |
+
)
|
| 486 |
+
if prompt is not None and prompt_embeds is not None:
|
| 487 |
+
raise ValueError(
|
| 488 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 489 |
+
" only forward one of the two."
|
| 490 |
+
)
|
| 491 |
+
elif prompt is None and prompt_embeds is None:
|
| 492 |
+
raise ValueError(
|
| 493 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 494 |
+
)
|
| 495 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 496 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 497 |
+
|
| 498 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 499 |
+
raise ValueError(
|
| 500 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 501 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 505 |
+
raise ValueError(
|
| 506 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 507 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 511 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 512 |
+
raise ValueError(
|
| 513 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 514 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 515 |
+
f" {negative_prompt_embeds.shape}."
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
def fuse_qkv_projections(self) -> None:
|
| 519 |
+
r"""Enables fused QKV projections."""
|
| 520 |
+
self.fusing_transformer = True
|
| 521 |
+
self.transformer.fuse_qkv_projections()
|
| 522 |
+
|
| 523 |
+
def unfuse_qkv_projections(self) -> None:
|
| 524 |
+
r"""Disable QKV projection fusion if enabled."""
|
| 525 |
+
if not self.fusing_transformer:
|
| 526 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 527 |
+
else:
|
| 528 |
+
self.transformer.unfuse_qkv_projections()
|
| 529 |
+
self.fusing_transformer = False
|
| 530 |
+
|
| 531 |
+
def _prepare_rotary_positional_embeddings(
|
| 532 |
+
self,
|
| 533 |
+
height: int,
|
| 534 |
+
width: int,
|
| 535 |
+
num_frames: int,
|
| 536 |
+
device: torch.device,
|
| 537 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 538 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 539 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 540 |
+
|
| 541 |
+
p = self.transformer.config.patch_size
|
| 542 |
+
p_t = self.transformer.config.patch_size_t
|
| 543 |
+
|
| 544 |
+
base_size_width = self.transformer.config.sample_width // p
|
| 545 |
+
base_size_height = self.transformer.config.sample_height // p
|
| 546 |
+
|
| 547 |
+
if p_t is None:
|
| 548 |
+
# CogVideoX 1.0
|
| 549 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 550 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 551 |
+
)
|
| 552 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 553 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 554 |
+
crops_coords=grid_crops_coords,
|
| 555 |
+
grid_size=(grid_height, grid_width),
|
| 556 |
+
temporal_size=num_frames,
|
| 557 |
+
)
|
| 558 |
+
else:
|
| 559 |
+
# CogVideoX 1.5
|
| 560 |
+
base_num_frames = (num_frames + p_t - 1) // p_t
|
| 561 |
+
|
| 562 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 563 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 564 |
+
crops_coords=None,
|
| 565 |
+
grid_size=(grid_height, grid_width),
|
| 566 |
+
temporal_size=base_num_frames,
|
| 567 |
+
grid_type="slice",
|
| 568 |
+
max_size=(base_size_height, base_size_width),
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
freqs_cos = freqs_cos.to(device=device)
|
| 572 |
+
freqs_sin = freqs_sin.to(device=device)
|
| 573 |
+
return freqs_cos, freqs_sin
|
| 574 |
+
|
| 575 |
+
@property
|
| 576 |
+
def guidance_scale(self):
|
| 577 |
+
return self._guidance_scale
|
| 578 |
+
|
| 579 |
+
@property
|
| 580 |
+
def num_timesteps(self):
|
| 581 |
+
return self._num_timesteps
|
| 582 |
+
|
| 583 |
+
@property
|
| 584 |
+
def attention_kwargs(self):
|
| 585 |
+
return self._attention_kwargs
|
| 586 |
+
|
| 587 |
+
@property
|
| 588 |
+
def interrupt(self):
|
| 589 |
+
return self._interrupt
|
| 590 |
+
|
| 591 |
+
@torch.no_grad()
|
| 592 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 593 |
+
def __call__(
|
| 594 |
+
self,
|
| 595 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 596 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 597 |
+
height: int = 480,
|
| 598 |
+
width: int = 720,
|
| 599 |
+
num_frames: int = 49,
|
| 600 |
+
num_inference_steps: int = 50,
|
| 601 |
+
timesteps: Optional[List[int]] = None,
|
| 602 |
+
guidance_scale: float = 6,
|
| 603 |
+
use_dynamic_cfg: bool = False,
|
| 604 |
+
num_videos_per_prompt: int = 1,
|
| 605 |
+
eta: float = 0.0,
|
| 606 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 607 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 608 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 609 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 610 |
+
output_type: str = "numpy",
|
| 611 |
+
return_dict: bool = False,
|
| 612 |
+
callback_on_step_end: Optional[
|
| 613 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 614 |
+
] = None,
|
| 615 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 616 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 617 |
+
max_sequence_length: int = 226,
|
| 618 |
+
) -> Union[CogVideoXFunPipelineOutput, Tuple]:
|
| 619 |
+
"""
|
| 620 |
+
Function invoked when calling the pipeline for generation.
|
| 621 |
+
|
| 622 |
+
Args:
|
| 623 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 624 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 625 |
+
instead.
|
| 626 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 627 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 628 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 629 |
+
less than `1`).
|
| 630 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 631 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 632 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 633 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 634 |
+
num_frames (`int`, defaults to `48`):
|
| 635 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
| 636 |
+
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
| 637 |
+
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
| 638 |
+
needs to be satisfied is that of divisibility mentioned above.
|
| 639 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 640 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 641 |
+
expense of slower inference.
|
| 642 |
+
timesteps (`List[int]`, *optional*):
|
| 643 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 644 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 645 |
+
passed will be used. Must be in descending order.
|
| 646 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 647 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 648 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 649 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 650 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 651 |
+
usually at the expense of lower image quality.
|
| 652 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 653 |
+
The number of videos to generate per prompt.
|
| 654 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 655 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 656 |
+
to make generation deterministic.
|
| 657 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 658 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 659 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 660 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 661 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 662 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 663 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 664 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 665 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 666 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 667 |
+
argument.
|
| 668 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 669 |
+
The output format of the generate image. Choose between
|
| 670 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 671 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 672 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 673 |
+
of a plain tuple.
|
| 674 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 675 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 676 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 677 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 678 |
+
`callback_on_step_end_tensor_inputs`.
|
| 679 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 680 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 681 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 682 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 683 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 684 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 685 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 686 |
+
|
| 687 |
+
Examples:
|
| 688 |
+
|
| 689 |
+
Returns:
|
| 690 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`:
|
| 691 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a
|
| 692 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 693 |
+
"""
|
| 694 |
+
|
| 695 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 696 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 697 |
+
|
| 698 |
+
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
| 699 |
+
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
| 700 |
+
num_frames = num_frames or self.transformer.config.sample_frames
|
| 701 |
+
|
| 702 |
+
num_videos_per_prompt = 1
|
| 703 |
+
|
| 704 |
+
# 1. Check inputs. Raise error if not correct
|
| 705 |
+
self.check_inputs(
|
| 706 |
+
prompt,
|
| 707 |
+
height,
|
| 708 |
+
width,
|
| 709 |
+
negative_prompt,
|
| 710 |
+
callback_on_step_end_tensor_inputs,
|
| 711 |
+
prompt_embeds,
|
| 712 |
+
negative_prompt_embeds,
|
| 713 |
+
)
|
| 714 |
+
self._guidance_scale = guidance_scale
|
| 715 |
+
self._attention_kwargs = attention_kwargs
|
| 716 |
+
self._interrupt = False
|
| 717 |
+
|
| 718 |
+
# 2. Default call parameters
|
| 719 |
+
if prompt is not None and isinstance(prompt, str):
|
| 720 |
+
batch_size = 1
|
| 721 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 722 |
+
batch_size = len(prompt)
|
| 723 |
+
else:
|
| 724 |
+
batch_size = prompt_embeds.shape[0]
|
| 725 |
+
|
| 726 |
+
device = self._execution_device
|
| 727 |
+
|
| 728 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 729 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 730 |
+
# corresponds to doing no classifier free guidance.
|
| 731 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 732 |
+
|
| 733 |
+
# 3. Encode input prompt
|
| 734 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 735 |
+
prompt,
|
| 736 |
+
negative_prompt,
|
| 737 |
+
do_classifier_free_guidance,
|
| 738 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 739 |
+
prompt_embeds=prompt_embeds,
|
| 740 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 741 |
+
max_sequence_length=max_sequence_length,
|
| 742 |
+
device=device,
|
| 743 |
+
)
|
| 744 |
+
if do_classifier_free_guidance:
|
| 745 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 746 |
+
|
| 747 |
+
# 4. Prepare timesteps
|
| 748 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 749 |
+
self._num_timesteps = len(timesteps)
|
| 750 |
+
|
| 751 |
+
# 5. Prepare latents
|
| 752 |
+
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 753 |
+
|
| 754 |
+
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
| 755 |
+
patch_size_t = self.transformer.config.patch_size_t
|
| 756 |
+
additional_frames = 0
|
| 757 |
+
if num_frames != 1 and patch_size_t is not None and latent_frames % patch_size_t != 0:
|
| 758 |
+
additional_frames = patch_size_t - latent_frames % patch_size_t
|
| 759 |
+
num_frames += additional_frames * self.vae_scale_factor_temporal
|
| 760 |
+
|
| 761 |
+
latent_channels = self.transformer.config.in_channels
|
| 762 |
+
latents = self.prepare_latents(
|
| 763 |
+
batch_size * num_videos_per_prompt,
|
| 764 |
+
latent_channels,
|
| 765 |
+
num_frames,
|
| 766 |
+
height,
|
| 767 |
+
width,
|
| 768 |
+
prompt_embeds.dtype,
|
| 769 |
+
device,
|
| 770 |
+
generator,
|
| 771 |
+
latents,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 775 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 776 |
+
|
| 777 |
+
# 7. Create rotary embeds if required
|
| 778 |
+
image_rotary_emb = (
|
| 779 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
| 780 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 781 |
+
else None
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# 8. Denoising loop
|
| 785 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 786 |
+
|
| 787 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 788 |
+
# for DPM-solver++
|
| 789 |
+
old_pred_original_sample = None
|
| 790 |
+
for i, t in enumerate(timesteps):
|
| 791 |
+
if self.interrupt:
|
| 792 |
+
continue
|
| 793 |
+
|
| 794 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 795 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 796 |
+
|
| 797 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 798 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 799 |
+
|
| 800 |
+
# predict noise model_output
|
| 801 |
+
noise_pred = self.transformer(
|
| 802 |
+
hidden_states=latent_model_input,
|
| 803 |
+
encoder_hidden_states=prompt_embeds,
|
| 804 |
+
timestep=timestep,
|
| 805 |
+
image_rotary_emb=image_rotary_emb,
|
| 806 |
+
return_dict=False,
|
| 807 |
+
)[0]
|
| 808 |
+
noise_pred = noise_pred.float()
|
| 809 |
+
|
| 810 |
+
# perform guidance
|
| 811 |
+
if use_dynamic_cfg:
|
| 812 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 813 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 814 |
+
)
|
| 815 |
+
if do_classifier_free_guidance:
|
| 816 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 817 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 818 |
+
|
| 819 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 820 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 821 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 822 |
+
else:
|
| 823 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 824 |
+
noise_pred,
|
| 825 |
+
old_pred_original_sample,
|
| 826 |
+
t,
|
| 827 |
+
timesteps[i - 1] if i > 0 else None,
|
| 828 |
+
latents,
|
| 829 |
+
**extra_step_kwargs,
|
| 830 |
+
return_dict=False,
|
| 831 |
+
)
|
| 832 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 833 |
+
|
| 834 |
+
# call the callback, if provided
|
| 835 |
+
if callback_on_step_end is not None:
|
| 836 |
+
callback_kwargs = {}
|
| 837 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 838 |
+
callback_kwargs[k] = locals()[k]
|
| 839 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 840 |
+
|
| 841 |
+
latents = callback_outputs.pop("latents", latents)
|
| 842 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 843 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 844 |
+
|
| 845 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 846 |
+
progress_bar.update()
|
| 847 |
+
|
| 848 |
+
if output_type == "numpy":
|
| 849 |
+
video = self.decode_latents(latents)
|
| 850 |
+
elif not output_type == "latent":
|
| 851 |
+
video = self.decode_latents(latents)
|
| 852 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 853 |
+
else:
|
| 854 |
+
video = latents
|
| 855 |
+
|
| 856 |
+
# Offload all models
|
| 857 |
+
self.maybe_free_model_hooks()
|
| 858 |
+
|
| 859 |
+
if not return_dict:
|
| 860 |
+
video = torch.from_numpy(video)
|
| 861 |
+
|
| 862 |
+
return CogVideoXFunPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_cogvideox_fun_control.py
ADDED
|
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 25 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 26 |
+
from diffusers.models.embeddings import (get_1d_rotary_pos_embed,
|
| 27 |
+
get_3d_rotary_pos_embed)
|
| 28 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 29 |
+
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 30 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 31 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 32 |
+
from diffusers.video_processor import VideoProcessor
|
| 33 |
+
from einops import rearrange
|
| 34 |
+
|
| 35 |
+
from ..models import (AutoencoderKLCogVideoX,
|
| 36 |
+
CogVideoXTransformer3DModel, T5EncoderModel,
|
| 37 |
+
T5Tokenizer)
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
EXAMPLE_DOC_STRING = """
|
| 43 |
+
Examples:
|
| 44 |
+
```python
|
| 45 |
+
pass
|
| 46 |
+
```
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed
|
| 51 |
+
def get_3d_rotary_pos_embed(
|
| 52 |
+
embed_dim,
|
| 53 |
+
crops_coords,
|
| 54 |
+
grid_size,
|
| 55 |
+
temporal_size,
|
| 56 |
+
theta: int = 10000,
|
| 57 |
+
use_real: bool = True,
|
| 58 |
+
grid_type: str = "linspace",
|
| 59 |
+
max_size: Optional[Tuple[int, int]] = None,
|
| 60 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 61 |
+
"""
|
| 62 |
+
RoPE for video tokens with 3D structure.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
embed_dim: (`int`):
|
| 66 |
+
The embedding dimension size, corresponding to hidden_size_head.
|
| 67 |
+
crops_coords (`Tuple[int]`):
|
| 68 |
+
The top-left and bottom-right coordinates of the crop.
|
| 69 |
+
grid_size (`Tuple[int]`):
|
| 70 |
+
The grid size of the spatial positional embedding (height, width).
|
| 71 |
+
temporal_size (`int`):
|
| 72 |
+
The size of the temporal dimension.
|
| 73 |
+
theta (`float`):
|
| 74 |
+
Scaling factor for frequency computation.
|
| 75 |
+
grid_type (`str`):
|
| 76 |
+
Whether to use "linspace" or "slice" to compute grids.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
| 80 |
+
"""
|
| 81 |
+
if use_real is not True:
|
| 82 |
+
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
| 83 |
+
|
| 84 |
+
if grid_type == "linspace":
|
| 85 |
+
start, stop = crops_coords
|
| 86 |
+
grid_size_h, grid_size_w = grid_size
|
| 87 |
+
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
| 88 |
+
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
| 89 |
+
grid_t = np.arange(temporal_size, dtype=np.float32)
|
| 90 |
+
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
| 91 |
+
elif grid_type == "slice":
|
| 92 |
+
max_h, max_w = max_size
|
| 93 |
+
grid_size_h, grid_size_w = grid_size
|
| 94 |
+
grid_h = np.arange(max_h, dtype=np.float32)
|
| 95 |
+
grid_w = np.arange(max_w, dtype=np.float32)
|
| 96 |
+
grid_t = np.arange(temporal_size, dtype=np.float32)
|
| 97 |
+
else:
|
| 98 |
+
raise ValueError("Invalid value passed for `grid_type`.")
|
| 99 |
+
|
| 100 |
+
# Compute dimensions for each axis
|
| 101 |
+
dim_t = embed_dim // 4
|
| 102 |
+
dim_h = embed_dim // 8 * 3
|
| 103 |
+
dim_w = embed_dim // 8 * 3
|
| 104 |
+
|
| 105 |
+
# Temporal frequencies
|
| 106 |
+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
| 107 |
+
# Spatial frequencies for height and width
|
| 108 |
+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
| 109 |
+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
| 110 |
+
|
| 111 |
+
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
| 112 |
+
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
| 113 |
+
freqs_t = freqs_t[:, None, None, :].expand(
|
| 114 |
+
-1, grid_size_h, grid_size_w, -1
|
| 115 |
+
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
| 116 |
+
freqs_h = freqs_h[None, :, None, :].expand(
|
| 117 |
+
temporal_size, -1, grid_size_w, -1
|
| 118 |
+
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
| 119 |
+
freqs_w = freqs_w[None, None, :, :].expand(
|
| 120 |
+
temporal_size, grid_size_h, -1, -1
|
| 121 |
+
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
| 122 |
+
|
| 123 |
+
freqs = torch.cat(
|
| 124 |
+
[freqs_t, freqs_h, freqs_w], dim=-1
|
| 125 |
+
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
| 126 |
+
freqs = freqs.view(
|
| 127 |
+
temporal_size * grid_size_h * grid_size_w, -1
|
| 128 |
+
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
| 129 |
+
return freqs
|
| 130 |
+
|
| 131 |
+
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
| 132 |
+
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
| 133 |
+
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
| 134 |
+
|
| 135 |
+
if grid_type == "slice":
|
| 136 |
+
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
|
| 137 |
+
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
|
| 138 |
+
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
|
| 139 |
+
|
| 140 |
+
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
| 141 |
+
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
| 142 |
+
return cos, sin
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
| 146 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 147 |
+
tw = tgt_width
|
| 148 |
+
th = tgt_height
|
| 149 |
+
h, w = src
|
| 150 |
+
r = h / w
|
| 151 |
+
if r > (th / tw):
|
| 152 |
+
resize_height = th
|
| 153 |
+
resize_width = int(round(th / h * w))
|
| 154 |
+
else:
|
| 155 |
+
resize_width = tw
|
| 156 |
+
resize_height = int(round(tw / w * h))
|
| 157 |
+
|
| 158 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 159 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 160 |
+
|
| 161 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 165 |
+
def retrieve_timesteps(
|
| 166 |
+
scheduler,
|
| 167 |
+
num_inference_steps: Optional[int] = None,
|
| 168 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 169 |
+
timesteps: Optional[List[int]] = None,
|
| 170 |
+
sigmas: Optional[List[float]] = None,
|
| 171 |
+
**kwargs,
|
| 172 |
+
):
|
| 173 |
+
"""
|
| 174 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 175 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
scheduler (`SchedulerMixin`):
|
| 179 |
+
The scheduler to get timesteps from.
|
| 180 |
+
num_inference_steps (`int`):
|
| 181 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 182 |
+
must be `None`.
|
| 183 |
+
device (`str` or `torch.device`, *optional*):
|
| 184 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 185 |
+
timesteps (`List[int]`, *optional*):
|
| 186 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 187 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 188 |
+
sigmas (`List[float]`, *optional*):
|
| 189 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 190 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 194 |
+
second element is the number of inference steps.
|
| 195 |
+
"""
|
| 196 |
+
if timesteps is not None and sigmas is not None:
|
| 197 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 198 |
+
if timesteps is not None:
|
| 199 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 200 |
+
if not accepts_timesteps:
|
| 201 |
+
raise ValueError(
|
| 202 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 203 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 204 |
+
)
|
| 205 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 206 |
+
timesteps = scheduler.timesteps
|
| 207 |
+
num_inference_steps = len(timesteps)
|
| 208 |
+
elif sigmas is not None:
|
| 209 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 210 |
+
if not accept_sigmas:
|
| 211 |
+
raise ValueError(
|
| 212 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 213 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 214 |
+
)
|
| 215 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 216 |
+
timesteps = scheduler.timesteps
|
| 217 |
+
num_inference_steps = len(timesteps)
|
| 218 |
+
else:
|
| 219 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 220 |
+
timesteps = scheduler.timesteps
|
| 221 |
+
return timesteps, num_inference_steps
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@dataclass
|
| 225 |
+
class CogVideoXFunPipelineOutput(BaseOutput):
|
| 226 |
+
r"""
|
| 227 |
+
Output class for CogVideo pipelines.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 231 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 232 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 233 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
videos: torch.Tensor
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class CogVideoXFunControlPipeline(DiffusionPipeline):
|
| 240 |
+
r"""
|
| 241 |
+
Pipeline for text-to-video generation using CogVideoX.
|
| 242 |
+
|
| 243 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 244 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
vae ([`AutoencoderKL`]):
|
| 248 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 249 |
+
text_encoder ([`T5EncoderModel`]):
|
| 250 |
+
Frozen text-encoder. CogVideoX_Fun uses
|
| 251 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 252 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 253 |
+
tokenizer (`T5Tokenizer`):
|
| 254 |
+
Tokenizer of class
|
| 255 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 256 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
| 257 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
| 258 |
+
scheduler ([`SchedulerMixin`]):
|
| 259 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
_optional_components = []
|
| 263 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 264 |
+
|
| 265 |
+
_callback_tensor_inputs = [
|
| 266 |
+
"latents",
|
| 267 |
+
"prompt_embeds",
|
| 268 |
+
"negative_prompt_embeds",
|
| 269 |
+
]
|
| 270 |
+
|
| 271 |
+
def __init__(
|
| 272 |
+
self,
|
| 273 |
+
tokenizer: T5Tokenizer,
|
| 274 |
+
text_encoder: T5EncoderModel,
|
| 275 |
+
vae: AutoencoderKLCogVideoX,
|
| 276 |
+
transformer: CogVideoXTransformer3DModel,
|
| 277 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 278 |
+
):
|
| 279 |
+
super().__init__()
|
| 280 |
+
|
| 281 |
+
self.register_modules(
|
| 282 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 283 |
+
)
|
| 284 |
+
self.vae_scale_factor_spatial = (
|
| 285 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 286 |
+
)
|
| 287 |
+
self.vae_scale_factor_temporal = (
|
| 288 |
+
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 292 |
+
|
| 293 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 294 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 295 |
+
self.mask_processor = VaeImageProcessor(
|
| 296 |
+
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
def _get_t5_prompt_embeds(
|
| 300 |
+
self,
|
| 301 |
+
prompt: Union[str, List[str]] = None,
|
| 302 |
+
num_videos_per_prompt: int = 1,
|
| 303 |
+
max_sequence_length: int = 226,
|
| 304 |
+
device: Optional[torch.device] = None,
|
| 305 |
+
dtype: Optional[torch.dtype] = None,
|
| 306 |
+
):
|
| 307 |
+
device = device or self._execution_device
|
| 308 |
+
dtype = dtype or self.text_encoder.dtype
|
| 309 |
+
|
| 310 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 311 |
+
batch_size = len(prompt)
|
| 312 |
+
|
| 313 |
+
text_inputs = self.tokenizer(
|
| 314 |
+
prompt,
|
| 315 |
+
padding="max_length",
|
| 316 |
+
max_length=max_sequence_length,
|
| 317 |
+
truncation=True,
|
| 318 |
+
add_special_tokens=True,
|
| 319 |
+
return_tensors="pt",
|
| 320 |
+
)
|
| 321 |
+
text_input_ids = text_inputs.input_ids
|
| 322 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 323 |
+
|
| 324 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 325 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 326 |
+
logger.warning(
|
| 327 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 328 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
| 332 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 333 |
+
|
| 334 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 335 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 336 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 337 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 338 |
+
|
| 339 |
+
return prompt_embeds
|
| 340 |
+
|
| 341 |
+
def encode_prompt(
|
| 342 |
+
self,
|
| 343 |
+
prompt: Union[str, List[str]],
|
| 344 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 345 |
+
do_classifier_free_guidance: bool = True,
|
| 346 |
+
num_videos_per_prompt: int = 1,
|
| 347 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 348 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 349 |
+
max_sequence_length: int = 226,
|
| 350 |
+
device: Optional[torch.device] = None,
|
| 351 |
+
dtype: Optional[torch.dtype] = None,
|
| 352 |
+
):
|
| 353 |
+
r"""
|
| 354 |
+
Encodes the prompt into text encoder hidden states.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 358 |
+
prompt to be encoded
|
| 359 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 360 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 361 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 362 |
+
less than `1`).
|
| 363 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 364 |
+
Whether to use classifier free guidance or not.
|
| 365 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 366 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 367 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 368 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 369 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 370 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 371 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 372 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 373 |
+
argument.
|
| 374 |
+
device: (`torch.device`, *optional*):
|
| 375 |
+
torch device
|
| 376 |
+
dtype: (`torch.dtype`, *optional*):
|
| 377 |
+
torch dtype
|
| 378 |
+
"""
|
| 379 |
+
device = device or self._execution_device
|
| 380 |
+
|
| 381 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 382 |
+
if prompt is not None:
|
| 383 |
+
batch_size = len(prompt)
|
| 384 |
+
else:
|
| 385 |
+
batch_size = prompt_embeds.shape[0]
|
| 386 |
+
|
| 387 |
+
if prompt_embeds is None:
|
| 388 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 389 |
+
prompt=prompt,
|
| 390 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 391 |
+
max_sequence_length=max_sequence_length,
|
| 392 |
+
device=device,
|
| 393 |
+
dtype=dtype,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 397 |
+
negative_prompt = negative_prompt or ""
|
| 398 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 399 |
+
|
| 400 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 401 |
+
raise TypeError(
|
| 402 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 403 |
+
f" {type(prompt)}."
|
| 404 |
+
)
|
| 405 |
+
elif batch_size != len(negative_prompt):
|
| 406 |
+
raise ValueError(
|
| 407 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 408 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 409 |
+
" the batch size of `prompt`."
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 413 |
+
prompt=negative_prompt,
|
| 414 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 415 |
+
max_sequence_length=max_sequence_length,
|
| 416 |
+
device=device,
|
| 417 |
+
dtype=dtype,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
return prompt_embeds, negative_prompt_embeds
|
| 421 |
+
|
| 422 |
+
def prepare_latents(
|
| 423 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 424 |
+
):
|
| 425 |
+
shape = (
|
| 426 |
+
batch_size,
|
| 427 |
+
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
| 428 |
+
num_channels_latents,
|
| 429 |
+
height // self.vae_scale_factor_spatial,
|
| 430 |
+
width // self.vae_scale_factor_spatial,
|
| 431 |
+
)
|
| 432 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 433 |
+
raise ValueError(
|
| 434 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 435 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if latents is None:
|
| 439 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 440 |
+
else:
|
| 441 |
+
latents = latents.to(device)
|
| 442 |
+
|
| 443 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 444 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 445 |
+
return latents
|
| 446 |
+
|
| 447 |
+
def prepare_control_latents(
|
| 448 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
| 449 |
+
):
|
| 450 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
| 451 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 452 |
+
# and half precision
|
| 453 |
+
|
| 454 |
+
if mask is not None:
|
| 455 |
+
mask = mask.to(device=device, dtype=self.vae.dtype)
|
| 456 |
+
bs = 1
|
| 457 |
+
new_mask = []
|
| 458 |
+
for i in range(0, mask.shape[0], bs):
|
| 459 |
+
mask_bs = mask[i : i + bs]
|
| 460 |
+
mask_bs = self.vae.encode(mask_bs)[0]
|
| 461 |
+
mask_bs = mask_bs.mode()
|
| 462 |
+
new_mask.append(mask_bs)
|
| 463 |
+
mask = torch.cat(new_mask, dim = 0)
|
| 464 |
+
mask = mask * self.vae.config.scaling_factor
|
| 465 |
+
|
| 466 |
+
if masked_image is not None:
|
| 467 |
+
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
| 468 |
+
bs = 1
|
| 469 |
+
new_mask_pixel_values = []
|
| 470 |
+
for i in range(0, masked_image.shape[0], bs):
|
| 471 |
+
mask_pixel_values_bs = masked_image[i : i + bs]
|
| 472 |
+
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
| 473 |
+
mask_pixel_values_bs = mask_pixel_values_bs.mode()
|
| 474 |
+
new_mask_pixel_values.append(mask_pixel_values_bs)
|
| 475 |
+
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
| 476 |
+
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
| 477 |
+
else:
|
| 478 |
+
masked_image_latents = None
|
| 479 |
+
|
| 480 |
+
return mask, masked_image_latents
|
| 481 |
+
|
| 482 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 483 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
| 484 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 485 |
+
|
| 486 |
+
frames = self.vae.decode(latents).sample
|
| 487 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 488 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 489 |
+
frames = frames.cpu().float().numpy()
|
| 490 |
+
return frames
|
| 491 |
+
|
| 492 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 493 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 494 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 495 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 496 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 497 |
+
# and should be between [0, 1]
|
| 498 |
+
|
| 499 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 500 |
+
extra_step_kwargs = {}
|
| 501 |
+
if accepts_eta:
|
| 502 |
+
extra_step_kwargs["eta"] = eta
|
| 503 |
+
|
| 504 |
+
# check if the scheduler accepts generator
|
| 505 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 506 |
+
if accepts_generator:
|
| 507 |
+
extra_step_kwargs["generator"] = generator
|
| 508 |
+
return extra_step_kwargs
|
| 509 |
+
|
| 510 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 511 |
+
def check_inputs(
|
| 512 |
+
self,
|
| 513 |
+
prompt,
|
| 514 |
+
height,
|
| 515 |
+
width,
|
| 516 |
+
negative_prompt,
|
| 517 |
+
callback_on_step_end_tensor_inputs,
|
| 518 |
+
prompt_embeds=None,
|
| 519 |
+
negative_prompt_embeds=None,
|
| 520 |
+
):
|
| 521 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 522 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 523 |
+
|
| 524 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 525 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 526 |
+
):
|
| 527 |
+
raise ValueError(
|
| 528 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 529 |
+
)
|
| 530 |
+
if prompt is not None and prompt_embeds is not None:
|
| 531 |
+
raise ValueError(
|
| 532 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 533 |
+
" only forward one of the two."
|
| 534 |
+
)
|
| 535 |
+
elif prompt is None and prompt_embeds is None:
|
| 536 |
+
raise ValueError(
|
| 537 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 538 |
+
)
|
| 539 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 540 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 541 |
+
|
| 542 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 543 |
+
raise ValueError(
|
| 544 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 545 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 549 |
+
raise ValueError(
|
| 550 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 551 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 555 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 556 |
+
raise ValueError(
|
| 557 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 558 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 559 |
+
f" {negative_prompt_embeds.shape}."
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
def fuse_qkv_projections(self) -> None:
|
| 563 |
+
r"""Enables fused QKV projections."""
|
| 564 |
+
self.fusing_transformer = True
|
| 565 |
+
self.transformer.fuse_qkv_projections()
|
| 566 |
+
|
| 567 |
+
def unfuse_qkv_projections(self) -> None:
|
| 568 |
+
r"""Disable QKV projection fusion if enabled."""
|
| 569 |
+
if not self.fusing_transformer:
|
| 570 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 571 |
+
else:
|
| 572 |
+
self.transformer.unfuse_qkv_projections()
|
| 573 |
+
self.fusing_transformer = False
|
| 574 |
+
|
| 575 |
+
def _prepare_rotary_positional_embeddings(
|
| 576 |
+
self,
|
| 577 |
+
height: int,
|
| 578 |
+
width: int,
|
| 579 |
+
num_frames: int,
|
| 580 |
+
device: torch.device,
|
| 581 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 582 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 583 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 584 |
+
|
| 585 |
+
p = self.transformer.config.patch_size
|
| 586 |
+
p_t = self.transformer.config.patch_size_t
|
| 587 |
+
|
| 588 |
+
base_size_width = self.transformer.config.sample_width // p
|
| 589 |
+
base_size_height = self.transformer.config.sample_height // p
|
| 590 |
+
|
| 591 |
+
if p_t is None:
|
| 592 |
+
# CogVideoX 1.0
|
| 593 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 594 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 595 |
+
)
|
| 596 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 597 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 598 |
+
crops_coords=grid_crops_coords,
|
| 599 |
+
grid_size=(grid_height, grid_width),
|
| 600 |
+
temporal_size=num_frames,
|
| 601 |
+
)
|
| 602 |
+
else:
|
| 603 |
+
# CogVideoX 1.5
|
| 604 |
+
base_num_frames = (num_frames + p_t - 1) // p_t
|
| 605 |
+
|
| 606 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 607 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 608 |
+
crops_coords=None,
|
| 609 |
+
grid_size=(grid_height, grid_width),
|
| 610 |
+
temporal_size=base_num_frames,
|
| 611 |
+
grid_type="slice",
|
| 612 |
+
max_size=(base_size_height, base_size_width),
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
freqs_cos = freqs_cos.to(device=device)
|
| 616 |
+
freqs_sin = freqs_sin.to(device=device)
|
| 617 |
+
return freqs_cos, freqs_sin
|
| 618 |
+
|
| 619 |
+
@property
|
| 620 |
+
def guidance_scale(self):
|
| 621 |
+
return self._guidance_scale
|
| 622 |
+
|
| 623 |
+
@property
|
| 624 |
+
def num_timesteps(self):
|
| 625 |
+
return self._num_timesteps
|
| 626 |
+
|
| 627 |
+
@property
|
| 628 |
+
def interrupt(self):
|
| 629 |
+
return self._interrupt
|
| 630 |
+
|
| 631 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
| 632 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
| 633 |
+
# get the original timestep using init_timestep
|
| 634 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 635 |
+
|
| 636 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 637 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
| 638 |
+
|
| 639 |
+
return timesteps, num_inference_steps - t_start
|
| 640 |
+
|
| 641 |
+
@torch.no_grad()
|
| 642 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 643 |
+
def __call__(
|
| 644 |
+
self,
|
| 645 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 646 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 647 |
+
height: int = 480,
|
| 648 |
+
width: int = 720,
|
| 649 |
+
video: Union[torch.FloatTensor] = None,
|
| 650 |
+
control_video: Union[torch.FloatTensor] = None,
|
| 651 |
+
num_frames: int = 49,
|
| 652 |
+
num_inference_steps: int = 50,
|
| 653 |
+
timesteps: Optional[List[int]] = None,
|
| 654 |
+
guidance_scale: float = 6,
|
| 655 |
+
use_dynamic_cfg: bool = False,
|
| 656 |
+
num_videos_per_prompt: int = 1,
|
| 657 |
+
eta: float = 0.0,
|
| 658 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 659 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 660 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 661 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 662 |
+
output_type: str = "numpy",
|
| 663 |
+
return_dict: bool = False,
|
| 664 |
+
callback_on_step_end: Optional[
|
| 665 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 666 |
+
] = None,
|
| 667 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 668 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 669 |
+
max_sequence_length: int = 226,
|
| 670 |
+
comfyui_progressbar: bool = False,
|
| 671 |
+
) -> Union[CogVideoXFunPipelineOutput, Tuple]:
|
| 672 |
+
"""
|
| 673 |
+
Function invoked when calling the pipeline for generation.
|
| 674 |
+
|
| 675 |
+
Args:
|
| 676 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 677 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 678 |
+
instead.
|
| 679 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 680 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 681 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 682 |
+
less than `1`).
|
| 683 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 684 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 685 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 686 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 687 |
+
num_frames (`int`, defaults to `48`):
|
| 688 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
| 689 |
+
contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
|
| 690 |
+
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
| 691 |
+
needs to be satisfied is that of divisibility mentioned above.
|
| 692 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 693 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 694 |
+
expense of slower inference.
|
| 695 |
+
timesteps (`List[int]`, *optional*):
|
| 696 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 697 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 698 |
+
passed will be used. Must be in descending order.
|
| 699 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 700 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 701 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 702 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 703 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 704 |
+
usually at the expense of lower image quality.
|
| 705 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 706 |
+
The number of videos to generate per prompt.
|
| 707 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 708 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 709 |
+
to make generation deterministic.
|
| 710 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 711 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 712 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 713 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 714 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 715 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 716 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 717 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 718 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 719 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 720 |
+
argument.
|
| 721 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 722 |
+
The output format of the generate image. Choose between
|
| 723 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 724 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 725 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 726 |
+
of a plain tuple.
|
| 727 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 728 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 729 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 730 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 731 |
+
`callback_on_step_end_tensor_inputs`.
|
| 732 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 733 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 734 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 735 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 736 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 737 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 738 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 739 |
+
|
| 740 |
+
Examples:
|
| 741 |
+
|
| 742 |
+
Returns:
|
| 743 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`:
|
| 744 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a
|
| 745 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 746 |
+
"""
|
| 747 |
+
|
| 748 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 749 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 750 |
+
|
| 751 |
+
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
| 752 |
+
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
| 753 |
+
num_frames = num_frames or self.transformer.config.sample_frames
|
| 754 |
+
|
| 755 |
+
num_videos_per_prompt = 1
|
| 756 |
+
|
| 757 |
+
# 1. Check inputs. Raise error if not correct
|
| 758 |
+
self.check_inputs(
|
| 759 |
+
prompt,
|
| 760 |
+
height,
|
| 761 |
+
width,
|
| 762 |
+
negative_prompt,
|
| 763 |
+
callback_on_step_end_tensor_inputs,
|
| 764 |
+
prompt_embeds,
|
| 765 |
+
negative_prompt_embeds,
|
| 766 |
+
)
|
| 767 |
+
self._guidance_scale = guidance_scale
|
| 768 |
+
self._attention_kwargs = attention_kwargs
|
| 769 |
+
self._interrupt = False
|
| 770 |
+
|
| 771 |
+
# 2. Default call parameters
|
| 772 |
+
if prompt is not None and isinstance(prompt, str):
|
| 773 |
+
batch_size = 1
|
| 774 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 775 |
+
batch_size = len(prompt)
|
| 776 |
+
else:
|
| 777 |
+
batch_size = prompt_embeds.shape[0]
|
| 778 |
+
|
| 779 |
+
device = self._execution_device
|
| 780 |
+
|
| 781 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 782 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 783 |
+
# corresponds to doing no classifier free guidance.
|
| 784 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 785 |
+
|
| 786 |
+
# 3. Encode input prompt
|
| 787 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 788 |
+
prompt,
|
| 789 |
+
negative_prompt,
|
| 790 |
+
do_classifier_free_guidance,
|
| 791 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 792 |
+
prompt_embeds=prompt_embeds,
|
| 793 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 794 |
+
max_sequence_length=max_sequence_length,
|
| 795 |
+
device=device,
|
| 796 |
+
)
|
| 797 |
+
if do_classifier_free_guidance:
|
| 798 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 799 |
+
|
| 800 |
+
# 4. Prepare timesteps
|
| 801 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 802 |
+
self._num_timesteps = len(timesteps)
|
| 803 |
+
if comfyui_progressbar:
|
| 804 |
+
from comfy.utils import ProgressBar
|
| 805 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 806 |
+
|
| 807 |
+
if control_video is not None:
|
| 808 |
+
video_length = control_video.shape[2]
|
| 809 |
+
control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 810 |
+
control_video = control_video.to(dtype=torch.float32)
|
| 811 |
+
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 812 |
+
else:
|
| 813 |
+
control_video = None
|
| 814 |
+
|
| 815 |
+
# Magvae needs the number of frames to be 4n + 1.
|
| 816 |
+
local_latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 817 |
+
# For CogVideoX 1.5, the latent frames should be clipped to make it divisible by patch_size_t
|
| 818 |
+
patch_size_t = self.transformer.config.patch_size_t
|
| 819 |
+
additional_frames = 0
|
| 820 |
+
if patch_size_t is not None and local_latent_length % patch_size_t != 0:
|
| 821 |
+
additional_frames = local_latent_length % patch_size_t
|
| 822 |
+
num_frames -= additional_frames * self.vae_scale_factor_temporal
|
| 823 |
+
if num_frames <= 0:
|
| 824 |
+
num_frames = 1
|
| 825 |
+
if video_length > num_frames:
|
| 826 |
+
logger.warning("The length of condition video is not right, the latent frames should be clipped to make it divisible by patch_size_t. ")
|
| 827 |
+
video_length = num_frames
|
| 828 |
+
control_video = control_video[:, :, :video_length]
|
| 829 |
+
|
| 830 |
+
# 5. Prepare latents.
|
| 831 |
+
latent_channels = self.vae.config.latent_channels
|
| 832 |
+
latents = self.prepare_latents(
|
| 833 |
+
batch_size * num_videos_per_prompt,
|
| 834 |
+
latent_channels,
|
| 835 |
+
num_frames,
|
| 836 |
+
height,
|
| 837 |
+
width,
|
| 838 |
+
prompt_embeds.dtype,
|
| 839 |
+
device,
|
| 840 |
+
generator,
|
| 841 |
+
latents,
|
| 842 |
+
)
|
| 843 |
+
if comfyui_progressbar:
|
| 844 |
+
pbar.update(1)
|
| 845 |
+
|
| 846 |
+
control_video_latents = self.prepare_control_latents(
|
| 847 |
+
None,
|
| 848 |
+
control_video,
|
| 849 |
+
batch_size,
|
| 850 |
+
height,
|
| 851 |
+
width,
|
| 852 |
+
prompt_embeds.dtype,
|
| 853 |
+
device,
|
| 854 |
+
generator,
|
| 855 |
+
do_classifier_free_guidance
|
| 856 |
+
)[1]
|
| 857 |
+
control_video_latents_input = (
|
| 858 |
+
torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
|
| 859 |
+
)
|
| 860 |
+
control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w")
|
| 861 |
+
|
| 862 |
+
if comfyui_progressbar:
|
| 863 |
+
pbar.update(1)
|
| 864 |
+
|
| 865 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 866 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 867 |
+
|
| 868 |
+
# 7. Create rotary embeds if required
|
| 869 |
+
image_rotary_emb = (
|
| 870 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
| 871 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 872 |
+
else None
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
# 8. Denoising loop
|
| 876 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 877 |
+
|
| 878 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 879 |
+
# for DPM-solver++
|
| 880 |
+
old_pred_original_sample = None
|
| 881 |
+
for i, t in enumerate(timesteps):
|
| 882 |
+
if self.interrupt:
|
| 883 |
+
continue
|
| 884 |
+
|
| 885 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 886 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 887 |
+
|
| 888 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 889 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 890 |
+
|
| 891 |
+
# predict noise model_output
|
| 892 |
+
noise_pred = self.transformer(
|
| 893 |
+
hidden_states=latent_model_input,
|
| 894 |
+
encoder_hidden_states=prompt_embeds,
|
| 895 |
+
timestep=timestep,
|
| 896 |
+
image_rotary_emb=image_rotary_emb,
|
| 897 |
+
return_dict=False,
|
| 898 |
+
control_latents=control_latents,
|
| 899 |
+
)[0]
|
| 900 |
+
noise_pred = noise_pred.float()
|
| 901 |
+
|
| 902 |
+
# perform guidance
|
| 903 |
+
if use_dynamic_cfg:
|
| 904 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 905 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 906 |
+
)
|
| 907 |
+
if do_classifier_free_guidance:
|
| 908 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 909 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 910 |
+
|
| 911 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 912 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 913 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 914 |
+
else:
|
| 915 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 916 |
+
noise_pred,
|
| 917 |
+
old_pred_original_sample,
|
| 918 |
+
t,
|
| 919 |
+
timesteps[i - 1] if i > 0 else None,
|
| 920 |
+
latents,
|
| 921 |
+
**extra_step_kwargs,
|
| 922 |
+
return_dict=False,
|
| 923 |
+
)
|
| 924 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 925 |
+
|
| 926 |
+
# call the callback, if provided
|
| 927 |
+
if callback_on_step_end is not None:
|
| 928 |
+
callback_kwargs = {}
|
| 929 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 930 |
+
callback_kwargs[k] = locals()[k]
|
| 931 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 932 |
+
|
| 933 |
+
latents = callback_outputs.pop("latents", latents)
|
| 934 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 935 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 936 |
+
|
| 937 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 938 |
+
progress_bar.update()
|
| 939 |
+
if comfyui_progressbar:
|
| 940 |
+
pbar.update(1)
|
| 941 |
+
|
| 942 |
+
if output_type == "numpy":
|
| 943 |
+
video = self.decode_latents(latents)
|
| 944 |
+
elif not output_type == "latent":
|
| 945 |
+
video = self.decode_latents(latents)
|
| 946 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 947 |
+
else:
|
| 948 |
+
video = latents
|
| 949 |
+
|
| 950 |
+
# Offload all models
|
| 951 |
+
self.maybe_free_model_hooks()
|
| 952 |
+
|
| 953 |
+
if not return_dict:
|
| 954 |
+
video = torch.from_numpy(video)
|
| 955 |
+
|
| 956 |
+
return CogVideoXFunPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py
ADDED
|
@@ -0,0 +1,1136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 25 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 26 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 27 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 28 |
+
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 29 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 30 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 31 |
+
from diffusers.video_processor import VideoProcessor
|
| 32 |
+
from einops import rearrange
|
| 33 |
+
|
| 34 |
+
from ..models import (AutoencoderKLCogVideoX,
|
| 35 |
+
CogVideoXTransformer3DModel, T5EncoderModel,
|
| 36 |
+
T5Tokenizer)
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
EXAMPLE_DOC_STRING = """
|
| 42 |
+
Examples:
|
| 43 |
+
```python
|
| 44 |
+
pass
|
| 45 |
+
```
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
# Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed
|
| 49 |
+
def get_3d_rotary_pos_embed(
|
| 50 |
+
embed_dim,
|
| 51 |
+
crops_coords,
|
| 52 |
+
grid_size,
|
| 53 |
+
temporal_size,
|
| 54 |
+
theta: int = 10000,
|
| 55 |
+
use_real: bool = True,
|
| 56 |
+
grid_type: str = "linspace",
|
| 57 |
+
max_size: Optional[Tuple[int, int]] = None,
|
| 58 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 59 |
+
"""
|
| 60 |
+
RoPE for video tokens with 3D structure.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
embed_dim: (`int`):
|
| 64 |
+
The embedding dimension size, corresponding to hidden_size_head.
|
| 65 |
+
crops_coords (`Tuple[int]`):
|
| 66 |
+
The top-left and bottom-right coordinates of the crop.
|
| 67 |
+
grid_size (`Tuple[int]`):
|
| 68 |
+
The grid size of the spatial positional embedding (height, width).
|
| 69 |
+
temporal_size (`int`):
|
| 70 |
+
The size of the temporal dimension.
|
| 71 |
+
theta (`float`):
|
| 72 |
+
Scaling factor for frequency computation.
|
| 73 |
+
grid_type (`str`):
|
| 74 |
+
Whether to use "linspace" or "slice" to compute grids.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
| 78 |
+
"""
|
| 79 |
+
if use_real is not True:
|
| 80 |
+
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
| 81 |
+
|
| 82 |
+
if grid_type == "linspace":
|
| 83 |
+
start, stop = crops_coords
|
| 84 |
+
grid_size_h, grid_size_w = grid_size
|
| 85 |
+
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
| 86 |
+
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
| 87 |
+
grid_t = np.arange(temporal_size, dtype=np.float32)
|
| 88 |
+
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
| 89 |
+
elif grid_type == "slice":
|
| 90 |
+
max_h, max_w = max_size
|
| 91 |
+
grid_size_h, grid_size_w = grid_size
|
| 92 |
+
grid_h = np.arange(max_h, dtype=np.float32)
|
| 93 |
+
grid_w = np.arange(max_w, dtype=np.float32)
|
| 94 |
+
grid_t = np.arange(temporal_size, dtype=np.float32)
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError("Invalid value passed for `grid_type`.")
|
| 97 |
+
|
| 98 |
+
# Compute dimensions for each axis
|
| 99 |
+
dim_t = embed_dim // 4
|
| 100 |
+
dim_h = embed_dim // 8 * 3
|
| 101 |
+
dim_w = embed_dim // 8 * 3
|
| 102 |
+
|
| 103 |
+
# Temporal frequencies
|
| 104 |
+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
| 105 |
+
# Spatial frequencies for height and width
|
| 106 |
+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
| 107 |
+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
| 108 |
+
|
| 109 |
+
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
| 110 |
+
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
| 111 |
+
freqs_t = freqs_t[:, None, None, :].expand(
|
| 112 |
+
-1, grid_size_h, grid_size_w, -1
|
| 113 |
+
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
| 114 |
+
freqs_h = freqs_h[None, :, None, :].expand(
|
| 115 |
+
temporal_size, -1, grid_size_w, -1
|
| 116 |
+
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
| 117 |
+
freqs_w = freqs_w[None, None, :, :].expand(
|
| 118 |
+
temporal_size, grid_size_h, -1, -1
|
| 119 |
+
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
| 120 |
+
|
| 121 |
+
freqs = torch.cat(
|
| 122 |
+
[freqs_t, freqs_h, freqs_w], dim=-1
|
| 123 |
+
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
| 124 |
+
freqs = freqs.view(
|
| 125 |
+
temporal_size * grid_size_h * grid_size_w, -1
|
| 126 |
+
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
| 127 |
+
return freqs
|
| 128 |
+
|
| 129 |
+
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
| 130 |
+
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
| 131 |
+
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
| 132 |
+
|
| 133 |
+
if grid_type == "slice":
|
| 134 |
+
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
|
| 135 |
+
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
|
| 136 |
+
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
|
| 137 |
+
|
| 138 |
+
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
| 139 |
+
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
| 140 |
+
return cos, sin
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
| 144 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 145 |
+
tw = tgt_width
|
| 146 |
+
th = tgt_height
|
| 147 |
+
h, w = src
|
| 148 |
+
r = h / w
|
| 149 |
+
if r > (th / tw):
|
| 150 |
+
resize_height = th
|
| 151 |
+
resize_width = int(round(th / h * w))
|
| 152 |
+
else:
|
| 153 |
+
resize_width = tw
|
| 154 |
+
resize_height = int(round(tw / w * h))
|
| 155 |
+
|
| 156 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 157 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 158 |
+
|
| 159 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 163 |
+
def retrieve_timesteps(
|
| 164 |
+
scheduler,
|
| 165 |
+
num_inference_steps: Optional[int] = None,
|
| 166 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 167 |
+
timesteps: Optional[List[int]] = None,
|
| 168 |
+
sigmas: Optional[List[float]] = None,
|
| 169 |
+
**kwargs,
|
| 170 |
+
):
|
| 171 |
+
"""
|
| 172 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 173 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
scheduler (`SchedulerMixin`):
|
| 177 |
+
The scheduler to get timesteps from.
|
| 178 |
+
num_inference_steps (`int`):
|
| 179 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 180 |
+
must be `None`.
|
| 181 |
+
device (`str` or `torch.device`, *optional*):
|
| 182 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 183 |
+
timesteps (`List[int]`, *optional*):
|
| 184 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 185 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 186 |
+
sigmas (`List[float]`, *optional*):
|
| 187 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 188 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 192 |
+
second element is the number of inference steps.
|
| 193 |
+
"""
|
| 194 |
+
if timesteps is not None and sigmas is not None:
|
| 195 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 196 |
+
if timesteps is not None:
|
| 197 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 198 |
+
if not accepts_timesteps:
|
| 199 |
+
raise ValueError(
|
| 200 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 201 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 202 |
+
)
|
| 203 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 204 |
+
timesteps = scheduler.timesteps
|
| 205 |
+
num_inference_steps = len(timesteps)
|
| 206 |
+
elif sigmas is not None:
|
| 207 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 208 |
+
if not accept_sigmas:
|
| 209 |
+
raise ValueError(
|
| 210 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 211 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 212 |
+
)
|
| 213 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 214 |
+
timesteps = scheduler.timesteps
|
| 215 |
+
num_inference_steps = len(timesteps)
|
| 216 |
+
else:
|
| 217 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 218 |
+
timesteps = scheduler.timesteps
|
| 219 |
+
return timesteps, num_inference_steps
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def resize_mask(mask, latent, process_first_frame_only=True):
|
| 223 |
+
latent_size = latent.size()
|
| 224 |
+
batch_size, channels, num_frames, height, width = mask.shape
|
| 225 |
+
|
| 226 |
+
if process_first_frame_only:
|
| 227 |
+
target_size = list(latent_size[2:])
|
| 228 |
+
target_size[0] = 1
|
| 229 |
+
first_frame_resized = F.interpolate(
|
| 230 |
+
mask[:, :, 0:1, :, :],
|
| 231 |
+
size=target_size,
|
| 232 |
+
mode='trilinear',
|
| 233 |
+
align_corners=False
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
target_size = list(latent_size[2:])
|
| 237 |
+
target_size[0] = target_size[0] - 1
|
| 238 |
+
if target_size[0] != 0:
|
| 239 |
+
remaining_frames_resized = F.interpolate(
|
| 240 |
+
mask[:, :, 1:, :, :],
|
| 241 |
+
size=target_size,
|
| 242 |
+
mode='trilinear',
|
| 243 |
+
align_corners=False
|
| 244 |
+
)
|
| 245 |
+
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
| 246 |
+
else:
|
| 247 |
+
resized_mask = first_frame_resized
|
| 248 |
+
else:
|
| 249 |
+
target_size = list(latent_size[2:])
|
| 250 |
+
resized_mask = F.interpolate(
|
| 251 |
+
mask,
|
| 252 |
+
size=target_size,
|
| 253 |
+
mode='trilinear',
|
| 254 |
+
align_corners=False
|
| 255 |
+
)
|
| 256 |
+
return resized_mask
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def add_noise_to_reference_video(image, ratio=None):
|
| 260 |
+
if ratio is None:
|
| 261 |
+
sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
|
| 262 |
+
sigma = torch.exp(sigma).to(image.dtype)
|
| 263 |
+
else:
|
| 264 |
+
sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
|
| 265 |
+
|
| 266 |
+
image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
|
| 267 |
+
image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
|
| 268 |
+
image = image + image_noise
|
| 269 |
+
return image
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
@dataclass
|
| 273 |
+
class CogVideoXFunPipelineOutput(BaseOutput):
|
| 274 |
+
r"""
|
| 275 |
+
Output class for CogVideo pipelines.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 279 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 280 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 281 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
videos: torch.Tensor
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class CogVideoXFunInpaintPipeline(DiffusionPipeline):
|
| 288 |
+
r"""
|
| 289 |
+
Pipeline for text-to-video generation using CogVideoX.
|
| 290 |
+
|
| 291 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 292 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
vae ([`AutoencoderKL`]):
|
| 296 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 297 |
+
text_encoder ([`T5EncoderModel`]):
|
| 298 |
+
Frozen text-encoder. CogVideoX_Fun uses
|
| 299 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 300 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 301 |
+
tokenizer (`T5Tokenizer`):
|
| 302 |
+
Tokenizer of class
|
| 303 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 304 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
| 305 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
| 306 |
+
scheduler ([`SchedulerMixin`]):
|
| 307 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
_optional_components = []
|
| 311 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 312 |
+
|
| 313 |
+
_callback_tensor_inputs = [
|
| 314 |
+
"latents",
|
| 315 |
+
"prompt_embeds",
|
| 316 |
+
"negative_prompt_embeds",
|
| 317 |
+
]
|
| 318 |
+
|
| 319 |
+
def __init__(
|
| 320 |
+
self,
|
| 321 |
+
tokenizer: T5Tokenizer,
|
| 322 |
+
text_encoder: T5EncoderModel,
|
| 323 |
+
vae: AutoencoderKLCogVideoX,
|
| 324 |
+
transformer: CogVideoXTransformer3DModel,
|
| 325 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 326 |
+
):
|
| 327 |
+
super().__init__()
|
| 328 |
+
|
| 329 |
+
self.register_modules(
|
| 330 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 331 |
+
)
|
| 332 |
+
self.vae_scale_factor_spatial = (
|
| 333 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 334 |
+
)
|
| 335 |
+
self.vae_scale_factor_temporal = (
|
| 336 |
+
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 340 |
+
|
| 341 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 342 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 343 |
+
self.mask_processor = VaeImageProcessor(
|
| 344 |
+
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def _get_t5_prompt_embeds(
|
| 348 |
+
self,
|
| 349 |
+
prompt: Union[str, List[str]] = None,
|
| 350 |
+
num_videos_per_prompt: int = 1,
|
| 351 |
+
max_sequence_length: int = 226,
|
| 352 |
+
device: Optional[torch.device] = None,
|
| 353 |
+
dtype: Optional[torch.dtype] = None,
|
| 354 |
+
):
|
| 355 |
+
device = device or self._execution_device
|
| 356 |
+
dtype = dtype or self.text_encoder.dtype
|
| 357 |
+
|
| 358 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 359 |
+
batch_size = len(prompt)
|
| 360 |
+
|
| 361 |
+
text_inputs = self.tokenizer(
|
| 362 |
+
prompt,
|
| 363 |
+
padding="max_length",
|
| 364 |
+
max_length=max_sequence_length,
|
| 365 |
+
truncation=True,
|
| 366 |
+
add_special_tokens=True,
|
| 367 |
+
return_tensors="pt",
|
| 368 |
+
)
|
| 369 |
+
text_input_ids = text_inputs.input_ids
|
| 370 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 371 |
+
|
| 372 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 373 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 374 |
+
logger.warning(
|
| 375 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 376 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
| 380 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 381 |
+
|
| 382 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 383 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 384 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 385 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 386 |
+
|
| 387 |
+
return prompt_embeds
|
| 388 |
+
|
| 389 |
+
def encode_prompt(
|
| 390 |
+
self,
|
| 391 |
+
prompt: Union[str, List[str]],
|
| 392 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 393 |
+
do_classifier_free_guidance: bool = True,
|
| 394 |
+
num_videos_per_prompt: int = 1,
|
| 395 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 396 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 397 |
+
max_sequence_length: int = 226,
|
| 398 |
+
device: Optional[torch.device] = None,
|
| 399 |
+
dtype: Optional[torch.dtype] = None,
|
| 400 |
+
):
|
| 401 |
+
r"""
|
| 402 |
+
Encodes the prompt into text encoder hidden states.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 406 |
+
prompt to be encoded
|
| 407 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 408 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 409 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 410 |
+
less than `1`).
|
| 411 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 412 |
+
Whether to use classifier free guidance or not.
|
| 413 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 414 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 415 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 416 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 417 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 418 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 419 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 420 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 421 |
+
argument.
|
| 422 |
+
device: (`torch.device`, *optional*):
|
| 423 |
+
torch device
|
| 424 |
+
dtype: (`torch.dtype`, *optional*):
|
| 425 |
+
torch dtype
|
| 426 |
+
"""
|
| 427 |
+
device = device or self._execution_device
|
| 428 |
+
|
| 429 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 430 |
+
if prompt is not None:
|
| 431 |
+
batch_size = len(prompt)
|
| 432 |
+
else:
|
| 433 |
+
batch_size = prompt_embeds.shape[0]
|
| 434 |
+
|
| 435 |
+
if prompt_embeds is None:
|
| 436 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 437 |
+
prompt=prompt,
|
| 438 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 439 |
+
max_sequence_length=max_sequence_length,
|
| 440 |
+
device=device,
|
| 441 |
+
dtype=dtype,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 445 |
+
negative_prompt = negative_prompt or ""
|
| 446 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 447 |
+
|
| 448 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 449 |
+
raise TypeError(
|
| 450 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 451 |
+
f" {type(prompt)}."
|
| 452 |
+
)
|
| 453 |
+
elif batch_size != len(negative_prompt):
|
| 454 |
+
raise ValueError(
|
| 455 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 456 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 457 |
+
" the batch size of `prompt`."
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 461 |
+
prompt=negative_prompt,
|
| 462 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 463 |
+
max_sequence_length=max_sequence_length,
|
| 464 |
+
device=device,
|
| 465 |
+
dtype=dtype,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
return prompt_embeds, negative_prompt_embeds
|
| 469 |
+
|
| 470 |
+
def prepare_latents(
|
| 471 |
+
self,
|
| 472 |
+
batch_size,
|
| 473 |
+
num_channels_latents,
|
| 474 |
+
height,
|
| 475 |
+
width,
|
| 476 |
+
video_length,
|
| 477 |
+
dtype,
|
| 478 |
+
device,
|
| 479 |
+
generator,
|
| 480 |
+
latents=None,
|
| 481 |
+
video=None,
|
| 482 |
+
timestep=None,
|
| 483 |
+
is_strength_max=True,
|
| 484 |
+
return_noise=False,
|
| 485 |
+
return_video_latents=False,
|
| 486 |
+
):
|
| 487 |
+
shape = (
|
| 488 |
+
batch_size,
|
| 489 |
+
(video_length - 1) // self.vae_scale_factor_temporal + 1,
|
| 490 |
+
num_channels_latents,
|
| 491 |
+
height // self.vae_scale_factor_spatial,
|
| 492 |
+
width // self.vae_scale_factor_spatial,
|
| 493 |
+
)
|
| 494 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 495 |
+
raise ValueError(
|
| 496 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 497 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
if return_video_latents or (latents is None and not is_strength_max):
|
| 501 |
+
video = video.to(device=device, dtype=self.vae.dtype)
|
| 502 |
+
|
| 503 |
+
bs = 1
|
| 504 |
+
new_video = []
|
| 505 |
+
for i in range(0, video.shape[0], bs):
|
| 506 |
+
video_bs = video[i : i + bs]
|
| 507 |
+
video_bs = self.vae.encode(video_bs)[0]
|
| 508 |
+
video_bs = video_bs.sample()
|
| 509 |
+
new_video.append(video_bs)
|
| 510 |
+
video = torch.cat(new_video, dim = 0)
|
| 511 |
+
video = video * self.vae.config.scaling_factor
|
| 512 |
+
|
| 513 |
+
video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
|
| 514 |
+
video_latents = video_latents.to(device=device, dtype=dtype)
|
| 515 |
+
video_latents = rearrange(video_latents, "b c f h w -> b f c h w")
|
| 516 |
+
|
| 517 |
+
if latents is None:
|
| 518 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 519 |
+
# if strength is 1. then initialise the latents to noise, else initial to image + noise
|
| 520 |
+
latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
|
| 521 |
+
# if pure noise then scale the initial latents by the Scheduler's init sigma
|
| 522 |
+
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
|
| 523 |
+
else:
|
| 524 |
+
noise = latents.to(device)
|
| 525 |
+
latents = noise * self.scheduler.init_noise_sigma
|
| 526 |
+
|
| 527 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 528 |
+
outputs = (latents,)
|
| 529 |
+
|
| 530 |
+
if return_noise:
|
| 531 |
+
outputs += (noise,)
|
| 532 |
+
|
| 533 |
+
if return_video_latents:
|
| 534 |
+
outputs += (video_latents,)
|
| 535 |
+
|
| 536 |
+
return outputs
|
| 537 |
+
|
| 538 |
+
def prepare_mask_latents(
|
| 539 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
|
| 540 |
+
):
|
| 541 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
| 542 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 543 |
+
# and half precision
|
| 544 |
+
|
| 545 |
+
if mask is not None:
|
| 546 |
+
mask = mask.to(device=device, dtype=self.vae.dtype)
|
| 547 |
+
bs = 1
|
| 548 |
+
new_mask = []
|
| 549 |
+
for i in range(0, mask.shape[0], bs):
|
| 550 |
+
mask_bs = mask[i : i + bs]
|
| 551 |
+
mask_bs = self.vae.encode(mask_bs)[0]
|
| 552 |
+
mask_bs = mask_bs.mode()
|
| 553 |
+
new_mask.append(mask_bs)
|
| 554 |
+
mask = torch.cat(new_mask, dim = 0)
|
| 555 |
+
mask = mask * self.vae.config.scaling_factor
|
| 556 |
+
|
| 557 |
+
if masked_image is not None:
|
| 558 |
+
if self.transformer.config.add_noise_in_inpaint_model:
|
| 559 |
+
masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
|
| 560 |
+
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
| 561 |
+
bs = 1
|
| 562 |
+
new_mask_pixel_values = []
|
| 563 |
+
for i in range(0, masked_image.shape[0], bs):
|
| 564 |
+
mask_pixel_values_bs = masked_image[i : i + bs]
|
| 565 |
+
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
| 566 |
+
mask_pixel_values_bs = mask_pixel_values_bs.mode()
|
| 567 |
+
new_mask_pixel_values.append(mask_pixel_values_bs)
|
| 568 |
+
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
| 569 |
+
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
| 570 |
+
else:
|
| 571 |
+
masked_image_latents = None
|
| 572 |
+
|
| 573 |
+
return mask, masked_image_latents
|
| 574 |
+
|
| 575 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 576 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
| 577 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 578 |
+
|
| 579 |
+
frames = self.vae.decode(latents).sample
|
| 580 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 581 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 582 |
+
frames = frames.cpu().float().numpy()
|
| 583 |
+
return frames
|
| 584 |
+
|
| 585 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 586 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 587 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 588 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 589 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 590 |
+
# and should be between [0, 1]
|
| 591 |
+
|
| 592 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 593 |
+
extra_step_kwargs = {}
|
| 594 |
+
if accepts_eta:
|
| 595 |
+
extra_step_kwargs["eta"] = eta
|
| 596 |
+
|
| 597 |
+
# check if the scheduler accepts generator
|
| 598 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 599 |
+
if accepts_generator:
|
| 600 |
+
extra_step_kwargs["generator"] = generator
|
| 601 |
+
return extra_step_kwargs
|
| 602 |
+
|
| 603 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 604 |
+
def check_inputs(
|
| 605 |
+
self,
|
| 606 |
+
prompt,
|
| 607 |
+
height,
|
| 608 |
+
width,
|
| 609 |
+
negative_prompt,
|
| 610 |
+
callback_on_step_end_tensor_inputs,
|
| 611 |
+
prompt_embeds=None,
|
| 612 |
+
negative_prompt_embeds=None,
|
| 613 |
+
):
|
| 614 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 615 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 616 |
+
|
| 617 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 618 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 619 |
+
):
|
| 620 |
+
raise ValueError(
|
| 621 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 622 |
+
)
|
| 623 |
+
if prompt is not None and prompt_embeds is not None:
|
| 624 |
+
raise ValueError(
|
| 625 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 626 |
+
" only forward one of the two."
|
| 627 |
+
)
|
| 628 |
+
elif prompt is None and prompt_embeds is None:
|
| 629 |
+
raise ValueError(
|
| 630 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 631 |
+
)
|
| 632 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 633 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 634 |
+
|
| 635 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 636 |
+
raise ValueError(
|
| 637 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 638 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 642 |
+
raise ValueError(
|
| 643 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 644 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 648 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 649 |
+
raise ValueError(
|
| 650 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 651 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 652 |
+
f" {negative_prompt_embeds.shape}."
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
def fuse_qkv_projections(self) -> None:
|
| 656 |
+
r"""Enables fused QKV projections."""
|
| 657 |
+
self.fusing_transformer = True
|
| 658 |
+
self.transformer.fuse_qkv_projections()
|
| 659 |
+
|
| 660 |
+
def unfuse_qkv_projections(self) -> None:
|
| 661 |
+
r"""Disable QKV projection fusion if enabled."""
|
| 662 |
+
if not self.fusing_transformer:
|
| 663 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 664 |
+
else:
|
| 665 |
+
self.transformer.unfuse_qkv_projections()
|
| 666 |
+
self.fusing_transformer = False
|
| 667 |
+
|
| 668 |
+
def _prepare_rotary_positional_embeddings(
|
| 669 |
+
self,
|
| 670 |
+
height: int,
|
| 671 |
+
width: int,
|
| 672 |
+
num_frames: int,
|
| 673 |
+
device: torch.device,
|
| 674 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 675 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 676 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 677 |
+
|
| 678 |
+
p = self.transformer.config.patch_size
|
| 679 |
+
p_t = self.transformer.config.patch_size_t
|
| 680 |
+
|
| 681 |
+
base_size_width = self.transformer.config.sample_width // p
|
| 682 |
+
base_size_height = self.transformer.config.sample_height // p
|
| 683 |
+
|
| 684 |
+
if p_t is None:
|
| 685 |
+
# CogVideoX 1.0
|
| 686 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 687 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 688 |
+
)
|
| 689 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 690 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 691 |
+
crops_coords=grid_crops_coords,
|
| 692 |
+
grid_size=(grid_height, grid_width),
|
| 693 |
+
temporal_size=num_frames,
|
| 694 |
+
)
|
| 695 |
+
else:
|
| 696 |
+
# CogVideoX 1.5
|
| 697 |
+
base_num_frames = (num_frames + p_t - 1) // p_t
|
| 698 |
+
|
| 699 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 700 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 701 |
+
crops_coords=None,
|
| 702 |
+
grid_size=(grid_height, grid_width),
|
| 703 |
+
temporal_size=base_num_frames,
|
| 704 |
+
grid_type="slice",
|
| 705 |
+
max_size=(base_size_height, base_size_width),
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
freqs_cos = freqs_cos.to(device=device)
|
| 709 |
+
freqs_sin = freqs_sin.to(device=device)
|
| 710 |
+
return freqs_cos, freqs_sin
|
| 711 |
+
|
| 712 |
+
@property
|
| 713 |
+
def guidance_scale(self):
|
| 714 |
+
return self._guidance_scale
|
| 715 |
+
|
| 716 |
+
@property
|
| 717 |
+
def num_timesteps(self):
|
| 718 |
+
return self._num_timesteps
|
| 719 |
+
|
| 720 |
+
@property
|
| 721 |
+
def attention_kwargs(self):
|
| 722 |
+
return self._attention_kwargs
|
| 723 |
+
|
| 724 |
+
@property
|
| 725 |
+
def interrupt(self):
|
| 726 |
+
return self._interrupt
|
| 727 |
+
|
| 728 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
| 729 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
| 730 |
+
# get the original timestep using init_timestep
|
| 731 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 732 |
+
|
| 733 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 734 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
| 735 |
+
|
| 736 |
+
return timesteps, num_inference_steps - t_start
|
| 737 |
+
|
| 738 |
+
@torch.no_grad()
|
| 739 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 740 |
+
def __call__(
|
| 741 |
+
self,
|
| 742 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 743 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 744 |
+
height: int = 480,
|
| 745 |
+
width: int = 720,
|
| 746 |
+
video: Union[torch.FloatTensor] = None,
|
| 747 |
+
mask_video: Union[torch.FloatTensor] = None,
|
| 748 |
+
masked_video_latents: Union[torch.FloatTensor] = None,
|
| 749 |
+
num_frames: int = 49,
|
| 750 |
+
num_inference_steps: int = 50,
|
| 751 |
+
timesteps: Optional[List[int]] = None,
|
| 752 |
+
guidance_scale: float = 6,
|
| 753 |
+
use_dynamic_cfg: bool = False,
|
| 754 |
+
num_videos_per_prompt: int = 1,
|
| 755 |
+
eta: float = 0.0,
|
| 756 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 757 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 758 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 759 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 760 |
+
output_type: str = "numpy",
|
| 761 |
+
return_dict: bool = False,
|
| 762 |
+
callback_on_step_end: Optional[
|
| 763 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 764 |
+
] = None,
|
| 765 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 766 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 767 |
+
max_sequence_length: int = 226,
|
| 768 |
+
strength: float = 1,
|
| 769 |
+
noise_aug_strength: float = 0.0563,
|
| 770 |
+
comfyui_progressbar: bool = False,
|
| 771 |
+
) -> Union[CogVideoXFunPipelineOutput, Tuple]:
|
| 772 |
+
"""
|
| 773 |
+
Function invoked when calling the pipeline for generation.
|
| 774 |
+
|
| 775 |
+
Args:
|
| 776 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 777 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 778 |
+
instead.
|
| 779 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 780 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 781 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 782 |
+
less than `1`).
|
| 783 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 784 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 785 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 786 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 787 |
+
num_frames (`int`, defaults to `48`):
|
| 788 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
| 789 |
+
contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
|
| 790 |
+
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
| 791 |
+
needs to be satisfied is that of divisibility mentioned above.
|
| 792 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 793 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 794 |
+
expense of slower inference.
|
| 795 |
+
timesteps (`List[int]`, *optional*):
|
| 796 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 797 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 798 |
+
passed will be used. Must be in descending order.
|
| 799 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 800 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 801 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 802 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 803 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 804 |
+
usually at the expense of lower image quality.
|
| 805 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 806 |
+
The number of videos to generate per prompt.
|
| 807 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 808 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 809 |
+
to make generation deterministic.
|
| 810 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 811 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 812 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 813 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 814 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 815 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 816 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 817 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 818 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 819 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 820 |
+
argument.
|
| 821 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 822 |
+
The output format of the generate image. Choose between
|
| 823 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 824 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 825 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 826 |
+
of a plain tuple.
|
| 827 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 828 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 829 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 830 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 831 |
+
`callback_on_step_end_tensor_inputs`.
|
| 832 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 833 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 834 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 835 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 836 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 837 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 838 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 839 |
+
|
| 840 |
+
Examples:
|
| 841 |
+
|
| 842 |
+
Returns:
|
| 843 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`:
|
| 844 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a
|
| 845 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 846 |
+
"""
|
| 847 |
+
|
| 848 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 849 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 850 |
+
|
| 851 |
+
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
| 852 |
+
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
| 853 |
+
num_frames = num_frames or self.transformer.config.sample_frames
|
| 854 |
+
|
| 855 |
+
num_videos_per_prompt = 1
|
| 856 |
+
|
| 857 |
+
# 1. Check inputs. Raise error if not correct
|
| 858 |
+
self.check_inputs(
|
| 859 |
+
prompt,
|
| 860 |
+
height,
|
| 861 |
+
width,
|
| 862 |
+
negative_prompt,
|
| 863 |
+
callback_on_step_end_tensor_inputs,
|
| 864 |
+
prompt_embeds,
|
| 865 |
+
negative_prompt_embeds,
|
| 866 |
+
)
|
| 867 |
+
self._guidance_scale = guidance_scale
|
| 868 |
+
self._attention_kwargs = attention_kwargs
|
| 869 |
+
self._interrupt = False
|
| 870 |
+
|
| 871 |
+
# 2. Default call parameters
|
| 872 |
+
if prompt is not None and isinstance(prompt, str):
|
| 873 |
+
batch_size = 1
|
| 874 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 875 |
+
batch_size = len(prompt)
|
| 876 |
+
else:
|
| 877 |
+
batch_size = prompt_embeds.shape[0]
|
| 878 |
+
|
| 879 |
+
device = self._execution_device
|
| 880 |
+
|
| 881 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 882 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 883 |
+
# corresponds to doing no classifier free guidance.
|
| 884 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 885 |
+
|
| 886 |
+
# 3. Encode input prompt
|
| 887 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 888 |
+
prompt,
|
| 889 |
+
negative_prompt,
|
| 890 |
+
do_classifier_free_guidance,
|
| 891 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 892 |
+
prompt_embeds=prompt_embeds,
|
| 893 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 894 |
+
max_sequence_length=max_sequence_length,
|
| 895 |
+
device=device,
|
| 896 |
+
)
|
| 897 |
+
if do_classifier_free_guidance:
|
| 898 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 899 |
+
|
| 900 |
+
# 4. set timesteps
|
| 901 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 902 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
| 903 |
+
num_inference_steps=num_inference_steps, strength=strength, device=device
|
| 904 |
+
)
|
| 905 |
+
self._num_timesteps = len(timesteps)
|
| 906 |
+
if comfyui_progressbar:
|
| 907 |
+
from comfy.utils import ProgressBar
|
| 908 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 909 |
+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
| 910 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
| 911 |
+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
| 912 |
+
is_strength_max = strength == 1.0
|
| 913 |
+
|
| 914 |
+
# 5. Prepare latents.
|
| 915 |
+
if video is not None:
|
| 916 |
+
video_length = video.shape[2]
|
| 917 |
+
init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 918 |
+
init_video = init_video.to(dtype=torch.float32)
|
| 919 |
+
init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 920 |
+
else:
|
| 921 |
+
init_video = None
|
| 922 |
+
|
| 923 |
+
# Magvae needs the number of frames to be 4n + 1.
|
| 924 |
+
local_latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 925 |
+
# For CogVideoX 1.5, the latent frames should be clipped to make it divisible by patch_size_t
|
| 926 |
+
patch_size_t = self.transformer.config.patch_size_t
|
| 927 |
+
additional_frames = 0
|
| 928 |
+
if patch_size_t is not None and local_latent_length % patch_size_t != 0:
|
| 929 |
+
additional_frames = local_latent_length % patch_size_t
|
| 930 |
+
num_frames -= additional_frames * self.vae_scale_factor_temporal
|
| 931 |
+
if num_frames <= 0:
|
| 932 |
+
num_frames = 1
|
| 933 |
+
if video_length > num_frames:
|
| 934 |
+
logger.warning("The length of condition video is not right, the latent frames should be clipped to make it divisible by patch_size_t. ")
|
| 935 |
+
video_length = num_frames
|
| 936 |
+
video = video[:, :, :video_length]
|
| 937 |
+
init_video = init_video[:, :, :video_length]
|
| 938 |
+
mask_video = mask_video[:, :, :video_length]
|
| 939 |
+
|
| 940 |
+
num_channels_latents = self.vae.config.latent_channels
|
| 941 |
+
num_channels_transformer = self.transformer.config.in_channels
|
| 942 |
+
return_image_latents = num_channels_transformer == num_channels_latents
|
| 943 |
+
|
| 944 |
+
latents_outputs = self.prepare_latents(
|
| 945 |
+
batch_size * num_videos_per_prompt,
|
| 946 |
+
num_channels_latents,
|
| 947 |
+
height,
|
| 948 |
+
width,
|
| 949 |
+
video_length,
|
| 950 |
+
prompt_embeds.dtype,
|
| 951 |
+
device,
|
| 952 |
+
generator,
|
| 953 |
+
latents,
|
| 954 |
+
video=init_video,
|
| 955 |
+
timestep=latent_timestep,
|
| 956 |
+
is_strength_max=is_strength_max,
|
| 957 |
+
return_noise=True,
|
| 958 |
+
return_video_latents=return_image_latents,
|
| 959 |
+
)
|
| 960 |
+
if return_image_latents:
|
| 961 |
+
latents, noise, image_latents = latents_outputs
|
| 962 |
+
else:
|
| 963 |
+
latents, noise = latents_outputs
|
| 964 |
+
if comfyui_progressbar:
|
| 965 |
+
pbar.update(1)
|
| 966 |
+
|
| 967 |
+
if mask_video is not None:
|
| 968 |
+
if (mask_video == 255).all():
|
| 969 |
+
mask_latents = torch.zeros_like(latents)[:, :, :1].to(latents.device, latents.dtype)
|
| 970 |
+
masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
| 971 |
+
|
| 972 |
+
mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
|
| 973 |
+
masked_video_latents_input = (
|
| 974 |
+
torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
|
| 975 |
+
)
|
| 976 |
+
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
|
| 977 |
+
else:
|
| 978 |
+
# Prepare mask latent variables
|
| 979 |
+
video_length = video.shape[2]
|
| 980 |
+
mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 981 |
+
mask_condition = mask_condition.to(dtype=torch.float32)
|
| 982 |
+
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
|
| 983 |
+
|
| 984 |
+
if num_channels_transformer != num_channels_latents:
|
| 985 |
+
mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
|
| 986 |
+
if masked_video_latents is None:
|
| 987 |
+
masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
|
| 988 |
+
else:
|
| 989 |
+
masked_video = masked_video_latents
|
| 990 |
+
|
| 991 |
+
_, masked_video_latents = self.prepare_mask_latents(
|
| 992 |
+
None,
|
| 993 |
+
masked_video,
|
| 994 |
+
batch_size,
|
| 995 |
+
height,
|
| 996 |
+
width,
|
| 997 |
+
prompt_embeds.dtype,
|
| 998 |
+
device,
|
| 999 |
+
generator,
|
| 1000 |
+
do_classifier_free_guidance,
|
| 1001 |
+
noise_aug_strength=noise_aug_strength,
|
| 1002 |
+
)
|
| 1003 |
+
mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
|
| 1004 |
+
mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
|
| 1005 |
+
|
| 1006 |
+
mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
|
| 1007 |
+
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
|
| 1008 |
+
|
| 1009 |
+
mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
|
| 1010 |
+
masked_video_latents_input = (
|
| 1011 |
+
torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
mask = rearrange(mask, "b c f h w -> b f c h w")
|
| 1015 |
+
mask_input = rearrange(mask_input, "b c f h w -> b f c h w")
|
| 1016 |
+
masked_video_latents_input = rearrange(masked_video_latents_input, "b c f h w -> b f c h w")
|
| 1017 |
+
|
| 1018 |
+
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
|
| 1019 |
+
else:
|
| 1020 |
+
mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
|
| 1021 |
+
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
|
| 1022 |
+
mask = rearrange(mask, "b c f h w -> b f c h w")
|
| 1023 |
+
|
| 1024 |
+
inpaint_latents = None
|
| 1025 |
+
else:
|
| 1026 |
+
if num_channels_transformer != num_channels_latents:
|
| 1027 |
+
mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
| 1028 |
+
masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
| 1029 |
+
|
| 1030 |
+
mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
| 1031 |
+
masked_video_latents_input = (
|
| 1032 |
+
torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
|
| 1033 |
+
)
|
| 1034 |
+
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
|
| 1035 |
+
else:
|
| 1036 |
+
mask = torch.zeros_like(init_video[:, :1])
|
| 1037 |
+
mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
|
| 1038 |
+
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
|
| 1039 |
+
mask = rearrange(mask, "b c f h w -> b f c h w")
|
| 1040 |
+
|
| 1041 |
+
inpaint_latents = None
|
| 1042 |
+
if comfyui_progressbar:
|
| 1043 |
+
pbar.update(1)
|
| 1044 |
+
|
| 1045 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 1046 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1047 |
+
|
| 1048 |
+
# 7. Create rotary embeds if required
|
| 1049 |
+
image_rotary_emb = (
|
| 1050 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
| 1051 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 1052 |
+
else None
|
| 1053 |
+
)
|
| 1054 |
+
|
| 1055 |
+
# 8. Denoising loop
|
| 1056 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1057 |
+
|
| 1058 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1059 |
+
# for DPM-solver++
|
| 1060 |
+
old_pred_original_sample = None
|
| 1061 |
+
for i, t in enumerate(timesteps):
|
| 1062 |
+
if self.interrupt:
|
| 1063 |
+
continue
|
| 1064 |
+
|
| 1065 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 1066 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1067 |
+
|
| 1068 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 1069 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 1070 |
+
|
| 1071 |
+
# predict noise model_output
|
| 1072 |
+
noise_pred = self.transformer(
|
| 1073 |
+
hidden_states=latent_model_input,
|
| 1074 |
+
encoder_hidden_states=prompt_embeds,
|
| 1075 |
+
timestep=timestep,
|
| 1076 |
+
image_rotary_emb=image_rotary_emb,
|
| 1077 |
+
return_dict=False,
|
| 1078 |
+
inpaint_latents=inpaint_latents,
|
| 1079 |
+
)[0]
|
| 1080 |
+
noise_pred = noise_pred.float()
|
| 1081 |
+
|
| 1082 |
+
# perform guidance
|
| 1083 |
+
if use_dynamic_cfg:
|
| 1084 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 1085 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 1086 |
+
)
|
| 1087 |
+
if do_classifier_free_guidance:
|
| 1088 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1089 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1090 |
+
|
| 1091 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1092 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 1093 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 1094 |
+
else:
|
| 1095 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 1096 |
+
noise_pred,
|
| 1097 |
+
old_pred_original_sample,
|
| 1098 |
+
t,
|
| 1099 |
+
timesteps[i - 1] if i > 0 else None,
|
| 1100 |
+
latents,
|
| 1101 |
+
**extra_step_kwargs,
|
| 1102 |
+
return_dict=False,
|
| 1103 |
+
)
|
| 1104 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 1105 |
+
|
| 1106 |
+
# call the callback, if provided
|
| 1107 |
+
if callback_on_step_end is not None:
|
| 1108 |
+
callback_kwargs = {}
|
| 1109 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1110 |
+
callback_kwargs[k] = locals()[k]
|
| 1111 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1112 |
+
|
| 1113 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1114 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1115 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 1116 |
+
|
| 1117 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1118 |
+
progress_bar.update()
|
| 1119 |
+
if comfyui_progressbar:
|
| 1120 |
+
pbar.update(1)
|
| 1121 |
+
|
| 1122 |
+
if output_type == "numpy":
|
| 1123 |
+
video = self.decode_latents(latents)
|
| 1124 |
+
elif not output_type == "latent":
|
| 1125 |
+
video = self.decode_latents(latents)
|
| 1126 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 1127 |
+
else:
|
| 1128 |
+
video = latents
|
| 1129 |
+
|
| 1130 |
+
# Offload all models
|
| 1131 |
+
self.maybe_free_model_hooks()
|
| 1132 |
+
|
| 1133 |
+
if not return_dict:
|
| 1134 |
+
video = torch.from_numpy(video)
|
| 1135 |
+
|
| 1136 |
+
return CogVideoXFunPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_fantasy_talking.py
ADDED
|
@@ -0,0 +1,754 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
import copy
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torchvision.transforms.functional as TF
|
| 11 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 12 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 13 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 14 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 15 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 16 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 17 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 18 |
+
from diffusers.video_processor import VideoProcessor
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from torchvision import transforms
|
| 22 |
+
from transformers import T5Tokenizer
|
| 23 |
+
|
| 24 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
|
| 25 |
+
Wan2_2Transformer3DModel_S2V, WanAudioEncoder,
|
| 26 |
+
WanT5EncoderModel)
|
| 27 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 28 |
+
get_sampling_sigmas)
|
| 29 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
EXAMPLE_DOC_STRING = """
|
| 35 |
+
Examples:
|
| 36 |
+
```python
|
| 37 |
+
pass
|
| 38 |
+
```
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 43 |
+
def retrieve_timesteps(
|
| 44 |
+
scheduler,
|
| 45 |
+
num_inference_steps: Optional[int] = None,
|
| 46 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 47 |
+
timesteps: Optional[List[int]] = None,
|
| 48 |
+
sigmas: Optional[List[float]] = None,
|
| 49 |
+
**kwargs,
|
| 50 |
+
):
|
| 51 |
+
"""
|
| 52 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 53 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
scheduler (`SchedulerMixin`):
|
| 57 |
+
The scheduler to get timesteps from.
|
| 58 |
+
num_inference_steps (`int`):
|
| 59 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 60 |
+
must be `None`.
|
| 61 |
+
device (`str` or `torch.device`, *optional*):
|
| 62 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 63 |
+
timesteps (`List[int]`, *optional*):
|
| 64 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 65 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 66 |
+
sigmas (`List[float]`, *optional*):
|
| 67 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 68 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 72 |
+
second element is the number of inference steps.
|
| 73 |
+
"""
|
| 74 |
+
if timesteps is not None and sigmas is not None:
|
| 75 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 76 |
+
if timesteps is not None:
|
| 77 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 78 |
+
if not accepts_timesteps:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 81 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 82 |
+
)
|
| 83 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 84 |
+
timesteps = scheduler.timesteps
|
| 85 |
+
num_inference_steps = len(timesteps)
|
| 86 |
+
elif sigmas is not None:
|
| 87 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 88 |
+
if not accept_sigmas:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 91 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 92 |
+
)
|
| 93 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 94 |
+
timesteps = scheduler.timesteps
|
| 95 |
+
num_inference_steps = len(timesteps)
|
| 96 |
+
else:
|
| 97 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 98 |
+
timesteps = scheduler.timesteps
|
| 99 |
+
return timesteps, num_inference_steps
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def resize_mask(mask, latent, process_first_frame_only=True):
|
| 103 |
+
latent_size = latent.size()
|
| 104 |
+
batch_size, channels, num_frames, height, width = mask.shape
|
| 105 |
+
|
| 106 |
+
if process_first_frame_only:
|
| 107 |
+
target_size = list(latent_size[2:])
|
| 108 |
+
target_size[0] = 1
|
| 109 |
+
first_frame_resized = F.interpolate(
|
| 110 |
+
mask[:, :, 0:1, :, :],
|
| 111 |
+
size=target_size,
|
| 112 |
+
mode='trilinear',
|
| 113 |
+
align_corners=False
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
target_size = list(latent_size[2:])
|
| 117 |
+
target_size[0] = target_size[0] - 1
|
| 118 |
+
if target_size[0] != 0:
|
| 119 |
+
remaining_frames_resized = F.interpolate(
|
| 120 |
+
mask[:, :, 1:, :, :],
|
| 121 |
+
size=target_size,
|
| 122 |
+
mode='trilinear',
|
| 123 |
+
align_corners=False
|
| 124 |
+
)
|
| 125 |
+
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
| 126 |
+
else:
|
| 127 |
+
resized_mask = first_frame_resized
|
| 128 |
+
else:
|
| 129 |
+
target_size = list(latent_size[2:])
|
| 130 |
+
resized_mask = F.interpolate(
|
| 131 |
+
mask,
|
| 132 |
+
size=target_size,
|
| 133 |
+
mode='trilinear',
|
| 134 |
+
align_corners=False
|
| 135 |
+
)
|
| 136 |
+
return resized_mask
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@dataclass
|
| 140 |
+
class WanPipelineOutput(BaseOutput):
|
| 141 |
+
r"""
|
| 142 |
+
Output class for CogVideo pipelines.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 146 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 147 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 148 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
videos: torch.Tensor
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class FantasyTalkingPipeline(DiffusionPipeline):
|
| 155 |
+
r"""
|
| 156 |
+
Pipeline for text-to-video generation using Wan.
|
| 157 |
+
|
| 158 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 159 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
_optional_components = ["transformer_2", "audio_encoder"]
|
| 163 |
+
model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
|
| 164 |
+
|
| 165 |
+
_callback_tensor_inputs = [
|
| 166 |
+
"latents",
|
| 167 |
+
"prompt_embeds",
|
| 168 |
+
"negative_prompt_embeds",
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
tokenizer: AutoTokenizer,
|
| 174 |
+
text_encoder: WanT5EncoderModel,
|
| 175 |
+
audio_encoder: WanAudioEncoder,
|
| 176 |
+
vae: AutoencoderKLWan,
|
| 177 |
+
transformer: Wan2_2Transformer3DModel_S2V,
|
| 178 |
+
clip_image_encoder: CLIPModel,
|
| 179 |
+
transformer_2: Wan2_2Transformer3DModel_S2V = None,
|
| 180 |
+
scheduler: FlowMatchEulerDiscreteScheduler = None,
|
| 181 |
+
):
|
| 182 |
+
super().__init__()
|
| 183 |
+
|
| 184 |
+
self.register_modules(
|
| 185 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
|
| 186 |
+
transformer_2=transformer_2, scheduler=scheduler, clip_image_encoder=clip_image_encoder, audio_encoder=audio_encoder
|
| 187 |
+
)
|
| 188 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 189 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 190 |
+
self.mask_processor = VaeImageProcessor(
|
| 191 |
+
vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def _get_t5_prompt_embeds(
|
| 195 |
+
self,
|
| 196 |
+
prompt: Union[str, List[str]] = None,
|
| 197 |
+
num_videos_per_prompt: int = 1,
|
| 198 |
+
max_sequence_length: int = 512,
|
| 199 |
+
device: Optional[torch.device] = None,
|
| 200 |
+
dtype: Optional[torch.dtype] = None,
|
| 201 |
+
):
|
| 202 |
+
device = device or self._execution_device
|
| 203 |
+
dtype = dtype or self.text_encoder.dtype
|
| 204 |
+
|
| 205 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 206 |
+
batch_size = len(prompt)
|
| 207 |
+
|
| 208 |
+
text_inputs = self.tokenizer(
|
| 209 |
+
prompt,
|
| 210 |
+
padding="max_length",
|
| 211 |
+
max_length=max_sequence_length,
|
| 212 |
+
truncation=True,
|
| 213 |
+
add_special_tokens=True,
|
| 214 |
+
return_tensors="pt",
|
| 215 |
+
)
|
| 216 |
+
text_input_ids = text_inputs.input_ids
|
| 217 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 218 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 219 |
+
|
| 220 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 221 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 222 |
+
logger.warning(
|
| 223 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 224 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 228 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 229 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 230 |
+
|
| 231 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 232 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 233 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 234 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 235 |
+
|
| 236 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 237 |
+
|
| 238 |
+
def encode_prompt(
|
| 239 |
+
self,
|
| 240 |
+
prompt: Union[str, List[str]],
|
| 241 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 242 |
+
do_classifier_free_guidance: bool = True,
|
| 243 |
+
num_videos_per_prompt: int = 1,
|
| 244 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 245 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 246 |
+
max_sequence_length: int = 512,
|
| 247 |
+
device: Optional[torch.device] = None,
|
| 248 |
+
dtype: Optional[torch.dtype] = None,
|
| 249 |
+
):
|
| 250 |
+
r"""
|
| 251 |
+
Encodes the prompt into text encoder hidden states.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 255 |
+
prompt to be encoded
|
| 256 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 257 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 258 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 259 |
+
less than `1`).
|
| 260 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 261 |
+
Whether to use classifier free guidance or not.
|
| 262 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 263 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 264 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 265 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 266 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 267 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 268 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 269 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 270 |
+
argument.
|
| 271 |
+
device: (`torch.device`, *optional*):
|
| 272 |
+
torch device
|
| 273 |
+
dtype: (`torch.dtype`, *optional*):
|
| 274 |
+
torch dtype
|
| 275 |
+
"""
|
| 276 |
+
device = device or self._execution_device
|
| 277 |
+
|
| 278 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 279 |
+
if prompt is not None:
|
| 280 |
+
batch_size = len(prompt)
|
| 281 |
+
else:
|
| 282 |
+
batch_size = prompt_embeds.shape[0]
|
| 283 |
+
|
| 284 |
+
if prompt_embeds is None:
|
| 285 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 286 |
+
prompt=prompt,
|
| 287 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 288 |
+
max_sequence_length=max_sequence_length,
|
| 289 |
+
device=device,
|
| 290 |
+
dtype=dtype,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 294 |
+
negative_prompt = negative_prompt or ""
|
| 295 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 296 |
+
|
| 297 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 298 |
+
raise TypeError(
|
| 299 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 300 |
+
f" {type(prompt)}."
|
| 301 |
+
)
|
| 302 |
+
elif batch_size != len(negative_prompt):
|
| 303 |
+
raise ValueError(
|
| 304 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 305 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 306 |
+
" the batch size of `prompt`."
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 310 |
+
prompt=negative_prompt,
|
| 311 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 312 |
+
max_sequence_length=max_sequence_length,
|
| 313 |
+
device=device,
|
| 314 |
+
dtype=dtype,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
return prompt_embeds, negative_prompt_embeds
|
| 318 |
+
|
| 319 |
+
def prepare_latents(
|
| 320 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None
|
| 321 |
+
):
|
| 322 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 323 |
+
raise ValueError(
|
| 324 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 325 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
shape = (
|
| 329 |
+
batch_size,
|
| 330 |
+
num_channels_latents,
|
| 331 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents,
|
| 332 |
+
height // self.vae.spatial_compression_ratio,
|
| 333 |
+
width // self.vae.spatial_compression_ratio,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
if latents is None:
|
| 337 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 338 |
+
else:
|
| 339 |
+
latents = latents.to(device)
|
| 340 |
+
|
| 341 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 342 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 343 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 344 |
+
return latents
|
| 345 |
+
|
| 346 |
+
def prepare_mask_latents(
|
| 347 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
|
| 348 |
+
):
|
| 349 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
| 350 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 351 |
+
# and half precision
|
| 352 |
+
|
| 353 |
+
if mask is not None:
|
| 354 |
+
mask = mask.to(device=device, dtype=self.vae.dtype)
|
| 355 |
+
bs = 1
|
| 356 |
+
new_mask = []
|
| 357 |
+
for i in range(0, mask.shape[0], bs):
|
| 358 |
+
mask_bs = mask[i : i + bs]
|
| 359 |
+
mask_bs = self.vae.encode(mask_bs)[0]
|
| 360 |
+
mask_bs = mask_bs.mode()
|
| 361 |
+
new_mask.append(mask_bs)
|
| 362 |
+
mask = torch.cat(new_mask, dim = 0)
|
| 363 |
+
# mask = mask * self.vae.config.scaling_factor
|
| 364 |
+
|
| 365 |
+
if masked_image is not None:
|
| 366 |
+
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
| 367 |
+
bs = 1
|
| 368 |
+
new_mask_pixel_values = []
|
| 369 |
+
for i in range(0, masked_image.shape[0], bs):
|
| 370 |
+
mask_pixel_values_bs = masked_image[i : i + bs]
|
| 371 |
+
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
| 372 |
+
mask_pixel_values_bs = mask_pixel_values_bs.mode()
|
| 373 |
+
new_mask_pixel_values.append(mask_pixel_values_bs)
|
| 374 |
+
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
| 375 |
+
# masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
| 376 |
+
else:
|
| 377 |
+
masked_image_latents = None
|
| 378 |
+
|
| 379 |
+
return mask, masked_image_latents
|
| 380 |
+
|
| 381 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 382 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 383 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 384 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 385 |
+
frames = frames.cpu().float().numpy()
|
| 386 |
+
return frames
|
| 387 |
+
|
| 388 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 389 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 390 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 391 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 392 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 393 |
+
# and should be between [0, 1]
|
| 394 |
+
|
| 395 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 396 |
+
extra_step_kwargs = {}
|
| 397 |
+
if accepts_eta:
|
| 398 |
+
extra_step_kwargs["eta"] = eta
|
| 399 |
+
|
| 400 |
+
# check if the scheduler accepts generator
|
| 401 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 402 |
+
if accepts_generator:
|
| 403 |
+
extra_step_kwargs["generator"] = generator
|
| 404 |
+
return extra_step_kwargs
|
| 405 |
+
|
| 406 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 407 |
+
def check_inputs(
|
| 408 |
+
self,
|
| 409 |
+
prompt,
|
| 410 |
+
height,
|
| 411 |
+
width,
|
| 412 |
+
negative_prompt,
|
| 413 |
+
callback_on_step_end_tensor_inputs,
|
| 414 |
+
prompt_embeds=None,
|
| 415 |
+
negative_prompt_embeds=None,
|
| 416 |
+
):
|
| 417 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 418 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 419 |
+
|
| 420 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 421 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 422 |
+
):
|
| 423 |
+
raise ValueError(
|
| 424 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 425 |
+
)
|
| 426 |
+
if prompt is not None and prompt_embeds is not None:
|
| 427 |
+
raise ValueError(
|
| 428 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 429 |
+
" only forward one of the two."
|
| 430 |
+
)
|
| 431 |
+
elif prompt is None and prompt_embeds is None:
|
| 432 |
+
raise ValueError(
|
| 433 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 434 |
+
)
|
| 435 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 436 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 437 |
+
|
| 438 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 439 |
+
raise ValueError(
|
| 440 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 441 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 445 |
+
raise ValueError(
|
| 446 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 447 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 451 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 452 |
+
raise ValueError(
|
| 453 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 454 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 455 |
+
f" {negative_prompt_embeds.shape}."
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
@property
|
| 459 |
+
def guidance_scale(self):
|
| 460 |
+
return self._guidance_scale
|
| 461 |
+
|
| 462 |
+
@property
|
| 463 |
+
def num_timesteps(self):
|
| 464 |
+
return self._num_timesteps
|
| 465 |
+
|
| 466 |
+
@property
|
| 467 |
+
def attention_kwargs(self):
|
| 468 |
+
return self._attention_kwargs
|
| 469 |
+
|
| 470 |
+
@property
|
| 471 |
+
def interrupt(self):
|
| 472 |
+
return self._interrupt
|
| 473 |
+
|
| 474 |
+
@torch.no_grad()
|
| 475 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 476 |
+
def __call__(
|
| 477 |
+
self,
|
| 478 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 479 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 480 |
+
height: int = 480,
|
| 481 |
+
width: int = 720,
|
| 482 |
+
video: Union[torch.FloatTensor] = None,
|
| 483 |
+
mask_video: Union[torch.FloatTensor] = None,
|
| 484 |
+
audio_path = None,
|
| 485 |
+
num_frames: int = 49,
|
| 486 |
+
num_inference_steps: int = 50,
|
| 487 |
+
timesteps: Optional[List[int]] = None,
|
| 488 |
+
guidance_scale: float = 6,
|
| 489 |
+
num_videos_per_prompt: int = 1,
|
| 490 |
+
eta: float = 0.0,
|
| 491 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 492 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 493 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 494 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 495 |
+
output_type: str = "numpy",
|
| 496 |
+
return_dict: bool = False,
|
| 497 |
+
callback_on_step_end: Optional[
|
| 498 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 499 |
+
] = None,
|
| 500 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 501 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 502 |
+
clip_image: Image = None,
|
| 503 |
+
max_sequence_length: int = 512,
|
| 504 |
+
comfyui_progressbar: bool = False,
|
| 505 |
+
shift: int = 5,
|
| 506 |
+
fps: int = 16,
|
| 507 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 508 |
+
"""
|
| 509 |
+
Function invoked when calling the pipeline for generation.
|
| 510 |
+
Args:
|
| 511 |
+
|
| 512 |
+
Examples:
|
| 513 |
+
|
| 514 |
+
Returns:
|
| 515 |
+
|
| 516 |
+
"""
|
| 517 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 518 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 519 |
+
num_videos_per_prompt = 1
|
| 520 |
+
|
| 521 |
+
# 1. Check inputs. Raise error if not correct
|
| 522 |
+
self.check_inputs(
|
| 523 |
+
prompt,
|
| 524 |
+
height,
|
| 525 |
+
width,
|
| 526 |
+
negative_prompt,
|
| 527 |
+
callback_on_step_end_tensor_inputs,
|
| 528 |
+
prompt_embeds,
|
| 529 |
+
negative_prompt_embeds,
|
| 530 |
+
)
|
| 531 |
+
self._guidance_scale = guidance_scale
|
| 532 |
+
self._attention_kwargs = attention_kwargs
|
| 533 |
+
self._interrupt = False
|
| 534 |
+
|
| 535 |
+
# 2. Default call parameters
|
| 536 |
+
if prompt is not None and isinstance(prompt, str):
|
| 537 |
+
batch_size = 1
|
| 538 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 539 |
+
batch_size = len(prompt)
|
| 540 |
+
else:
|
| 541 |
+
batch_size = prompt_embeds.shape[0]
|
| 542 |
+
|
| 543 |
+
device = self._execution_device
|
| 544 |
+
weight_dtype = self.text_encoder.dtype
|
| 545 |
+
|
| 546 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 547 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 548 |
+
# corresponds to doing no classifier free guidance.
|
| 549 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 550 |
+
|
| 551 |
+
# 3. Encode input prompt
|
| 552 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 553 |
+
prompt,
|
| 554 |
+
negative_prompt,
|
| 555 |
+
do_classifier_free_guidance,
|
| 556 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 557 |
+
prompt_embeds=prompt_embeds,
|
| 558 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 559 |
+
max_sequence_length=max_sequence_length,
|
| 560 |
+
device=device,
|
| 561 |
+
)
|
| 562 |
+
if do_classifier_free_guidance:
|
| 563 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 564 |
+
else:
|
| 565 |
+
in_prompt_embeds = prompt_embeds
|
| 566 |
+
|
| 567 |
+
# 4. Prepare timesteps
|
| 568 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 569 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
| 570 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 571 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 572 |
+
timesteps = self.scheduler.timesteps
|
| 573 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 574 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 575 |
+
timesteps, _ = retrieve_timesteps(
|
| 576 |
+
self.scheduler,
|
| 577 |
+
device=device,
|
| 578 |
+
sigmas=sampling_sigmas)
|
| 579 |
+
else:
|
| 580 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 581 |
+
self._num_timesteps = len(timesteps)
|
| 582 |
+
if comfyui_progressbar:
|
| 583 |
+
from comfy.utils import ProgressBar
|
| 584 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 585 |
+
|
| 586 |
+
# 5. Prepare latents.
|
| 587 |
+
if video is not None:
|
| 588 |
+
video_length = video.shape[2]
|
| 589 |
+
init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 590 |
+
init_video = init_video.to(dtype=torch.float32)
|
| 591 |
+
init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 592 |
+
else:
|
| 593 |
+
init_video = None
|
| 594 |
+
|
| 595 |
+
latent_channels = self.vae.config.latent_channels
|
| 596 |
+
latents = self.prepare_latents(
|
| 597 |
+
batch_size * num_videos_per_prompt,
|
| 598 |
+
latent_channels,
|
| 599 |
+
num_frames,
|
| 600 |
+
height,
|
| 601 |
+
width,
|
| 602 |
+
weight_dtype,
|
| 603 |
+
device,
|
| 604 |
+
generator,
|
| 605 |
+
latents,
|
| 606 |
+
)
|
| 607 |
+
if comfyui_progressbar:
|
| 608 |
+
pbar.update(1)
|
| 609 |
+
|
| 610 |
+
# Prepare mask latent variables
|
| 611 |
+
if init_video is not None:
|
| 612 |
+
if (mask_video == 255).all():
|
| 613 |
+
mask_latents = torch.tile(
|
| 614 |
+
torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1]
|
| 615 |
+
)
|
| 616 |
+
masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
|
| 617 |
+
else:
|
| 618 |
+
bs, _, video_length, height, width = video.size()
|
| 619 |
+
mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 620 |
+
mask_condition = mask_condition.to(dtype=torch.float32)
|
| 621 |
+
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
|
| 622 |
+
|
| 623 |
+
masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5)
|
| 624 |
+
_, masked_video_latents = self.prepare_mask_latents(
|
| 625 |
+
None,
|
| 626 |
+
masked_video,
|
| 627 |
+
batch_size,
|
| 628 |
+
height,
|
| 629 |
+
width,
|
| 630 |
+
weight_dtype,
|
| 631 |
+
device,
|
| 632 |
+
generator,
|
| 633 |
+
do_classifier_free_guidance,
|
| 634 |
+
noise_aug_strength=None,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
mask_condition = torch.concat(
|
| 638 |
+
[
|
| 639 |
+
torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
|
| 640 |
+
mask_condition[:, :, 1:]
|
| 641 |
+
], dim=2
|
| 642 |
+
)
|
| 643 |
+
mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width)
|
| 644 |
+
mask_condition = mask_condition.transpose(1, 2)
|
| 645 |
+
mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype)
|
| 646 |
+
|
| 647 |
+
# Prepare clip latent variables
|
| 648 |
+
if clip_image is not None:
|
| 649 |
+
clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
|
| 650 |
+
clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
|
| 651 |
+
else:
|
| 652 |
+
clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0))
|
| 653 |
+
clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
|
| 654 |
+
clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
|
| 655 |
+
clip_context = torch.zeros_like(clip_context)
|
| 656 |
+
|
| 657 |
+
# Extract audio emb
|
| 658 |
+
audio_wav2vec_fea = self.audio_encoder.extract_audio_feat(audio_path, num_frames=num_frames, fps=fps)
|
| 659 |
+
|
| 660 |
+
if comfyui_progressbar:
|
| 661 |
+
pbar.update(1)
|
| 662 |
+
|
| 663 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 664 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 665 |
+
|
| 666 |
+
target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
|
| 667 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 668 |
+
# 7. Denoising loop
|
| 669 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 670 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 671 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 672 |
+
for i, t in enumerate(timesteps):
|
| 673 |
+
self.transformer.current_steps = i
|
| 674 |
+
|
| 675 |
+
if self.interrupt:
|
| 676 |
+
continue
|
| 677 |
+
|
| 678 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 679 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 680 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 681 |
+
|
| 682 |
+
if init_video is not None:
|
| 683 |
+
mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
|
| 684 |
+
masked_video_latents_input = (
|
| 685 |
+
torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
|
| 686 |
+
)
|
| 687 |
+
y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype)
|
| 688 |
+
|
| 689 |
+
clip_context_input = (
|
| 690 |
+
torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
audio_wav2vec_fea_input = (
|
| 694 |
+
torch.cat([audio_wav2vec_fea] * 2) if do_classifier_free_guidance else audio_wav2vec_fea
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
audio_scale = torch.tensor(
|
| 698 |
+
[0.75, 1]
|
| 699 |
+
).to(latent_model_input.device, latent_model_input.dtype)
|
| 700 |
+
|
| 701 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 702 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 703 |
+
|
| 704 |
+
# predict noise model_output
|
| 705 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 706 |
+
noise_pred = self.transformer(
|
| 707 |
+
x=latent_model_input,
|
| 708 |
+
context=in_prompt_embeds,
|
| 709 |
+
t=timestep,
|
| 710 |
+
seq_len=seq_len,
|
| 711 |
+
y=y,
|
| 712 |
+
audio_wav2vec_fea=audio_wav2vec_fea_input,
|
| 713 |
+
audio_scale=audio_scale,
|
| 714 |
+
clip_fea=clip_context_input,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
# perform guidance
|
| 718 |
+
if do_classifier_free_guidance:
|
| 719 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 720 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 721 |
+
|
| 722 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 723 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 724 |
+
|
| 725 |
+
if callback_on_step_end is not None:
|
| 726 |
+
callback_kwargs = {}
|
| 727 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 728 |
+
callback_kwargs[k] = locals()[k]
|
| 729 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 730 |
+
|
| 731 |
+
latents = callback_outputs.pop("latents", latents)
|
| 732 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 733 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 734 |
+
|
| 735 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 736 |
+
progress_bar.update()
|
| 737 |
+
if comfyui_progressbar:
|
| 738 |
+
pbar.update(1)
|
| 739 |
+
|
| 740 |
+
if output_type == "numpy":
|
| 741 |
+
video = self.decode_latents(latents)
|
| 742 |
+
elif not output_type == "latent":
|
| 743 |
+
video = self.decode_latents(latents)
|
| 744 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 745 |
+
else:
|
| 746 |
+
video = latents
|
| 747 |
+
|
| 748 |
+
# Offload all models
|
| 749 |
+
self.maybe_free_model_hooks()
|
| 750 |
+
|
| 751 |
+
if not return_dict:
|
| 752 |
+
video = torch.from_numpy(video)
|
| 753 |
+
|
| 754 |
+
return WanPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_flux.py
ADDED
|
@@ -0,0 +1,978 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux.py
|
| 2 |
+
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import PIL.Image
|
| 22 |
+
import torch
|
| 23 |
+
from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
|
| 24 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 25 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 26 |
+
from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
|
| 27 |
+
replace_example_docstring)
|
| 28 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 29 |
+
|
| 30 |
+
from ..models import (CLIPImageProcessor, CLIPTextModel,
|
| 31 |
+
CLIPTokenizer, CLIPVisionModelWithProjection,
|
| 32 |
+
FluxTransformer2DModel, T5EncoderModel, AutoencoderKL,
|
| 33 |
+
T5TokenizerFast)
|
| 34 |
+
|
| 35 |
+
if is_torch_xla_available():
|
| 36 |
+
import torch_xla.core.xla_model as xm
|
| 37 |
+
|
| 38 |
+
XLA_AVAILABLE = True
|
| 39 |
+
else:
|
| 40 |
+
XLA_AVAILABLE = False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 44 |
+
|
| 45 |
+
EXAMPLE_DOC_STRING = """
|
| 46 |
+
Examples:
|
| 47 |
+
```py
|
| 48 |
+
>>> import torch
|
| 49 |
+
>>> from diffusers import FluxPipeline
|
| 50 |
+
|
| 51 |
+
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
| 52 |
+
>>> pipe.to("cuda")
|
| 53 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
| 54 |
+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
| 55 |
+
>>> # Refer to the pipeline documentation for more details.
|
| 56 |
+
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
|
| 57 |
+
>>> image.save("flux.png")
|
| 58 |
+
```
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def calculate_shift(
|
| 63 |
+
image_seq_len,
|
| 64 |
+
base_seq_len: int = 256,
|
| 65 |
+
max_seq_len: int = 4096,
|
| 66 |
+
base_shift: float = 0.5,
|
| 67 |
+
max_shift: float = 1.15,
|
| 68 |
+
):
|
| 69 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 70 |
+
b = base_shift - m * base_seq_len
|
| 71 |
+
mu = image_seq_len * m + b
|
| 72 |
+
return mu
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 76 |
+
def retrieve_timesteps(
|
| 77 |
+
scheduler,
|
| 78 |
+
num_inference_steps: Optional[int] = None,
|
| 79 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 80 |
+
timesteps: Optional[List[int]] = None,
|
| 81 |
+
sigmas: Optional[List[float]] = None,
|
| 82 |
+
**kwargs,
|
| 83 |
+
):
|
| 84 |
+
r"""
|
| 85 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 86 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
scheduler (`SchedulerMixin`):
|
| 90 |
+
The scheduler to get timesteps from.
|
| 91 |
+
num_inference_steps (`int`):
|
| 92 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 93 |
+
must be `None`.
|
| 94 |
+
device (`str` or `torch.device`, *optional*):
|
| 95 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 96 |
+
timesteps (`List[int]`, *optional*):
|
| 97 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 98 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 99 |
+
sigmas (`List[float]`, *optional*):
|
| 100 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 101 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 105 |
+
second element is the number of inference steps.
|
| 106 |
+
"""
|
| 107 |
+
if timesteps is not None and sigmas is not None:
|
| 108 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 109 |
+
if timesteps is not None:
|
| 110 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 111 |
+
if not accepts_timesteps:
|
| 112 |
+
raise ValueError(
|
| 113 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 114 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 115 |
+
)
|
| 116 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 117 |
+
timesteps = scheduler.timesteps
|
| 118 |
+
num_inference_steps = len(timesteps)
|
| 119 |
+
elif sigmas is not None:
|
| 120 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 121 |
+
if not accept_sigmas:
|
| 122 |
+
raise ValueError(
|
| 123 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 124 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 125 |
+
)
|
| 126 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 127 |
+
timesteps = scheduler.timesteps
|
| 128 |
+
num_inference_steps = len(timesteps)
|
| 129 |
+
else:
|
| 130 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 131 |
+
timesteps = scheduler.timesteps
|
| 132 |
+
return timesteps, num_inference_steps
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@dataclass
|
| 136 |
+
class FluxPipelineOutput(BaseOutput):
|
| 137 |
+
"""
|
| 138 |
+
Output class for Flux image generation pipelines.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
|
| 142 |
+
List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
|
| 143 |
+
height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
|
| 144 |
+
pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
|
| 145 |
+
passed to the decoder.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@dataclass
|
| 152 |
+
class FluxPriorReduxPipelineOutput(BaseOutput):
|
| 153 |
+
"""
|
| 154 |
+
Output class for Flux Prior Redux pipelines.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
| 158 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
| 159 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
prompt_embeds: torch.Tensor
|
| 163 |
+
pooled_prompt_embeds: torch.Tensor
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class FluxPipeline(
|
| 167 |
+
DiffusionPipeline,
|
| 168 |
+
):
|
| 169 |
+
r"""
|
| 170 |
+
The Flux pipeline for text-to-image generation.
|
| 171 |
+
|
| 172 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
transformer ([`FluxTransformer2DModel`]):
|
| 176 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 177 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 178 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 179 |
+
vae ([`AutoencoderKL`]):
|
| 180 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 181 |
+
text_encoder ([`CLIPTextModel`]):
|
| 182 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 183 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 184 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
| 185 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
| 186 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 187 |
+
tokenizer (`CLIPTokenizer`):
|
| 188 |
+
Tokenizer of class
|
| 189 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 190 |
+
tokenizer_2 (`T5TokenizerFast`):
|
| 191 |
+
Second Tokenizer of class
|
| 192 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
| 196 |
+
_optional_components = ["image_encoder", "feature_extractor"]
|
| 197 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 198 |
+
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 202 |
+
vae: AutoencoderKL,
|
| 203 |
+
text_encoder: CLIPTextModel,
|
| 204 |
+
tokenizer: CLIPTokenizer,
|
| 205 |
+
text_encoder_2: T5EncoderModel,
|
| 206 |
+
tokenizer_2: T5TokenizerFast,
|
| 207 |
+
transformer: FluxTransformer2DModel,
|
| 208 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 209 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 210 |
+
):
|
| 211 |
+
super().__init__()
|
| 212 |
+
|
| 213 |
+
self.register_modules(
|
| 214 |
+
vae=vae,
|
| 215 |
+
text_encoder=text_encoder,
|
| 216 |
+
text_encoder_2=text_encoder_2,
|
| 217 |
+
tokenizer=tokenizer,
|
| 218 |
+
tokenizer_2=tokenizer_2,
|
| 219 |
+
transformer=transformer,
|
| 220 |
+
scheduler=scheduler,
|
| 221 |
+
image_encoder=image_encoder,
|
| 222 |
+
feature_extractor=feature_extractor,
|
| 223 |
+
)
|
| 224 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 225 |
+
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
| 226 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
| 227 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 228 |
+
self.tokenizer_max_length = (
|
| 229 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
| 230 |
+
)
|
| 231 |
+
self.default_sample_size = 128
|
| 232 |
+
|
| 233 |
+
def _get_t5_prompt_embeds(
|
| 234 |
+
self,
|
| 235 |
+
prompt: Union[str, List[str]] = None,
|
| 236 |
+
num_images_per_prompt: int = 1,
|
| 237 |
+
max_sequence_length: int = 512,
|
| 238 |
+
device: Optional[torch.device] = None,
|
| 239 |
+
dtype: Optional[torch.dtype] = None,
|
| 240 |
+
):
|
| 241 |
+
device = device or self._execution_device
|
| 242 |
+
dtype = dtype or self.text_encoder.dtype
|
| 243 |
+
|
| 244 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 245 |
+
batch_size = len(prompt)
|
| 246 |
+
|
| 247 |
+
text_inputs = self.tokenizer_2(
|
| 248 |
+
prompt,
|
| 249 |
+
padding="max_length",
|
| 250 |
+
max_length=max_sequence_length,
|
| 251 |
+
truncation=True,
|
| 252 |
+
return_length=False,
|
| 253 |
+
return_overflowing_tokens=False,
|
| 254 |
+
return_tensors="pt",
|
| 255 |
+
)
|
| 256 |
+
text_input_ids = text_inputs.input_ids
|
| 257 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
| 258 |
+
|
| 259 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 260 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 261 |
+
logger.warning(
|
| 262 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 263 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
| 267 |
+
|
| 268 |
+
dtype = self.text_encoder_2.dtype
|
| 269 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 270 |
+
|
| 271 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 272 |
+
|
| 273 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 274 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 275 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 276 |
+
|
| 277 |
+
return prompt_embeds
|
| 278 |
+
|
| 279 |
+
def _get_clip_prompt_embeds(
|
| 280 |
+
self,
|
| 281 |
+
prompt: Union[str, List[str]],
|
| 282 |
+
num_images_per_prompt: int = 1,
|
| 283 |
+
device: Optional[torch.device] = None,
|
| 284 |
+
):
|
| 285 |
+
device = device or self._execution_device
|
| 286 |
+
|
| 287 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 288 |
+
batch_size = len(prompt)
|
| 289 |
+
|
| 290 |
+
text_inputs = self.tokenizer(
|
| 291 |
+
prompt,
|
| 292 |
+
padding="max_length",
|
| 293 |
+
max_length=self.tokenizer_max_length,
|
| 294 |
+
truncation=True,
|
| 295 |
+
return_overflowing_tokens=False,
|
| 296 |
+
return_length=False,
|
| 297 |
+
return_tensors="pt",
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
text_input_ids = text_inputs.input_ids
|
| 301 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 302 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 303 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 304 |
+
logger.warning(
|
| 305 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 306 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 307 |
+
)
|
| 308 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
| 309 |
+
|
| 310 |
+
# Use pooled output of CLIPTextModel
|
| 311 |
+
prompt_embeds = prompt_embeds.pooler_output
|
| 312 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
| 313 |
+
|
| 314 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 315 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
| 316 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 317 |
+
|
| 318 |
+
return prompt_embeds
|
| 319 |
+
|
| 320 |
+
def encode_prompt(
|
| 321 |
+
self,
|
| 322 |
+
prompt: Union[str, List[str]],
|
| 323 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 324 |
+
device: Optional[torch.device] = None,
|
| 325 |
+
num_images_per_prompt: int = 1,
|
| 326 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 327 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 328 |
+
max_sequence_length: int = 512,
|
| 329 |
+
lora_scale: Optional[float] = None,
|
| 330 |
+
):
|
| 331 |
+
r"""
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 335 |
+
prompt to be encoded
|
| 336 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 337 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 338 |
+
used in all text-encoders
|
| 339 |
+
device: (`torch.device`):
|
| 340 |
+
torch device
|
| 341 |
+
num_images_per_prompt (`int`):
|
| 342 |
+
number of images that should be generated per prompt
|
| 343 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 344 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 345 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 346 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 347 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 348 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 349 |
+
lora_scale (`float`, *optional*):
|
| 350 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 351 |
+
"""
|
| 352 |
+
device = device or self._execution_device
|
| 353 |
+
|
| 354 |
+
# set lora scale so that monkey patched LoRA
|
| 355 |
+
# function of text encoder can correctly access it
|
| 356 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 357 |
+
|
| 358 |
+
if prompt_embeds is None:
|
| 359 |
+
prompt_2 = prompt_2 or prompt
|
| 360 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
| 361 |
+
|
| 362 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
| 363 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
| 364 |
+
prompt=prompt,
|
| 365 |
+
device=device,
|
| 366 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 367 |
+
)
|
| 368 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 369 |
+
prompt=prompt_2,
|
| 370 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 371 |
+
max_sequence_length=max_sequence_length,
|
| 372 |
+
device=device,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
| 376 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 377 |
+
|
| 378 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
| 379 |
+
|
| 380 |
+
def encode_image(self, image, device, num_images_per_prompt):
|
| 381 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 382 |
+
|
| 383 |
+
if not isinstance(image, torch.Tensor):
|
| 384 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 385 |
+
|
| 386 |
+
image = image.to(device=device, dtype=dtype)
|
| 387 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 388 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 389 |
+
return image_embeds
|
| 390 |
+
|
| 391 |
+
def prepare_ip_adapter_image_embeds(
|
| 392 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
| 393 |
+
):
|
| 394 |
+
image_embeds = []
|
| 395 |
+
if ip_adapter_image_embeds is None:
|
| 396 |
+
if not isinstance(ip_adapter_image, list):
|
| 397 |
+
ip_adapter_image = [ip_adapter_image]
|
| 398 |
+
|
| 399 |
+
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
| 400 |
+
raise ValueError(
|
| 401 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
for single_ip_adapter_image in ip_adapter_image:
|
| 405 |
+
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
| 406 |
+
image_embeds.append(single_image_embeds[None, :])
|
| 407 |
+
else:
|
| 408 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
| 409 |
+
ip_adapter_image_embeds = [ip_adapter_image_embeds]
|
| 410 |
+
|
| 411 |
+
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
| 412 |
+
raise ValueError(
|
| 413 |
+
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 417 |
+
image_embeds.append(single_image_embeds)
|
| 418 |
+
|
| 419 |
+
ip_adapter_image_embeds = []
|
| 420 |
+
for single_image_embeds in image_embeds:
|
| 421 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 422 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
| 423 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
| 424 |
+
|
| 425 |
+
return ip_adapter_image_embeds
|
| 426 |
+
|
| 427 |
+
def check_inputs(
|
| 428 |
+
self,
|
| 429 |
+
prompt,
|
| 430 |
+
prompt_2,
|
| 431 |
+
height,
|
| 432 |
+
width,
|
| 433 |
+
negative_prompt=None,
|
| 434 |
+
negative_prompt_2=None,
|
| 435 |
+
prompt_embeds=None,
|
| 436 |
+
negative_prompt_embeds=None,
|
| 437 |
+
pooled_prompt_embeds=None,
|
| 438 |
+
negative_pooled_prompt_embeds=None,
|
| 439 |
+
callback_on_step_end_tensor_inputs=None,
|
| 440 |
+
max_sequence_length=None,
|
| 441 |
+
):
|
| 442 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 443 |
+
logger.warning(
|
| 444 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 448 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 449 |
+
):
|
| 450 |
+
raise ValueError(
|
| 451 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
if prompt is not None and prompt_embeds is not None:
|
| 455 |
+
raise ValueError(
|
| 456 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 457 |
+
" only forward one of the two."
|
| 458 |
+
)
|
| 459 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 460 |
+
raise ValueError(
|
| 461 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 462 |
+
" only forward one of the two."
|
| 463 |
+
)
|
| 464 |
+
elif prompt is None and prompt_embeds is None:
|
| 465 |
+
raise ValueError(
|
| 466 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 467 |
+
)
|
| 468 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 469 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 470 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 471 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 472 |
+
|
| 473 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 474 |
+
raise ValueError(
|
| 475 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 476 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 477 |
+
)
|
| 478 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
| 479 |
+
raise ValueError(
|
| 480 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
| 481 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
| 485 |
+
raise ValueError(
|
| 486 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 487 |
+
)
|
| 488 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
| 489 |
+
raise ValueError(
|
| 490 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
| 494 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
| 495 |
+
|
| 496 |
+
@staticmethod
|
| 497 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
| 498 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
| 499 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
| 500 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
| 501 |
+
|
| 502 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 503 |
+
|
| 504 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 505 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 509 |
+
|
| 510 |
+
@staticmethod
|
| 511 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 512 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 513 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 514 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 515 |
+
|
| 516 |
+
return latents
|
| 517 |
+
|
| 518 |
+
@staticmethod
|
| 519 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 520 |
+
batch_size, num_patches, channels = latents.shape
|
| 521 |
+
|
| 522 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 523 |
+
# latent height and width to be divisible by 2.
|
| 524 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 525 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 526 |
+
|
| 527 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 528 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 529 |
+
|
| 530 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
| 531 |
+
|
| 532 |
+
return latents
|
| 533 |
+
|
| 534 |
+
def enable_vae_slicing(self):
|
| 535 |
+
r"""
|
| 536 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 537 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 538 |
+
"""
|
| 539 |
+
self.vae.enable_slicing()
|
| 540 |
+
|
| 541 |
+
def disable_vae_slicing(self):
|
| 542 |
+
r"""
|
| 543 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 544 |
+
computing decoding in one step.
|
| 545 |
+
"""
|
| 546 |
+
self.vae.disable_slicing()
|
| 547 |
+
|
| 548 |
+
def enable_vae_tiling(self):
|
| 549 |
+
r"""
|
| 550 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 551 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 552 |
+
processing larger images.
|
| 553 |
+
"""
|
| 554 |
+
self.vae.enable_tiling()
|
| 555 |
+
|
| 556 |
+
def disable_vae_tiling(self):
|
| 557 |
+
r"""
|
| 558 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 559 |
+
computing decoding in one step.
|
| 560 |
+
"""
|
| 561 |
+
self.vae.disable_tiling()
|
| 562 |
+
|
| 563 |
+
def prepare_latents(
|
| 564 |
+
self,
|
| 565 |
+
batch_size,
|
| 566 |
+
num_channels_latents,
|
| 567 |
+
height,
|
| 568 |
+
width,
|
| 569 |
+
dtype,
|
| 570 |
+
device,
|
| 571 |
+
generator,
|
| 572 |
+
latents=None,
|
| 573 |
+
):
|
| 574 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 575 |
+
# latent height and width to be divisible by 2.
|
| 576 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 577 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 578 |
+
|
| 579 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 580 |
+
|
| 581 |
+
if latents is not None:
|
| 582 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
| 583 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
| 584 |
+
|
| 585 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 586 |
+
raise ValueError(
|
| 587 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 588 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 592 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 593 |
+
|
| 594 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
| 595 |
+
|
| 596 |
+
return latents, latent_image_ids
|
| 597 |
+
|
| 598 |
+
@property
|
| 599 |
+
def guidance_scale(self):
|
| 600 |
+
return self._guidance_scale
|
| 601 |
+
|
| 602 |
+
@property
|
| 603 |
+
def joint_attention_kwargs(self):
|
| 604 |
+
return self._joint_attention_kwargs
|
| 605 |
+
|
| 606 |
+
@property
|
| 607 |
+
def num_timesteps(self):
|
| 608 |
+
return self._num_timesteps
|
| 609 |
+
|
| 610 |
+
@property
|
| 611 |
+
def current_timestep(self):
|
| 612 |
+
return self._current_timestep
|
| 613 |
+
|
| 614 |
+
@property
|
| 615 |
+
def interrupt(self):
|
| 616 |
+
return self._interrupt
|
| 617 |
+
|
| 618 |
+
@torch.no_grad()
|
| 619 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 620 |
+
def __call__(
|
| 621 |
+
self,
|
| 622 |
+
prompt: Union[str, List[str]] = None,
|
| 623 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 624 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 625 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 626 |
+
true_cfg_scale: float = 1.0,
|
| 627 |
+
height: Optional[int] = None,
|
| 628 |
+
width: Optional[int] = None,
|
| 629 |
+
num_inference_steps: int = 28,
|
| 630 |
+
sigmas: Optional[List[float]] = None,
|
| 631 |
+
guidance_scale: float = 3.5,
|
| 632 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 633 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 634 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 635 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 636 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 637 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 638 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 639 |
+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 640 |
+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 641 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 642 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 643 |
+
output_type: Optional[str] = "pil",
|
| 644 |
+
return_dict: bool = True,
|
| 645 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 646 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 647 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 648 |
+
max_sequence_length: int = 512,
|
| 649 |
+
):
|
| 650 |
+
r"""
|
| 651 |
+
Function invoked when calling the pipeline for generation.
|
| 652 |
+
|
| 653 |
+
Args:
|
| 654 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 655 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 656 |
+
instead.
|
| 657 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 658 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 659 |
+
will be used instead.
|
| 660 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 661 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 662 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
| 663 |
+
not greater than `1`).
|
| 664 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 665 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 666 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 667 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
| 668 |
+
True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
|
| 669 |
+
`negative_prompt` is provided.
|
| 670 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 671 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 672 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 673 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 674 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 675 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 676 |
+
expense of slower inference.
|
| 677 |
+
sigmas (`List[float]`, *optional*):
|
| 678 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 679 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 680 |
+
will be used.
|
| 681 |
+
guidance_scale (`float`, *optional*, defaults to 3.5):
|
| 682 |
+
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
| 683 |
+
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
| 684 |
+
|
| 685 |
+
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
| 686 |
+
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
| 687 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 688 |
+
The number of images to generate per prompt.
|
| 689 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 690 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 691 |
+
to make generation deterministic.
|
| 692 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 693 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 694 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 695 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 696 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 697 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 698 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 699 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 700 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 701 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 702 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 703 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 704 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 705 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 706 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 707 |
+
negative_ip_adapter_image:
|
| 708 |
+
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 709 |
+
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 710 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 711 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 712 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 713 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 714 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 715 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 716 |
+
argument.
|
| 717 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 718 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 719 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 720 |
+
input argument.
|
| 721 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 722 |
+
The output format of the generate image. Choose between
|
| 723 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 724 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 725 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
| 726 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 727 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 728 |
+
`self.processor` in
|
| 729 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 730 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 731 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 732 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 733 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 734 |
+
`callback_on_step_end_tensor_inputs`.
|
| 735 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 736 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 737 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 738 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 739 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 740 |
+
|
| 741 |
+
Examples:
|
| 742 |
+
|
| 743 |
+
Returns:
|
| 744 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 745 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 746 |
+
images.
|
| 747 |
+
"""
|
| 748 |
+
|
| 749 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 750 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 751 |
+
|
| 752 |
+
# 1. Check inputs. Raise error if not correct
|
| 753 |
+
self.check_inputs(
|
| 754 |
+
prompt,
|
| 755 |
+
prompt_2,
|
| 756 |
+
height,
|
| 757 |
+
width,
|
| 758 |
+
negative_prompt=negative_prompt,
|
| 759 |
+
negative_prompt_2=negative_prompt_2,
|
| 760 |
+
prompt_embeds=prompt_embeds,
|
| 761 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 762 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 763 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 764 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 765 |
+
max_sequence_length=max_sequence_length,
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
self._guidance_scale = guidance_scale
|
| 769 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 770 |
+
self._current_timestep = None
|
| 771 |
+
self._interrupt = False
|
| 772 |
+
|
| 773 |
+
# 2. Define call parameters
|
| 774 |
+
if prompt is not None and isinstance(prompt, str):
|
| 775 |
+
batch_size = 1
|
| 776 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 777 |
+
batch_size = len(prompt)
|
| 778 |
+
else:
|
| 779 |
+
batch_size = prompt_embeds.shape[0]
|
| 780 |
+
|
| 781 |
+
device = self._execution_device
|
| 782 |
+
|
| 783 |
+
lora_scale = (
|
| 784 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 785 |
+
)
|
| 786 |
+
has_neg_prompt = negative_prompt is not None or (
|
| 787 |
+
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
|
| 788 |
+
)
|
| 789 |
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 790 |
+
(
|
| 791 |
+
prompt_embeds,
|
| 792 |
+
pooled_prompt_embeds,
|
| 793 |
+
text_ids,
|
| 794 |
+
) = self.encode_prompt(
|
| 795 |
+
prompt=prompt,
|
| 796 |
+
prompt_2=prompt_2,
|
| 797 |
+
prompt_embeds=prompt_embeds,
|
| 798 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 799 |
+
device=device,
|
| 800 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 801 |
+
max_sequence_length=max_sequence_length,
|
| 802 |
+
lora_scale=lora_scale,
|
| 803 |
+
)
|
| 804 |
+
if do_true_cfg:
|
| 805 |
+
(
|
| 806 |
+
negative_prompt_embeds,
|
| 807 |
+
negative_pooled_prompt_embeds,
|
| 808 |
+
negative_text_ids,
|
| 809 |
+
) = self.encode_prompt(
|
| 810 |
+
prompt=negative_prompt,
|
| 811 |
+
prompt_2=negative_prompt_2,
|
| 812 |
+
prompt_embeds=negative_prompt_embeds,
|
| 813 |
+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 814 |
+
device=device,
|
| 815 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 816 |
+
max_sequence_length=max_sequence_length,
|
| 817 |
+
lora_scale=lora_scale,
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
# 4. Prepare latent variables
|
| 821 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 822 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 823 |
+
batch_size * num_images_per_prompt,
|
| 824 |
+
num_channels_latents,
|
| 825 |
+
height,
|
| 826 |
+
width,
|
| 827 |
+
prompt_embeds.dtype,
|
| 828 |
+
device,
|
| 829 |
+
generator,
|
| 830 |
+
latents,
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
# 5. Prepare timesteps
|
| 834 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 835 |
+
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
|
| 836 |
+
sigmas = None
|
| 837 |
+
image_seq_len = latents.shape[1]
|
| 838 |
+
mu = calculate_shift(
|
| 839 |
+
image_seq_len,
|
| 840 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 841 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 842 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 843 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 844 |
+
)
|
| 845 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 846 |
+
self.scheduler,
|
| 847 |
+
num_inference_steps,
|
| 848 |
+
device,
|
| 849 |
+
sigmas=sigmas,
|
| 850 |
+
mu=mu,
|
| 851 |
+
)
|
| 852 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 853 |
+
self._num_timesteps = len(timesteps)
|
| 854 |
+
|
| 855 |
+
# handle guidance
|
| 856 |
+
if self.transformer.config.guidance_embeds:
|
| 857 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 858 |
+
guidance = guidance.expand(latents.shape[0])
|
| 859 |
+
else:
|
| 860 |
+
guidance = None
|
| 861 |
+
|
| 862 |
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
| 863 |
+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
| 864 |
+
):
|
| 865 |
+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 866 |
+
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 867 |
+
|
| 868 |
+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
| 869 |
+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
| 870 |
+
):
|
| 871 |
+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 872 |
+
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 873 |
+
|
| 874 |
+
if self.joint_attention_kwargs is None:
|
| 875 |
+
self._joint_attention_kwargs = {}
|
| 876 |
+
|
| 877 |
+
image_embeds = None
|
| 878 |
+
negative_image_embeds = None
|
| 879 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 880 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 881 |
+
ip_adapter_image,
|
| 882 |
+
ip_adapter_image_embeds,
|
| 883 |
+
device,
|
| 884 |
+
batch_size * num_images_per_prompt,
|
| 885 |
+
)
|
| 886 |
+
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
| 887 |
+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 888 |
+
negative_ip_adapter_image,
|
| 889 |
+
negative_ip_adapter_image_embeds,
|
| 890 |
+
device,
|
| 891 |
+
batch_size * num_images_per_prompt,
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
# 6. Denoising loop
|
| 895 |
+
# We set the index here to remove DtoH sync, helpful especially during compilation.
|
| 896 |
+
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
|
| 897 |
+
self.scheduler.set_begin_index(0)
|
| 898 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 899 |
+
for i, t in enumerate(timesteps):
|
| 900 |
+
if self.interrupt:
|
| 901 |
+
continue
|
| 902 |
+
|
| 903 |
+
self._current_timestep = t
|
| 904 |
+
if image_embeds is not None:
|
| 905 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
| 906 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 907 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 908 |
+
|
| 909 |
+
with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device):
|
| 910 |
+
noise_pred = self.transformer(
|
| 911 |
+
hidden_states=latents,
|
| 912 |
+
timestep=timestep / 1000,
|
| 913 |
+
guidance=guidance,
|
| 914 |
+
pooled_projections=pooled_prompt_embeds,
|
| 915 |
+
encoder_hidden_states=prompt_embeds,
|
| 916 |
+
txt_ids=text_ids,
|
| 917 |
+
img_ids=latent_image_ids,
|
| 918 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 919 |
+
return_dict=False,
|
| 920 |
+
)[0]
|
| 921 |
+
|
| 922 |
+
if do_true_cfg:
|
| 923 |
+
with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device):
|
| 924 |
+
neg_noise_pred = self.transformer(
|
| 925 |
+
hidden_states=latents,
|
| 926 |
+
timestep=timestep / 1000,
|
| 927 |
+
guidance=guidance,
|
| 928 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
| 929 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 930 |
+
txt_ids=negative_text_ids,
|
| 931 |
+
img_ids=latent_image_ids,
|
| 932 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 933 |
+
return_dict=False,
|
| 934 |
+
)[0]
|
| 935 |
+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 936 |
+
|
| 937 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 938 |
+
latents_dtype = latents.dtype
|
| 939 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 940 |
+
|
| 941 |
+
if latents.dtype != latents_dtype:
|
| 942 |
+
if torch.backends.mps.is_available():
|
| 943 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 944 |
+
latents = latents.to(latents_dtype)
|
| 945 |
+
|
| 946 |
+
if callback_on_step_end is not None:
|
| 947 |
+
callback_kwargs = {}
|
| 948 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 949 |
+
callback_kwargs[k] = locals()[k]
|
| 950 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 951 |
+
|
| 952 |
+
latents = callback_outputs.pop("latents", latents)
|
| 953 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 954 |
+
|
| 955 |
+
# call the callback, if provided
|
| 956 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 957 |
+
progress_bar.update()
|
| 958 |
+
|
| 959 |
+
if XLA_AVAILABLE:
|
| 960 |
+
xm.mark_step()
|
| 961 |
+
|
| 962 |
+
self._current_timestep = None
|
| 963 |
+
|
| 964 |
+
if output_type == "latent":
|
| 965 |
+
image = latents
|
| 966 |
+
else:
|
| 967 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 968 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 969 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 970 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 971 |
+
|
| 972 |
+
# Offload all models
|
| 973 |
+
self.maybe_free_model_hooks()
|
| 974 |
+
|
| 975 |
+
if not return_dict:
|
| 976 |
+
return (image,)
|
| 977 |
+
|
| 978 |
+
return FluxPipelineOutput(images=image)
|
videox_fun/pipeline/pipeline_flux2.py
ADDED
|
@@ -0,0 +1,900 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux2/pipeline_flux2.py
|
| 2 |
+
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
|
| 18 |
+
replace_example_docstring)
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import PIL
|
| 24 |
+
import torch
|
| 25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 26 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 27 |
+
from diffusers.utils import (is_torch_xla_available, logging,
|
| 28 |
+
replace_example_docstring)
|
| 29 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 30 |
+
|
| 31 |
+
from ..models import (AutoencoderKLFlux2, Flux2ImageProcessor,
|
| 32 |
+
Flux2Transformer2DModel, Mistral3ForConditionalGeneration, AutoProcessor)
|
| 33 |
+
|
| 34 |
+
if is_torch_xla_available():
|
| 35 |
+
import torch_xla.core.xla_model as xm
|
| 36 |
+
|
| 37 |
+
XLA_AVAILABLE = True
|
| 38 |
+
else:
|
| 39 |
+
XLA_AVAILABLE = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 43 |
+
|
| 44 |
+
EXAMPLE_DOC_STRING = """
|
| 45 |
+
Examples:
|
| 46 |
+
```py
|
| 47 |
+
>>> import torch
|
| 48 |
+
>>> from diffusers import Flux2Pipeline
|
| 49 |
+
|
| 50 |
+
>>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16)
|
| 51 |
+
>>> pipe.to("cuda")
|
| 52 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
| 53 |
+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
| 54 |
+
>>> # Refer to the pipeline documentation for more details.
|
| 55 |
+
>>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
|
| 56 |
+
>>> image.save("flux.png")
|
| 57 |
+
```
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def format_text_input(prompts: List[str], system_message: str = None):
|
| 62 |
+
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
|
| 63 |
+
# when truncation is enabled. The processor counts [IMG] tokens and fails
|
| 64 |
+
# if the count changes after truncation.
|
| 65 |
+
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
|
| 66 |
+
|
| 67 |
+
return [
|
| 68 |
+
[
|
| 69 |
+
{
|
| 70 |
+
"role": "system",
|
| 71 |
+
"content": [{"type": "text", "text": system_message}],
|
| 72 |
+
},
|
| 73 |
+
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
| 74 |
+
]
|
| 75 |
+
for prompt in cleaned_txt
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
| 80 |
+
a1, b1 = 8.73809524e-05, 1.89833333
|
| 81 |
+
a2, b2 = 0.00016927, 0.45666666
|
| 82 |
+
|
| 83 |
+
if image_seq_len > 4300:
|
| 84 |
+
mu = a2 * image_seq_len + b2
|
| 85 |
+
return float(mu)
|
| 86 |
+
|
| 87 |
+
m_200 = a2 * image_seq_len + b2
|
| 88 |
+
m_10 = a1 * image_seq_len + b1
|
| 89 |
+
|
| 90 |
+
a = (m_200 - m_10) / 190.0
|
| 91 |
+
b = m_200 - 200.0 * a
|
| 92 |
+
mu = a * num_steps + b
|
| 93 |
+
|
| 94 |
+
return float(mu)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 98 |
+
def retrieve_timesteps(
|
| 99 |
+
scheduler,
|
| 100 |
+
num_inference_steps: Optional[int] = None,
|
| 101 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 102 |
+
timesteps: Optional[List[int]] = None,
|
| 103 |
+
sigmas: Optional[List[float]] = None,
|
| 104 |
+
**kwargs,
|
| 105 |
+
):
|
| 106 |
+
r"""
|
| 107 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 108 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
scheduler (`SchedulerMixin`):
|
| 112 |
+
The scheduler to get timesteps from.
|
| 113 |
+
num_inference_steps (`int`):
|
| 114 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 115 |
+
must be `None`.
|
| 116 |
+
device (`str` or `torch.device`, *optional*):
|
| 117 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 118 |
+
timesteps (`List[int]`, *optional*):
|
| 119 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 120 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 121 |
+
sigmas (`List[float]`, *optional*):
|
| 122 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 123 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 127 |
+
second element is the number of inference steps.
|
| 128 |
+
"""
|
| 129 |
+
if timesteps is not None and sigmas is not None:
|
| 130 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 131 |
+
if timesteps is not None:
|
| 132 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 133 |
+
if not accepts_timesteps:
|
| 134 |
+
raise ValueError(
|
| 135 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 136 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 137 |
+
)
|
| 138 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 139 |
+
timesteps = scheduler.timesteps
|
| 140 |
+
num_inference_steps = len(timesteps)
|
| 141 |
+
elif sigmas is not None:
|
| 142 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 143 |
+
if not accept_sigmas:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 146 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 147 |
+
)
|
| 148 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 149 |
+
timesteps = scheduler.timesteps
|
| 150 |
+
num_inference_steps = len(timesteps)
|
| 151 |
+
else:
|
| 152 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 153 |
+
timesteps = scheduler.timesteps
|
| 154 |
+
return timesteps, num_inference_steps
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 158 |
+
def retrieve_latents(
|
| 159 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 160 |
+
):
|
| 161 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 162 |
+
return encoder_output.latent_dist.sample(generator)
|
| 163 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 164 |
+
return encoder_output.latent_dist.mode()
|
| 165 |
+
elif hasattr(encoder_output, "latents"):
|
| 166 |
+
return encoder_output.latents
|
| 167 |
+
else:
|
| 168 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@dataclass
|
| 172 |
+
class Flux2PipelineOutput(BaseOutput):
|
| 173 |
+
"""
|
| 174 |
+
Output class for Flux2 image generation pipelines.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
|
| 178 |
+
List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
|
| 179 |
+
height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
|
| 180 |
+
pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
|
| 181 |
+
passed to the decoder.
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class Flux2Pipeline(DiffusionPipeline):
|
| 188 |
+
r"""
|
| 189 |
+
The Flux2 pipeline for text-to-image generation.
|
| 190 |
+
|
| 191 |
+
Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2)
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
transformer ([`Flux2Transformer2DModel`]):
|
| 195 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 196 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 197 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 198 |
+
vae ([`AutoencoderKLFlux2`]):
|
| 199 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 200 |
+
text_encoder ([`Mistral3ForConditionalGeneration`]):
|
| 201 |
+
[Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration)
|
| 202 |
+
tokenizer (`AutoProcessor`):
|
| 203 |
+
Tokenizer of class
|
| 204 |
+
[PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor).
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 208 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 209 |
+
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 213 |
+
vae: AutoencoderKLFlux2,
|
| 214 |
+
text_encoder: Mistral3ForConditionalGeneration,
|
| 215 |
+
tokenizer: AutoProcessor,
|
| 216 |
+
transformer: Flux2Transformer2DModel,
|
| 217 |
+
):
|
| 218 |
+
super().__init__()
|
| 219 |
+
|
| 220 |
+
self.register_modules(
|
| 221 |
+
vae=vae,
|
| 222 |
+
text_encoder=text_encoder,
|
| 223 |
+
tokenizer=tokenizer,
|
| 224 |
+
scheduler=scheduler,
|
| 225 |
+
transformer=transformer,
|
| 226 |
+
)
|
| 227 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 228 |
+
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
| 229 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
| 230 |
+
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 231 |
+
self.tokenizer_max_length = 512
|
| 232 |
+
self.default_sample_size = 128
|
| 233 |
+
|
| 234 |
+
# fmt: off
|
| 235 |
+
self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
|
| 236 |
+
# fmt: on
|
| 237 |
+
|
| 238 |
+
@staticmethod
|
| 239 |
+
def _get_mistral_3_small_prompt_embeds(
|
| 240 |
+
text_encoder: Mistral3ForConditionalGeneration,
|
| 241 |
+
tokenizer: AutoProcessor,
|
| 242 |
+
prompt: Union[str, List[str]],
|
| 243 |
+
dtype: Optional[torch.dtype] = None,
|
| 244 |
+
device: Optional[torch.device] = None,
|
| 245 |
+
max_sequence_length: int = 512,
|
| 246 |
+
# fmt: off
|
| 247 |
+
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
|
| 248 |
+
# fmt: on
|
| 249 |
+
hidden_states_layers: List[int] = (10, 20, 30),
|
| 250 |
+
):
|
| 251 |
+
dtype = text_encoder.dtype if dtype is None else dtype
|
| 252 |
+
device = text_encoder.device if device is None else device
|
| 253 |
+
|
| 254 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 255 |
+
|
| 256 |
+
# Format input messages
|
| 257 |
+
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
|
| 258 |
+
|
| 259 |
+
# Process all messages at once
|
| 260 |
+
inputs = tokenizer.apply_chat_template(
|
| 261 |
+
messages_batch,
|
| 262 |
+
add_generation_prompt=False,
|
| 263 |
+
tokenize=True,
|
| 264 |
+
return_dict=True,
|
| 265 |
+
return_tensors="pt",
|
| 266 |
+
padding="max_length",
|
| 267 |
+
truncation=True,
|
| 268 |
+
max_length=max_sequence_length,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Move to device
|
| 272 |
+
input_ids = inputs["input_ids"].to(device)
|
| 273 |
+
attention_mask = inputs["attention_mask"].to(device)
|
| 274 |
+
|
| 275 |
+
# Forward pass through the model
|
| 276 |
+
output = text_encoder(
|
| 277 |
+
input_ids=input_ids,
|
| 278 |
+
attention_mask=attention_mask,
|
| 279 |
+
output_hidden_states=True,
|
| 280 |
+
use_cache=False,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Only use outputs from intermediate layers and stack them
|
| 284 |
+
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
| 285 |
+
out = out.to(dtype=dtype, device=device)
|
| 286 |
+
|
| 287 |
+
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
| 288 |
+
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
| 289 |
+
|
| 290 |
+
return prompt_embeds
|
| 291 |
+
|
| 292 |
+
@staticmethod
|
| 293 |
+
def _prepare_text_ids(
|
| 294 |
+
x: torch.Tensor, # (B, L, D) or (L, D)
|
| 295 |
+
t_coord: Optional[torch.Tensor] = None,
|
| 296 |
+
):
|
| 297 |
+
B, L, _ = x.shape
|
| 298 |
+
out_ids = []
|
| 299 |
+
|
| 300 |
+
for i in range(B):
|
| 301 |
+
t = torch.arange(1) if t_coord is None else t_coord[i]
|
| 302 |
+
h = torch.arange(1)
|
| 303 |
+
w = torch.arange(1)
|
| 304 |
+
l = torch.arange(L)
|
| 305 |
+
|
| 306 |
+
coords = torch.cartesian_prod(t, h, w, l)
|
| 307 |
+
out_ids.append(coords)
|
| 308 |
+
|
| 309 |
+
return torch.stack(out_ids)
|
| 310 |
+
|
| 311 |
+
def encode_prompt(
|
| 312 |
+
self,
|
| 313 |
+
prompt: Union[str, List[str]],
|
| 314 |
+
device: Optional[torch.device] = None,
|
| 315 |
+
num_images_per_prompt: int = 1,
|
| 316 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 317 |
+
max_sequence_length: int = 512,
|
| 318 |
+
text_encoder_out_layers: Tuple[int] = (10, 20, 30),
|
| 319 |
+
):
|
| 320 |
+
device = device or self._execution_device
|
| 321 |
+
|
| 322 |
+
if prompt is None:
|
| 323 |
+
prompt = ""
|
| 324 |
+
|
| 325 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 326 |
+
|
| 327 |
+
if prompt_embeds is None:
|
| 328 |
+
prompt_embeds = self._get_mistral_3_small_prompt_embeds(
|
| 329 |
+
text_encoder=self.text_encoder,
|
| 330 |
+
tokenizer=self.tokenizer,
|
| 331 |
+
prompt=prompt,
|
| 332 |
+
device=device,
|
| 333 |
+
max_sequence_length=max_sequence_length,
|
| 334 |
+
system_message=self.system_message,
|
| 335 |
+
hidden_states_layers=text_encoder_out_layers,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
batch_size, seq_len, _ = prompt_embeds.shape
|
| 339 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 340 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 341 |
+
|
| 342 |
+
text_ids = self._prepare_text_ids(prompt_embeds)
|
| 343 |
+
text_ids = text_ids.to(device)
|
| 344 |
+
return prompt_embeds, text_ids
|
| 345 |
+
|
| 346 |
+
@staticmethod
|
| 347 |
+
def _prepare_latent_ids(
|
| 348 |
+
latents: torch.Tensor, # (B, C, H, W)
|
| 349 |
+
):
|
| 350 |
+
r"""
|
| 351 |
+
Generates 4D position coordinates (T, H, W, L) for latent tensors.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
latents (torch.Tensor):
|
| 355 |
+
Latent tensor of shape (B, C, H, W)
|
| 356 |
+
|
| 357 |
+
Returns:
|
| 358 |
+
torch.Tensor:
|
| 359 |
+
Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
|
| 360 |
+
H=[0..H-1], W=[0..W-1], L=0
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
batch_size, _, height, width = latents.shape
|
| 364 |
+
|
| 365 |
+
t = torch.arange(1) # [0] - time dimension
|
| 366 |
+
h = torch.arange(height)
|
| 367 |
+
w = torch.arange(width)
|
| 368 |
+
l = torch.arange(1) # [0] - layer dimension
|
| 369 |
+
|
| 370 |
+
# Create position IDs: (H*W, 4)
|
| 371 |
+
latent_ids = torch.cartesian_prod(t, h, w, l)
|
| 372 |
+
|
| 373 |
+
# Expand to batch: (B, H*W, 4)
|
| 374 |
+
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
|
| 375 |
+
|
| 376 |
+
return latent_ids
|
| 377 |
+
|
| 378 |
+
@staticmethod
|
| 379 |
+
def _prepare_image_ids(
|
| 380 |
+
image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
|
| 381 |
+
scale: int = 10,
|
| 382 |
+
):
|
| 383 |
+
r"""
|
| 384 |
+
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
|
| 385 |
+
|
| 386 |
+
This function creates a unique coordinate for every pixel/patch across all input latent with different
|
| 387 |
+
dimensions.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
image_latents (List[torch.Tensor]):
|
| 391 |
+
A list of image latent feature tensors, typically of shape (C, H, W).
|
| 392 |
+
scale (int, optional):
|
| 393 |
+
A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
|
| 394 |
+
latent is: 'scale + scale * i'. Defaults to 10.
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
torch.Tensor:
|
| 398 |
+
The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
|
| 399 |
+
input latents.
|
| 400 |
+
|
| 401 |
+
Coordinate Components (Dimension 4):
|
| 402 |
+
- T (Time): The unique index indicating which latent image the coordinate belongs to.
|
| 403 |
+
- H (Height): The row index within that latent image.
|
| 404 |
+
- W (Width): The column index within that latent image.
|
| 405 |
+
- L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
|
| 406 |
+
"""
|
| 407 |
+
|
| 408 |
+
if not isinstance(image_latents, list):
|
| 409 |
+
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
|
| 410 |
+
|
| 411 |
+
# create time offset for each reference image
|
| 412 |
+
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
| 413 |
+
t_coords = [t.view(-1) for t in t_coords]
|
| 414 |
+
|
| 415 |
+
image_latent_ids = []
|
| 416 |
+
for x, t in zip(image_latents, t_coords):
|
| 417 |
+
x = x.squeeze(0)
|
| 418 |
+
_, height, width = x.shape
|
| 419 |
+
|
| 420 |
+
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
| 421 |
+
image_latent_ids.append(x_ids)
|
| 422 |
+
|
| 423 |
+
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
| 424 |
+
image_latent_ids = image_latent_ids.unsqueeze(0)
|
| 425 |
+
|
| 426 |
+
return image_latent_ids
|
| 427 |
+
|
| 428 |
+
@staticmethod
|
| 429 |
+
def _patchify_latents(latents):
|
| 430 |
+
batch_size, num_channels_latents, height, width = latents.shape
|
| 431 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 432 |
+
latents = latents.permute(0, 1, 3, 5, 2, 4)
|
| 433 |
+
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
|
| 434 |
+
return latents
|
| 435 |
+
|
| 436 |
+
@staticmethod
|
| 437 |
+
def _unpatchify_latents(latents):
|
| 438 |
+
batch_size, num_channels_latents, height, width = latents.shape
|
| 439 |
+
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
|
| 440 |
+
latents = latents.permute(0, 1, 4, 2, 5, 3)
|
| 441 |
+
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
|
| 442 |
+
return latents
|
| 443 |
+
|
| 444 |
+
@staticmethod
|
| 445 |
+
def _pack_latents(latents):
|
| 446 |
+
"""
|
| 447 |
+
pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
|
| 448 |
+
"""
|
| 449 |
+
|
| 450 |
+
batch_size, num_channels, height, width = latents.shape
|
| 451 |
+
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
| 452 |
+
|
| 453 |
+
return latents
|
| 454 |
+
|
| 455 |
+
@staticmethod
|
| 456 |
+
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
|
| 457 |
+
"""
|
| 458 |
+
using position ids to scatter tokens into place
|
| 459 |
+
"""
|
| 460 |
+
x_list = []
|
| 461 |
+
for data, pos in zip(x, x_ids):
|
| 462 |
+
_, ch = data.shape # noqa: F841
|
| 463 |
+
h_ids = pos[:, 1].to(torch.int64)
|
| 464 |
+
w_ids = pos[:, 2].to(torch.int64)
|
| 465 |
+
|
| 466 |
+
h = torch.max(h_ids) + 1
|
| 467 |
+
w = torch.max(w_ids) + 1
|
| 468 |
+
|
| 469 |
+
flat_ids = h_ids * w + w_ids
|
| 470 |
+
|
| 471 |
+
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
|
| 472 |
+
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
| 473 |
+
|
| 474 |
+
# reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
|
| 475 |
+
|
| 476 |
+
out = out.view(h, w, ch).permute(2, 0, 1)
|
| 477 |
+
x_list.append(out)
|
| 478 |
+
|
| 479 |
+
return torch.stack(x_list, dim=0)
|
| 480 |
+
|
| 481 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
| 482 |
+
if image.ndim != 4:
|
| 483 |
+
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
|
| 484 |
+
|
| 485 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
| 486 |
+
image_latents = self._patchify_latents(image_latents)
|
| 487 |
+
|
| 488 |
+
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
|
| 489 |
+
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
|
| 490 |
+
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
|
| 491 |
+
|
| 492 |
+
return image_latents
|
| 493 |
+
|
| 494 |
+
def prepare_latents(
|
| 495 |
+
self,
|
| 496 |
+
batch_size,
|
| 497 |
+
num_latents_channels,
|
| 498 |
+
height,
|
| 499 |
+
width,
|
| 500 |
+
dtype,
|
| 501 |
+
device,
|
| 502 |
+
generator: torch.Generator,
|
| 503 |
+
latents: Optional[torch.Tensor] = None,
|
| 504 |
+
):
|
| 505 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 506 |
+
# latent height and width to be divisible by 2.
|
| 507 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 508 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 509 |
+
|
| 510 |
+
shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
|
| 511 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 512 |
+
raise ValueError(
|
| 513 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 514 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 515 |
+
)
|
| 516 |
+
if latents is None:
|
| 517 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 518 |
+
else:
|
| 519 |
+
latents = latents.to(device=device, dtype=dtype)
|
| 520 |
+
|
| 521 |
+
latent_ids = self._prepare_latent_ids(latents)
|
| 522 |
+
latent_ids = latent_ids.to(device)
|
| 523 |
+
|
| 524 |
+
latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
|
| 525 |
+
return latents, latent_ids
|
| 526 |
+
|
| 527 |
+
def prepare_image_latents(
|
| 528 |
+
self,
|
| 529 |
+
images: List[torch.Tensor],
|
| 530 |
+
batch_size,
|
| 531 |
+
generator: torch.Generator,
|
| 532 |
+
device,
|
| 533 |
+
dtype,
|
| 534 |
+
):
|
| 535 |
+
image_latents = []
|
| 536 |
+
for image in images:
|
| 537 |
+
image = image.to(device=device, dtype=dtype)
|
| 538 |
+
imagge_latent = self._encode_vae_image(image=image, generator=generator)
|
| 539 |
+
image_latents.append(imagge_latent) # (1, 128, 32, 32)
|
| 540 |
+
|
| 541 |
+
image_latent_ids = self._prepare_image_ids(image_latents)
|
| 542 |
+
|
| 543 |
+
# Pack each latent and concatenate
|
| 544 |
+
packed_latents = []
|
| 545 |
+
for latent in image_latents:
|
| 546 |
+
# latent: (1, 128, 32, 32)
|
| 547 |
+
packed = self._pack_latents(latent) # (1, 1024, 128)
|
| 548 |
+
packed = packed.squeeze(0) # (1024, 128) - remove batch dim
|
| 549 |
+
packed_latents.append(packed)
|
| 550 |
+
|
| 551 |
+
# Concatenate all reference tokens along sequence dimension
|
| 552 |
+
image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
|
| 553 |
+
image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
|
| 554 |
+
|
| 555 |
+
image_latents = image_latents.repeat(batch_size, 1, 1)
|
| 556 |
+
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
|
| 557 |
+
image_latent_ids = image_latent_ids.to(device)
|
| 558 |
+
|
| 559 |
+
return image_latents, image_latent_ids
|
| 560 |
+
|
| 561 |
+
def check_inputs(
|
| 562 |
+
self,
|
| 563 |
+
prompt,
|
| 564 |
+
height,
|
| 565 |
+
width,
|
| 566 |
+
prompt_embeds=None,
|
| 567 |
+
callback_on_step_end_tensor_inputs=None,
|
| 568 |
+
):
|
| 569 |
+
if (
|
| 570 |
+
height is not None
|
| 571 |
+
and height % (self.vae_scale_factor * 2) != 0
|
| 572 |
+
or width is not None
|
| 573 |
+
and width % (self.vae_scale_factor * 2) != 0
|
| 574 |
+
):
|
| 575 |
+
logger.warning(
|
| 576 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 580 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 581 |
+
):
|
| 582 |
+
raise ValueError(
|
| 583 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
if prompt is not None and prompt_embeds is not None:
|
| 587 |
+
raise ValueError(
|
| 588 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 589 |
+
" only forward one of the two."
|
| 590 |
+
)
|
| 591 |
+
elif prompt is None and prompt_embeds is None:
|
| 592 |
+
raise ValueError(
|
| 593 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 594 |
+
)
|
| 595 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 596 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 597 |
+
|
| 598 |
+
@property
|
| 599 |
+
def guidance_scale(self):
|
| 600 |
+
return self._guidance_scale
|
| 601 |
+
|
| 602 |
+
@property
|
| 603 |
+
def joint_attention_kwargs(self):
|
| 604 |
+
return self._joint_attention_kwargs
|
| 605 |
+
|
| 606 |
+
@property
|
| 607 |
+
def num_timesteps(self):
|
| 608 |
+
return self._num_timesteps
|
| 609 |
+
|
| 610 |
+
@property
|
| 611 |
+
def current_timestep(self):
|
| 612 |
+
return self._current_timestep
|
| 613 |
+
|
| 614 |
+
@property
|
| 615 |
+
def interrupt(self):
|
| 616 |
+
return self._interrupt
|
| 617 |
+
|
| 618 |
+
@torch.no_grad()
|
| 619 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 620 |
+
def __call__(
|
| 621 |
+
self,
|
| 622 |
+
image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
|
| 623 |
+
prompt: Union[str, List[str]] = None,
|
| 624 |
+
height: Optional[int] = None,
|
| 625 |
+
width: Optional[int] = None,
|
| 626 |
+
num_inference_steps: int = 50,
|
| 627 |
+
sigmas: Optional[List[float]] = None,
|
| 628 |
+
guidance_scale: Optional[float] = 4.0,
|
| 629 |
+
num_images_per_prompt: int = 1,
|
| 630 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 631 |
+
latents: Optional[torch.Tensor] = None,
|
| 632 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 633 |
+
output_type: Optional[str] = "pil",
|
| 634 |
+
return_dict: bool = True,
|
| 635 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 636 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 637 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 638 |
+
max_sequence_length: int = 512,
|
| 639 |
+
text_encoder_out_layers: Tuple[int] = (10, 20, 30),
|
| 640 |
+
):
|
| 641 |
+
r"""
|
| 642 |
+
Function invoked when calling the pipeline for generation.
|
| 643 |
+
|
| 644 |
+
Args:
|
| 645 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 646 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
| 647 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
| 648 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
| 649 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
| 650 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
| 651 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 652 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 653 |
+
instead.
|
| 654 |
+
guidance_scale (`float`, *optional*, defaults to 1.0):
|
| 655 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 656 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 657 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 658 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 659 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 660 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 661 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 662 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 663 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 664 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 665 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 666 |
+
expense of slower inference.
|
| 667 |
+
sigmas (`List[float]`, *optional*):
|
| 668 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 669 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 670 |
+
will be used.
|
| 671 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 672 |
+
The number of images to generate per prompt.
|
| 673 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 674 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 675 |
+
to make generation deterministic.
|
| 676 |
+
latents (`torch.Tensor`, *optional*):
|
| 677 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 678 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 679 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 680 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 681 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 682 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 683 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 684 |
+
The output format of the generate image. Choose between
|
| 685 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 686 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 687 |
+
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
| 688 |
+
attention_kwargs (`dict`, *optional*):
|
| 689 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 690 |
+
`self.processor` in
|
| 691 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 692 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 693 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 694 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 695 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 696 |
+
`callback_on_step_end_tensor_inputs`.
|
| 697 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 698 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 699 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 700 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 701 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 702 |
+
text_encoder_out_layers (`Tuple[int]`):
|
| 703 |
+
Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
|
| 704 |
+
|
| 705 |
+
Examples:
|
| 706 |
+
|
| 707 |
+
Returns:
|
| 708 |
+
[`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if
|
| 709 |
+
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
| 710 |
+
generated images.
|
| 711 |
+
"""
|
| 712 |
+
|
| 713 |
+
# 1. Check inputs. Raise error if not correct
|
| 714 |
+
self.check_inputs(
|
| 715 |
+
prompt=prompt,
|
| 716 |
+
height=height,
|
| 717 |
+
width=width,
|
| 718 |
+
prompt_embeds=prompt_embeds,
|
| 719 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
self._guidance_scale = guidance_scale
|
| 723 |
+
self._attention_kwargs = attention_kwargs
|
| 724 |
+
self._current_timestep = None
|
| 725 |
+
self._interrupt = False
|
| 726 |
+
|
| 727 |
+
# 2. Define call parameters
|
| 728 |
+
if prompt is not None and isinstance(prompt, str):
|
| 729 |
+
batch_size = 1
|
| 730 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 731 |
+
batch_size = len(prompt)
|
| 732 |
+
else:
|
| 733 |
+
batch_size = prompt_embeds.shape[0]
|
| 734 |
+
|
| 735 |
+
device = self._execution_device
|
| 736 |
+
|
| 737 |
+
# 3. prepare text embeddings
|
| 738 |
+
prompt_embeds, text_ids = self.encode_prompt(
|
| 739 |
+
prompt=prompt,
|
| 740 |
+
prompt_embeds=prompt_embeds,
|
| 741 |
+
device=device,
|
| 742 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 743 |
+
max_sequence_length=max_sequence_length,
|
| 744 |
+
text_encoder_out_layers=text_encoder_out_layers,
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
# 4. process images
|
| 748 |
+
if image is not None and not isinstance(image, list):
|
| 749 |
+
image = [image]
|
| 750 |
+
|
| 751 |
+
condition_images = None
|
| 752 |
+
if image is not None:
|
| 753 |
+
for img in image:
|
| 754 |
+
self.image_processor.check_image_input(img)
|
| 755 |
+
|
| 756 |
+
condition_images = []
|
| 757 |
+
for img in image:
|
| 758 |
+
image_width, image_height = img.size
|
| 759 |
+
if image_width * image_height > 1024 * 1024:
|
| 760 |
+
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
|
| 761 |
+
image_width, image_height = img.size
|
| 762 |
+
|
| 763 |
+
multiple_of = self.vae_scale_factor * 2
|
| 764 |
+
image_width = (image_width // multiple_of) * multiple_of
|
| 765 |
+
image_height = (image_height // multiple_of) * multiple_of
|
| 766 |
+
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
|
| 767 |
+
condition_images.append(img)
|
| 768 |
+
height = height or image_height
|
| 769 |
+
width = width or image_width
|
| 770 |
+
|
| 771 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 772 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 773 |
+
|
| 774 |
+
# 5. prepare latent variables
|
| 775 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 776 |
+
latents, latent_ids = self.prepare_latents(
|
| 777 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 778 |
+
num_latents_channels=num_channels_latents,
|
| 779 |
+
height=height,
|
| 780 |
+
width=width,
|
| 781 |
+
dtype=prompt_embeds.dtype,
|
| 782 |
+
device=device,
|
| 783 |
+
generator=generator,
|
| 784 |
+
latents=latents,
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
image_latents = None
|
| 788 |
+
image_latent_ids = None
|
| 789 |
+
if condition_images is not None:
|
| 790 |
+
image_latents, image_latent_ids = self.prepare_image_latents(
|
| 791 |
+
images=condition_images,
|
| 792 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 793 |
+
generator=generator,
|
| 794 |
+
device=device,
|
| 795 |
+
dtype=self.vae.dtype,
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
# 6. Prepare timesteps
|
| 799 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 800 |
+
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
|
| 801 |
+
sigmas = None
|
| 802 |
+
image_seq_len = latents.shape[1]
|
| 803 |
+
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
|
| 804 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 805 |
+
self.scheduler,
|
| 806 |
+
num_inference_steps,
|
| 807 |
+
device,
|
| 808 |
+
sigmas=sigmas,
|
| 809 |
+
mu=mu,
|
| 810 |
+
)
|
| 811 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 812 |
+
self._num_timesteps = len(timesteps)
|
| 813 |
+
|
| 814 |
+
# handle guidance
|
| 815 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 816 |
+
guidance = guidance.expand(latents.shape[0])
|
| 817 |
+
|
| 818 |
+
# 7. Denoising loop
|
| 819 |
+
# We set the index here to remove DtoH sync, helpful especially during compilation.
|
| 820 |
+
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
|
| 821 |
+
self.scheduler.set_begin_index(0)
|
| 822 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 823 |
+
for i, t in enumerate(timesteps):
|
| 824 |
+
if self.interrupt:
|
| 825 |
+
continue
|
| 826 |
+
|
| 827 |
+
self._current_timestep = t
|
| 828 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 829 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 830 |
+
|
| 831 |
+
latent_model_input = latents.to(self.transformer.dtype)
|
| 832 |
+
latent_image_ids = latent_ids
|
| 833 |
+
|
| 834 |
+
if image_latents is not None:
|
| 835 |
+
latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
|
| 836 |
+
latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
|
| 837 |
+
|
| 838 |
+
noise_pred = self.transformer(
|
| 839 |
+
hidden_states=latent_model_input, # (B, image_seq_len, C)
|
| 840 |
+
timestep=timestep / 1000,
|
| 841 |
+
guidance=guidance,
|
| 842 |
+
encoder_hidden_states=prompt_embeds,
|
| 843 |
+
txt_ids=text_ids, # B, text_seq_len, 4
|
| 844 |
+
img_ids=latent_image_ids, # B, image_seq_len, 4
|
| 845 |
+
joint_attention_kwargs=self._attention_kwargs,
|
| 846 |
+
return_dict=False,
|
| 847 |
+
)[0]
|
| 848 |
+
|
| 849 |
+
noise_pred = noise_pred[:, : latents.size(1) :]
|
| 850 |
+
|
| 851 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 852 |
+
latents_dtype = latents.dtype
|
| 853 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 854 |
+
|
| 855 |
+
if latents.dtype != latents_dtype:
|
| 856 |
+
if torch.backends.mps.is_available():
|
| 857 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 858 |
+
latents = latents.to(latents_dtype)
|
| 859 |
+
|
| 860 |
+
if callback_on_step_end is not None:
|
| 861 |
+
callback_kwargs = {}
|
| 862 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 863 |
+
callback_kwargs[k] = locals()[k]
|
| 864 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 865 |
+
|
| 866 |
+
latents = callback_outputs.pop("latents", latents)
|
| 867 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 868 |
+
|
| 869 |
+
# call the callback, if provided
|
| 870 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 871 |
+
progress_bar.update()
|
| 872 |
+
|
| 873 |
+
if XLA_AVAILABLE:
|
| 874 |
+
xm.mark_step()
|
| 875 |
+
|
| 876 |
+
self._current_timestep = None
|
| 877 |
+
|
| 878 |
+
if output_type == "latent":
|
| 879 |
+
image = latents
|
| 880 |
+
else:
|
| 881 |
+
torch.save({"pred": latents}, "pred_d.pt")
|
| 882 |
+
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
| 883 |
+
|
| 884 |
+
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
| 885 |
+
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
|
| 886 |
+
latents.device, latents.dtype
|
| 887 |
+
)
|
| 888 |
+
latents = latents * latents_bn_std + latents_bn_mean
|
| 889 |
+
latents = self._unpatchify_latents(latents)
|
| 890 |
+
|
| 891 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 892 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 893 |
+
|
| 894 |
+
# Offload all models
|
| 895 |
+
self.maybe_free_model_hooks()
|
| 896 |
+
|
| 897 |
+
if not return_dict:
|
| 898 |
+
return (image,)
|
| 899 |
+
|
| 900 |
+
return Flux2PipelineOutput(images=image)
|
videox_fun/pipeline/pipeline_flux2_control.py
ADDED
|
@@ -0,0 +1,973 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux2/pipeline_flux2.py
|
| 2 |
+
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
|
| 18 |
+
replace_example_docstring)
|
| 19 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import PIL
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 28 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 29 |
+
from diffusers.utils import (is_torch_xla_available, logging,
|
| 30 |
+
replace_example_docstring)
|
| 31 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 32 |
+
|
| 33 |
+
from ..models import (AutoencoderKLFlux2, Flux2ImageProcessor,
|
| 34 |
+
Flux2ControlTransformer2DModel, Mistral3ForConditionalGeneration, AutoProcessor)
|
| 35 |
+
|
| 36 |
+
if is_torch_xla_available():
|
| 37 |
+
import torch_xla.core.xla_model as xm
|
| 38 |
+
|
| 39 |
+
XLA_AVAILABLE = True
|
| 40 |
+
else:
|
| 41 |
+
XLA_AVAILABLE = False
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 45 |
+
|
| 46 |
+
EXAMPLE_DOC_STRING = """
|
| 47 |
+
Examples:
|
| 48 |
+
```py
|
| 49 |
+
>>> import torch
|
| 50 |
+
>>> from diffusers import Flux2Pipeline
|
| 51 |
+
|
| 52 |
+
>>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16)
|
| 53 |
+
>>> pipe.to("cuda")
|
| 54 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
| 55 |
+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
| 56 |
+
>>> # Refer to the pipeline documentation for more details.
|
| 57 |
+
>>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
|
| 58 |
+
>>> image.save("flux.png")
|
| 59 |
+
```
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def format_text_input(prompts: List[str], system_message: str = None):
|
| 64 |
+
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
|
| 65 |
+
# when truncation is enabled. The processor counts [IMG] tokens and fails
|
| 66 |
+
# if the count changes after truncation.
|
| 67 |
+
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
|
| 68 |
+
|
| 69 |
+
return [
|
| 70 |
+
[
|
| 71 |
+
{
|
| 72 |
+
"role": "system",
|
| 73 |
+
"content": [{"type": "text", "text": system_message}],
|
| 74 |
+
},
|
| 75 |
+
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
| 76 |
+
]
|
| 77 |
+
for prompt in cleaned_txt
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
| 82 |
+
a1, b1 = 8.73809524e-05, 1.89833333
|
| 83 |
+
a2, b2 = 0.00016927, 0.45666666
|
| 84 |
+
|
| 85 |
+
if image_seq_len > 4300:
|
| 86 |
+
mu = a2 * image_seq_len + b2
|
| 87 |
+
return float(mu)
|
| 88 |
+
|
| 89 |
+
m_200 = a2 * image_seq_len + b2
|
| 90 |
+
m_10 = a1 * image_seq_len + b1
|
| 91 |
+
|
| 92 |
+
a = (m_200 - m_10) / 190.0
|
| 93 |
+
b = m_200 - 200.0 * a
|
| 94 |
+
mu = a * num_steps + b
|
| 95 |
+
|
| 96 |
+
return float(mu)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 100 |
+
def retrieve_timesteps(
|
| 101 |
+
scheduler,
|
| 102 |
+
num_inference_steps: Optional[int] = None,
|
| 103 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 104 |
+
timesteps: Optional[List[int]] = None,
|
| 105 |
+
sigmas: Optional[List[float]] = None,
|
| 106 |
+
**kwargs,
|
| 107 |
+
):
|
| 108 |
+
r"""
|
| 109 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 110 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
scheduler (`SchedulerMixin`):
|
| 114 |
+
The scheduler to get timesteps from.
|
| 115 |
+
num_inference_steps (`int`):
|
| 116 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 117 |
+
must be `None`.
|
| 118 |
+
device (`str` or `torch.device`, *optional*):
|
| 119 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 120 |
+
timesteps (`List[int]`, *optional*):
|
| 121 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 122 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 123 |
+
sigmas (`List[float]`, *optional*):
|
| 124 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 125 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 129 |
+
second element is the number of inference steps.
|
| 130 |
+
"""
|
| 131 |
+
if timesteps is not None and sigmas is not None:
|
| 132 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 133 |
+
if timesteps is not None:
|
| 134 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 135 |
+
if not accepts_timesteps:
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 138 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 139 |
+
)
|
| 140 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 141 |
+
timesteps = scheduler.timesteps
|
| 142 |
+
num_inference_steps = len(timesteps)
|
| 143 |
+
elif sigmas is not None:
|
| 144 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 145 |
+
if not accept_sigmas:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 148 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 149 |
+
)
|
| 150 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 151 |
+
timesteps = scheduler.timesteps
|
| 152 |
+
num_inference_steps = len(timesteps)
|
| 153 |
+
else:
|
| 154 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 155 |
+
timesteps = scheduler.timesteps
|
| 156 |
+
return timesteps, num_inference_steps
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 160 |
+
def retrieve_latents(
|
| 161 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 162 |
+
):
|
| 163 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 164 |
+
return encoder_output.latent_dist.sample(generator)
|
| 165 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 166 |
+
return encoder_output.latent_dist.mode()
|
| 167 |
+
elif hasattr(encoder_output, "latents"):
|
| 168 |
+
return encoder_output.latents
|
| 169 |
+
else:
|
| 170 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@dataclass
|
| 174 |
+
class Flux2PipelineOutput(BaseOutput):
|
| 175 |
+
"""
|
| 176 |
+
Output class for Flux2 image generation pipelines.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
|
| 180 |
+
List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
|
| 181 |
+
height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
|
| 182 |
+
pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
|
| 183 |
+
passed to the decoder.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class Flux2ControlPipeline(DiffusionPipeline):
|
| 190 |
+
r"""
|
| 191 |
+
The Flux2 pipeline for text-to-image generation.
|
| 192 |
+
|
| 193 |
+
Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2)
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
transformer ([`Flux2ControlTransformer2DModel`]):
|
| 197 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 198 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 199 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 200 |
+
vae ([`AutoencoderKLFlux2`]):
|
| 201 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 202 |
+
text_encoder ([`Mistral3ForConditionalGeneration`]):
|
| 203 |
+
[Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration)
|
| 204 |
+
tokenizer (`AutoProcessor`):
|
| 205 |
+
Tokenizer of class
|
| 206 |
+
[PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor).
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 210 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 211 |
+
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 215 |
+
vae: AutoencoderKLFlux2,
|
| 216 |
+
text_encoder: Mistral3ForConditionalGeneration,
|
| 217 |
+
tokenizer: AutoProcessor,
|
| 218 |
+
transformer: Flux2ControlTransformer2DModel,
|
| 219 |
+
):
|
| 220 |
+
super().__init__()
|
| 221 |
+
|
| 222 |
+
self.register_modules(
|
| 223 |
+
vae=vae,
|
| 224 |
+
text_encoder=text_encoder,
|
| 225 |
+
tokenizer=tokenizer,
|
| 226 |
+
scheduler=scheduler,
|
| 227 |
+
transformer=transformer,
|
| 228 |
+
)
|
| 229 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 230 |
+
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
| 231 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
| 232 |
+
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 233 |
+
self.diffusers_image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 234 |
+
self.mask_processor = VaeImageProcessor(
|
| 235 |
+
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 236 |
+
)
|
| 237 |
+
self.tokenizer_max_length = 512
|
| 238 |
+
self.default_sample_size = 128
|
| 239 |
+
|
| 240 |
+
# fmt: off
|
| 241 |
+
self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
|
| 242 |
+
# fmt: on
|
| 243 |
+
|
| 244 |
+
@staticmethod
|
| 245 |
+
def _get_mistral_3_small_prompt_embeds(
|
| 246 |
+
text_encoder: Mistral3ForConditionalGeneration,
|
| 247 |
+
tokenizer: AutoProcessor,
|
| 248 |
+
prompt: Union[str, List[str]],
|
| 249 |
+
dtype: Optional[torch.dtype] = None,
|
| 250 |
+
device: Optional[torch.device] = None,
|
| 251 |
+
max_sequence_length: int = 512,
|
| 252 |
+
# fmt: off
|
| 253 |
+
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
|
| 254 |
+
# fmt: on
|
| 255 |
+
hidden_states_layers: List[int] = (10, 20, 30),
|
| 256 |
+
):
|
| 257 |
+
dtype = text_encoder.dtype if dtype is None else dtype
|
| 258 |
+
device = text_encoder.device if device is None else device
|
| 259 |
+
|
| 260 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 261 |
+
|
| 262 |
+
# Format input messages
|
| 263 |
+
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
|
| 264 |
+
|
| 265 |
+
# Process all messages at once
|
| 266 |
+
inputs = tokenizer.apply_chat_template(
|
| 267 |
+
messages_batch,
|
| 268 |
+
add_generation_prompt=False,
|
| 269 |
+
tokenize=True,
|
| 270 |
+
return_dict=True,
|
| 271 |
+
return_tensors="pt",
|
| 272 |
+
padding="max_length",
|
| 273 |
+
truncation=True,
|
| 274 |
+
max_length=max_sequence_length,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Move to device
|
| 278 |
+
input_ids = inputs["input_ids"].to(device)
|
| 279 |
+
attention_mask = inputs["attention_mask"].to(device)
|
| 280 |
+
|
| 281 |
+
# Forward pass through the model
|
| 282 |
+
output = text_encoder(
|
| 283 |
+
input_ids=input_ids,
|
| 284 |
+
attention_mask=attention_mask,
|
| 285 |
+
output_hidden_states=True,
|
| 286 |
+
use_cache=False,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Only use outputs from intermediate layers and stack them
|
| 290 |
+
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
| 291 |
+
out = out.to(dtype=dtype, device=device)
|
| 292 |
+
|
| 293 |
+
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
| 294 |
+
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
| 295 |
+
|
| 296 |
+
return prompt_embeds
|
| 297 |
+
|
| 298 |
+
@staticmethod
|
| 299 |
+
def _prepare_text_ids(
|
| 300 |
+
x: torch.Tensor, # (B, L, D) or (L, D)
|
| 301 |
+
t_coord: Optional[torch.Tensor] = None,
|
| 302 |
+
):
|
| 303 |
+
B, L, _ = x.shape
|
| 304 |
+
out_ids = []
|
| 305 |
+
|
| 306 |
+
for i in range(B):
|
| 307 |
+
t = torch.arange(1) if t_coord is None else t_coord[i]
|
| 308 |
+
h = torch.arange(1)
|
| 309 |
+
w = torch.arange(1)
|
| 310 |
+
l = torch.arange(L)
|
| 311 |
+
|
| 312 |
+
coords = torch.cartesian_prod(t, h, w, l)
|
| 313 |
+
out_ids.append(coords)
|
| 314 |
+
|
| 315 |
+
return torch.stack(out_ids)
|
| 316 |
+
|
| 317 |
+
def encode_prompt(
|
| 318 |
+
self,
|
| 319 |
+
prompt: Union[str, List[str]],
|
| 320 |
+
device: Optional[torch.device] = None,
|
| 321 |
+
num_images_per_prompt: int = 1,
|
| 322 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 323 |
+
max_sequence_length: int = 512,
|
| 324 |
+
text_encoder_out_layers: Tuple[int] = (10, 20, 30),
|
| 325 |
+
):
|
| 326 |
+
device = device or self._execution_device
|
| 327 |
+
|
| 328 |
+
if prompt is None:
|
| 329 |
+
prompt = ""
|
| 330 |
+
|
| 331 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 332 |
+
|
| 333 |
+
if prompt_embeds is None:
|
| 334 |
+
prompt_embeds = self._get_mistral_3_small_prompt_embeds(
|
| 335 |
+
text_encoder=self.text_encoder,
|
| 336 |
+
tokenizer=self.tokenizer,
|
| 337 |
+
prompt=prompt,
|
| 338 |
+
device=device,
|
| 339 |
+
max_sequence_length=max_sequence_length,
|
| 340 |
+
system_message=self.system_message,
|
| 341 |
+
hidden_states_layers=text_encoder_out_layers,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
batch_size, seq_len, _ = prompt_embeds.shape
|
| 345 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 346 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 347 |
+
|
| 348 |
+
text_ids = self._prepare_text_ids(prompt_embeds)
|
| 349 |
+
text_ids = text_ids.to(device)
|
| 350 |
+
return prompt_embeds, text_ids
|
| 351 |
+
|
| 352 |
+
@staticmethod
|
| 353 |
+
def _prepare_latent_ids(
|
| 354 |
+
latents: torch.Tensor, # (B, C, H, W)
|
| 355 |
+
):
|
| 356 |
+
r"""
|
| 357 |
+
Generates 4D position coordinates (T, H, W, L) for latent tensors.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
latents (torch.Tensor):
|
| 361 |
+
Latent tensor of shape (B, C, H, W)
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
torch.Tensor:
|
| 365 |
+
Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
|
| 366 |
+
H=[0..H-1], W=[0..W-1], L=0
|
| 367 |
+
"""
|
| 368 |
+
|
| 369 |
+
batch_size, _, height, width = latents.shape
|
| 370 |
+
|
| 371 |
+
t = torch.arange(1) # [0] - time dimension
|
| 372 |
+
h = torch.arange(height)
|
| 373 |
+
w = torch.arange(width)
|
| 374 |
+
l = torch.arange(1) # [0] - layer dimension
|
| 375 |
+
|
| 376 |
+
# Create position IDs: (H*W, 4)
|
| 377 |
+
latent_ids = torch.cartesian_prod(t, h, w, l)
|
| 378 |
+
|
| 379 |
+
# Expand to batch: (B, H*W, 4)
|
| 380 |
+
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
|
| 381 |
+
|
| 382 |
+
return latent_ids
|
| 383 |
+
|
| 384 |
+
@staticmethod
|
| 385 |
+
def _prepare_image_ids(
|
| 386 |
+
image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
|
| 387 |
+
scale: int = 10,
|
| 388 |
+
):
|
| 389 |
+
r"""
|
| 390 |
+
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
|
| 391 |
+
|
| 392 |
+
This function creates a unique coordinate for every pixel/patch across all input latent with different
|
| 393 |
+
dimensions.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
image_latents (List[torch.Tensor]):
|
| 397 |
+
A list of image latent feature tensors, typically of shape (C, H, W).
|
| 398 |
+
scale (int, optional):
|
| 399 |
+
A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
|
| 400 |
+
latent is: 'scale + scale * i'. Defaults to 10.
|
| 401 |
+
|
| 402 |
+
Returns:
|
| 403 |
+
torch.Tensor:
|
| 404 |
+
The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
|
| 405 |
+
input latents.
|
| 406 |
+
|
| 407 |
+
Coordinate Components (Dimension 4):
|
| 408 |
+
- T (Time): The unique index indicating which latent image the coordinate belongs to.
|
| 409 |
+
- H (Height): The row index within that latent image.
|
| 410 |
+
- W (Width): The column index within that latent image.
|
| 411 |
+
- L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
|
| 412 |
+
"""
|
| 413 |
+
|
| 414 |
+
if not isinstance(image_latents, list):
|
| 415 |
+
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
|
| 416 |
+
|
| 417 |
+
# create time offset for each reference image
|
| 418 |
+
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
| 419 |
+
t_coords = [t.view(-1) for t in t_coords]
|
| 420 |
+
|
| 421 |
+
image_latent_ids = []
|
| 422 |
+
for x, t in zip(image_latents, t_coords):
|
| 423 |
+
x = x.squeeze(0)
|
| 424 |
+
_, height, width = x.shape
|
| 425 |
+
|
| 426 |
+
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
| 427 |
+
image_latent_ids.append(x_ids)
|
| 428 |
+
|
| 429 |
+
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
| 430 |
+
image_latent_ids = image_latent_ids.unsqueeze(0)
|
| 431 |
+
|
| 432 |
+
return image_latent_ids
|
| 433 |
+
|
| 434 |
+
@staticmethod
|
| 435 |
+
def _patchify_latents(latents):
|
| 436 |
+
batch_size, num_channels_latents, height, width = latents.shape
|
| 437 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 438 |
+
latents = latents.permute(0, 1, 3, 5, 2, 4)
|
| 439 |
+
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
|
| 440 |
+
return latents
|
| 441 |
+
|
| 442 |
+
@staticmethod
|
| 443 |
+
def _unpatchify_latents(latents):
|
| 444 |
+
batch_size, num_channels_latents, height, width = latents.shape
|
| 445 |
+
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
|
| 446 |
+
latents = latents.permute(0, 1, 4, 2, 5, 3)
|
| 447 |
+
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
|
| 448 |
+
return latents
|
| 449 |
+
|
| 450 |
+
@staticmethod
|
| 451 |
+
def _pack_latents(latents):
|
| 452 |
+
"""
|
| 453 |
+
pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
|
| 454 |
+
"""
|
| 455 |
+
|
| 456 |
+
batch_size, num_channels, height, width = latents.shape
|
| 457 |
+
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
| 458 |
+
|
| 459 |
+
return latents
|
| 460 |
+
|
| 461 |
+
@staticmethod
|
| 462 |
+
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
|
| 463 |
+
"""
|
| 464 |
+
using position ids to scatter tokens into place
|
| 465 |
+
"""
|
| 466 |
+
x_list = []
|
| 467 |
+
for data, pos in zip(x, x_ids):
|
| 468 |
+
_, ch = data.shape # noqa: F841
|
| 469 |
+
h_ids = pos[:, 1].to(torch.int64)
|
| 470 |
+
w_ids = pos[:, 2].to(torch.int64)
|
| 471 |
+
|
| 472 |
+
h = torch.max(h_ids) + 1
|
| 473 |
+
w = torch.max(w_ids) + 1
|
| 474 |
+
|
| 475 |
+
flat_ids = h_ids * w + w_ids
|
| 476 |
+
|
| 477 |
+
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
|
| 478 |
+
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
| 479 |
+
|
| 480 |
+
# reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
|
| 481 |
+
|
| 482 |
+
out = out.view(h, w, ch).permute(2, 0, 1)
|
| 483 |
+
x_list.append(out)
|
| 484 |
+
|
| 485 |
+
return torch.stack(x_list, dim=0)
|
| 486 |
+
|
| 487 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
| 488 |
+
if image.ndim != 4:
|
| 489 |
+
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
|
| 490 |
+
|
| 491 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
| 492 |
+
image_latents = self._patchify_latents(image_latents)
|
| 493 |
+
|
| 494 |
+
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
|
| 495 |
+
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
|
| 496 |
+
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
|
| 497 |
+
|
| 498 |
+
return image_latents
|
| 499 |
+
|
| 500 |
+
def prepare_latents(
|
| 501 |
+
self,
|
| 502 |
+
batch_size,
|
| 503 |
+
num_latents_channels,
|
| 504 |
+
height,
|
| 505 |
+
width,
|
| 506 |
+
dtype,
|
| 507 |
+
device,
|
| 508 |
+
generator: torch.Generator,
|
| 509 |
+
latents: Optional[torch.Tensor] = None,
|
| 510 |
+
):
|
| 511 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 512 |
+
# latent height and width to be divisible by 2.
|
| 513 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 514 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 515 |
+
|
| 516 |
+
shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
|
| 517 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 518 |
+
raise ValueError(
|
| 519 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 520 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 521 |
+
)
|
| 522 |
+
if latents is None:
|
| 523 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 524 |
+
else:
|
| 525 |
+
latents = latents.to(device=device, dtype=dtype)
|
| 526 |
+
|
| 527 |
+
latent_ids = self._prepare_latent_ids(latents)
|
| 528 |
+
latent_ids = latent_ids.to(device)
|
| 529 |
+
|
| 530 |
+
latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
|
| 531 |
+
return latents, latent_ids
|
| 532 |
+
|
| 533 |
+
def prepare_image_latents(
|
| 534 |
+
self,
|
| 535 |
+
images: List[torch.Tensor],
|
| 536 |
+
batch_size,
|
| 537 |
+
generator: torch.Generator,
|
| 538 |
+
device,
|
| 539 |
+
dtype,
|
| 540 |
+
):
|
| 541 |
+
image_latents = []
|
| 542 |
+
for image in images:
|
| 543 |
+
image = image.to(device=device, dtype=dtype)
|
| 544 |
+
imagge_latent = self._encode_vae_image(image=image, generator=generator)
|
| 545 |
+
image_latents.append(imagge_latent) # (1, 128, 32, 32)
|
| 546 |
+
|
| 547 |
+
image_latent_ids = self._prepare_image_ids(image_latents)
|
| 548 |
+
|
| 549 |
+
# Pack each latent and concatenate
|
| 550 |
+
packed_latents = []
|
| 551 |
+
for latent in image_latents:
|
| 552 |
+
# latent: (1, 128, 32, 32)
|
| 553 |
+
packed = self._pack_latents(latent) # (1, 1024, 128)
|
| 554 |
+
packed = packed.squeeze(0) # (1024, 128) - remove batch dim
|
| 555 |
+
packed_latents.append(packed)
|
| 556 |
+
|
| 557 |
+
# Concatenate all reference tokens along sequence dimension
|
| 558 |
+
image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
|
| 559 |
+
image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
|
| 560 |
+
|
| 561 |
+
image_latents = image_latents.repeat(batch_size, 1, 1)
|
| 562 |
+
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
|
| 563 |
+
image_latent_ids = image_latent_ids.to(device)
|
| 564 |
+
|
| 565 |
+
return image_latents, image_latent_ids
|
| 566 |
+
|
| 567 |
+
def check_inputs(
|
| 568 |
+
self,
|
| 569 |
+
prompt,
|
| 570 |
+
height,
|
| 571 |
+
width,
|
| 572 |
+
prompt_embeds=None,
|
| 573 |
+
callback_on_step_end_tensor_inputs=None,
|
| 574 |
+
):
|
| 575 |
+
if (
|
| 576 |
+
height is not None
|
| 577 |
+
and height % (self.vae_scale_factor * 2) != 0
|
| 578 |
+
or width is not None
|
| 579 |
+
and width % (self.vae_scale_factor * 2) != 0
|
| 580 |
+
):
|
| 581 |
+
logger.warning(
|
| 582 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 586 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 587 |
+
):
|
| 588 |
+
raise ValueError(
|
| 589 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
if prompt is not None and prompt_embeds is not None:
|
| 593 |
+
raise ValueError(
|
| 594 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 595 |
+
" only forward one of the two."
|
| 596 |
+
)
|
| 597 |
+
elif prompt is None and prompt_embeds is None:
|
| 598 |
+
raise ValueError(
|
| 599 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 600 |
+
)
|
| 601 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 602 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 603 |
+
|
| 604 |
+
@property
|
| 605 |
+
def guidance_scale(self):
|
| 606 |
+
return self._guidance_scale
|
| 607 |
+
|
| 608 |
+
@property
|
| 609 |
+
def joint_attention_kwargs(self):
|
| 610 |
+
return self._joint_attention_kwargs
|
| 611 |
+
|
| 612 |
+
@property
|
| 613 |
+
def num_timesteps(self):
|
| 614 |
+
return self._num_timesteps
|
| 615 |
+
|
| 616 |
+
@property
|
| 617 |
+
def current_timestep(self):
|
| 618 |
+
return self._current_timestep
|
| 619 |
+
|
| 620 |
+
@property
|
| 621 |
+
def interrupt(self):
|
| 622 |
+
return self._interrupt
|
| 623 |
+
|
| 624 |
+
@torch.no_grad()
|
| 625 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 626 |
+
def __call__(
|
| 627 |
+
self,
|
| 628 |
+
prompt: Union[str, List[str]] = None,
|
| 629 |
+
height: Optional[int] = None,
|
| 630 |
+
width: Optional[int] = None,
|
| 631 |
+
|
| 632 |
+
image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
|
| 633 |
+
inpaint_image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
|
| 634 |
+
mask_image: Union[torch.FloatTensor] = None,
|
| 635 |
+
control_image: Union[torch.FloatTensor] = None,
|
| 636 |
+
control_context_scale: float = 1.0,
|
| 637 |
+
|
| 638 |
+
num_inference_steps: int = 50,
|
| 639 |
+
sigmas: Optional[List[float]] = None,
|
| 640 |
+
guidance_scale: Optional[float] = 4.0,
|
| 641 |
+
num_images_per_prompt: int = 1,
|
| 642 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 643 |
+
latents: Optional[torch.Tensor] = None,
|
| 644 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 645 |
+
output_type: Optional[str] = "pil",
|
| 646 |
+
return_dict: bool = True,
|
| 647 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 648 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 649 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 650 |
+
max_sequence_length: int = 512,
|
| 651 |
+
text_encoder_out_layers: Tuple[int] = (10, 20, 30),
|
| 652 |
+
):
|
| 653 |
+
r"""
|
| 654 |
+
Function invoked when calling the pipeline for generation.
|
| 655 |
+
|
| 656 |
+
Args:
|
| 657 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 658 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
| 659 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
| 660 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
| 661 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
| 662 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
| 663 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 664 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 665 |
+
instead.
|
| 666 |
+
guidance_scale (`float`, *optional*, defaults to 1.0):
|
| 667 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 668 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 669 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 670 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 671 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 672 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 673 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 674 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 675 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 676 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 677 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 678 |
+
expense of slower inference.
|
| 679 |
+
sigmas (`List[float]`, *optional*):
|
| 680 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 681 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 682 |
+
will be used.
|
| 683 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 684 |
+
The number of images to generate per prompt.
|
| 685 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 686 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 687 |
+
to make generation deterministic.
|
| 688 |
+
latents (`torch.Tensor`, *optional*):
|
| 689 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 690 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 691 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 692 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 693 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 694 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 695 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 696 |
+
The output format of the generate image. Choose between
|
| 697 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 698 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 699 |
+
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
| 700 |
+
attention_kwargs (`dict`, *optional*):
|
| 701 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 702 |
+
`self.processor` in
|
| 703 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 704 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 705 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 706 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 707 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 708 |
+
`callback_on_step_end_tensor_inputs`.
|
| 709 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 710 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 711 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 712 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 713 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 714 |
+
text_encoder_out_layers (`Tuple[int]`):
|
| 715 |
+
Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
|
| 716 |
+
|
| 717 |
+
Examples:
|
| 718 |
+
|
| 719 |
+
Returns:
|
| 720 |
+
[`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if
|
| 721 |
+
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
| 722 |
+
generated images.
|
| 723 |
+
"""
|
| 724 |
+
|
| 725 |
+
# 1. Check inputs. Raise error if not correct
|
| 726 |
+
self.check_inputs(
|
| 727 |
+
prompt=prompt,
|
| 728 |
+
height=height,
|
| 729 |
+
width=width,
|
| 730 |
+
prompt_embeds=prompt_embeds,
|
| 731 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
self._guidance_scale = guidance_scale
|
| 735 |
+
self._attention_kwargs = attention_kwargs
|
| 736 |
+
self._current_timestep = None
|
| 737 |
+
self._interrupt = False
|
| 738 |
+
|
| 739 |
+
# 2. Define call parameters
|
| 740 |
+
if prompt is not None and isinstance(prompt, str):
|
| 741 |
+
batch_size = 1
|
| 742 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 743 |
+
batch_size = len(prompt)
|
| 744 |
+
else:
|
| 745 |
+
batch_size = prompt_embeds.shape[0]
|
| 746 |
+
|
| 747 |
+
device = self._execution_device
|
| 748 |
+
weight_dtype = self.text_encoder.dtype
|
| 749 |
+
|
| 750 |
+
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device, weight_dtype)
|
| 751 |
+
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
|
| 752 |
+
device, weight_dtype
|
| 753 |
+
)
|
| 754 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 755 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 756 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 757 |
+
|
| 758 |
+
# Prepare mask latent variables
|
| 759 |
+
if mask_image is not None:
|
| 760 |
+
mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
| 761 |
+
mask_condition = torch.tile(mask_condition, [1, 3, 1, 1]).to(dtype=weight_dtype, device=device)
|
| 762 |
+
|
| 763 |
+
if inpaint_image is not None:
|
| 764 |
+
init_image = self.diffusers_image_processor.preprocess(inpaint_image, height=height, width=width)
|
| 765 |
+
init_image = init_image.to(dtype=weight_dtype, device=device) * (mask_condition < 0.5)
|
| 766 |
+
inpaint_latent = self.vae.encode(init_image)[0].mode()
|
| 767 |
+
else:
|
| 768 |
+
inpaint_latent = torch.zeros((batch_size, num_channels_latents * 4, height // 2 // self.vae_scale_factor, width // 2 // self.vae_scale_factor)).to(device, weight_dtype)
|
| 769 |
+
|
| 770 |
+
if control_image is not None:
|
| 771 |
+
control_image = self.diffusers_image_processor.preprocess(control_image, height=height, width=width)
|
| 772 |
+
control_image = control_image.to(dtype=weight_dtype, device=device)
|
| 773 |
+
control_latents = self.vae.encode(control_image)[0].mode()
|
| 774 |
+
else:
|
| 775 |
+
control_latents = torch.zeros_like(inpaint_latent)
|
| 776 |
+
|
| 777 |
+
mask_condition = F.interpolate(1 - mask_condition[:, :1], size=control_latents.size()[-2:], mode='nearest').to(device, weight_dtype)
|
| 778 |
+
mask_condition = self._patchify_latents(mask_condition)
|
| 779 |
+
mask_condition = self._pack_latents(mask_condition)
|
| 780 |
+
|
| 781 |
+
if inpaint_image is not None:
|
| 782 |
+
inpaint_latent = self._patchify_latents(inpaint_latent)
|
| 783 |
+
inpaint_latent = (inpaint_latent - latents_bn_mean) / latents_bn_std
|
| 784 |
+
inpaint_latent = self._pack_latents(inpaint_latent)
|
| 785 |
+
else:
|
| 786 |
+
inpaint_latent = self._patchify_latents(inpaint_latent)
|
| 787 |
+
inpaint_latent = self._pack_latents(inpaint_latent)
|
| 788 |
+
|
| 789 |
+
if control_image is not None:
|
| 790 |
+
control_latents = self._patchify_latents(control_latents)
|
| 791 |
+
control_latents = (control_latents - latents_bn_mean) / latents_bn_std
|
| 792 |
+
control_latents = self._pack_latents(control_latents)
|
| 793 |
+
else:
|
| 794 |
+
control_latents = self._patchify_latents(control_latents)
|
| 795 |
+
control_latents = self._pack_latents(control_latents)
|
| 796 |
+
control_context = torch.concat([control_latents, mask_condition, inpaint_latent], dim=2)
|
| 797 |
+
|
| 798 |
+
# 3. prepare text embeddings
|
| 799 |
+
prompt_embeds, text_ids = self.encode_prompt(
|
| 800 |
+
prompt=prompt,
|
| 801 |
+
prompt_embeds=prompt_embeds,
|
| 802 |
+
device=device,
|
| 803 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 804 |
+
max_sequence_length=max_sequence_length,
|
| 805 |
+
text_encoder_out_layers=text_encoder_out_layers,
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
# 4. process images
|
| 809 |
+
if image is not None and not isinstance(image, list):
|
| 810 |
+
image = [image]
|
| 811 |
+
|
| 812 |
+
condition_images = None
|
| 813 |
+
if image is not None:
|
| 814 |
+
for img in image:
|
| 815 |
+
self.image_processor.check_image_input(img)
|
| 816 |
+
|
| 817 |
+
condition_images = []
|
| 818 |
+
for img in image:
|
| 819 |
+
image_width, image_height = img.size
|
| 820 |
+
if image_width * image_height > 1024 * 1024:
|
| 821 |
+
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
|
| 822 |
+
image_width, image_height = img.size
|
| 823 |
+
|
| 824 |
+
multiple_of = self.vae_scale_factor * 2
|
| 825 |
+
image_width = (image_width // multiple_of) * multiple_of
|
| 826 |
+
image_height = (image_height // multiple_of) * multiple_of
|
| 827 |
+
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
|
| 828 |
+
condition_images.append(img)
|
| 829 |
+
height = height or image_height
|
| 830 |
+
width = width or image_width
|
| 831 |
+
|
| 832 |
+
# 5. prepare latent variables
|
| 833 |
+
latents, latent_ids = self.prepare_latents(
|
| 834 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 835 |
+
num_latents_channels=num_channels_latents,
|
| 836 |
+
height=height,
|
| 837 |
+
width=width,
|
| 838 |
+
dtype=prompt_embeds.dtype,
|
| 839 |
+
device=device,
|
| 840 |
+
generator=generator,
|
| 841 |
+
latents=latents,
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
image_latents = None
|
| 845 |
+
image_latent_ids = None
|
| 846 |
+
if condition_images is not None:
|
| 847 |
+
image_latents, image_latent_ids = self.prepare_image_latents(
|
| 848 |
+
images=condition_images,
|
| 849 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 850 |
+
generator=generator,
|
| 851 |
+
device=device,
|
| 852 |
+
dtype=self.vae.dtype,
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
# 6. Prepare timesteps
|
| 856 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 857 |
+
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
|
| 858 |
+
sigmas = None
|
| 859 |
+
image_seq_len = latents.shape[1]
|
| 860 |
+
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
|
| 861 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 862 |
+
self.scheduler,
|
| 863 |
+
num_inference_steps,
|
| 864 |
+
device,
|
| 865 |
+
sigmas=sigmas,
|
| 866 |
+
mu=mu,
|
| 867 |
+
)
|
| 868 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 869 |
+
self._num_timesteps = len(timesteps)
|
| 870 |
+
|
| 871 |
+
# handle guidance
|
| 872 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 873 |
+
guidance = guidance.expand(latents.shape[0])
|
| 874 |
+
|
| 875 |
+
# 7. Denoising loop
|
| 876 |
+
# We set the index here to remove DtoH sync, helpful especially during compilation.
|
| 877 |
+
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
|
| 878 |
+
self.scheduler.set_begin_index(0)
|
| 879 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 880 |
+
for i, t in enumerate(timesteps):
|
| 881 |
+
if self.interrupt:
|
| 882 |
+
continue
|
| 883 |
+
|
| 884 |
+
self._current_timestep = t
|
| 885 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 886 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 887 |
+
|
| 888 |
+
latent_model_input = latents.to(self.transformer.dtype)
|
| 889 |
+
control_context_input = control_context.to(self.transformer.dtype)
|
| 890 |
+
latent_image_ids = latent_ids
|
| 891 |
+
|
| 892 |
+
if image_latents is not None:
|
| 893 |
+
latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
|
| 894 |
+
latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
|
| 895 |
+
|
| 896 |
+
local_bs, local_length, local_c = control_context.size()
|
| 897 |
+
control_context_input = torch.cat(
|
| 898 |
+
[
|
| 899 |
+
control_context,
|
| 900 |
+
torch.zeros(
|
| 901 |
+
[
|
| 902 |
+
local_bs,
|
| 903 |
+
image_latents.size()[1],
|
| 904 |
+
local_c
|
| 905 |
+
]
|
| 906 |
+
).to(control_context.device, control_context.dtype)],
|
| 907 |
+
dim=1
|
| 908 |
+
).to(self.transformer.dtype)
|
| 909 |
+
|
| 910 |
+
noise_pred = self.transformer(
|
| 911 |
+
hidden_states=latent_model_input, # (B, image_seq_len, C)
|
| 912 |
+
timestep=timestep / 1000,
|
| 913 |
+
guidance=guidance,
|
| 914 |
+
encoder_hidden_states=prompt_embeds,
|
| 915 |
+
txt_ids=text_ids, # B, text_seq_len, 4
|
| 916 |
+
img_ids=latent_image_ids, # B, image_seq_len, 4
|
| 917 |
+
joint_attention_kwargs=self._attention_kwargs,
|
| 918 |
+
control_context=control_context_input,
|
| 919 |
+
control_context_scale=control_context_scale,
|
| 920 |
+
return_dict=False,
|
| 921 |
+
)[0]
|
| 922 |
+
|
| 923 |
+
noise_pred = noise_pred[:, : latents.size(1) :]
|
| 924 |
+
|
| 925 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 926 |
+
latents_dtype = latents.dtype
|
| 927 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 928 |
+
|
| 929 |
+
if latents.dtype != latents_dtype:
|
| 930 |
+
if torch.backends.mps.is_available():
|
| 931 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 932 |
+
latents = latents.to(latents_dtype)
|
| 933 |
+
|
| 934 |
+
if callback_on_step_end is not None:
|
| 935 |
+
callback_kwargs = {}
|
| 936 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 937 |
+
callback_kwargs[k] = locals()[k]
|
| 938 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 939 |
+
|
| 940 |
+
latents = callback_outputs.pop("latents", latents)
|
| 941 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 942 |
+
|
| 943 |
+
# call the callback, if provided
|
| 944 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 945 |
+
progress_bar.update()
|
| 946 |
+
|
| 947 |
+
if XLA_AVAILABLE:
|
| 948 |
+
xm.mark_step()
|
| 949 |
+
|
| 950 |
+
self._current_timestep = None
|
| 951 |
+
|
| 952 |
+
if output_type == "latent":
|
| 953 |
+
image = latents
|
| 954 |
+
else:
|
| 955 |
+
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
| 956 |
+
|
| 957 |
+
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
| 958 |
+
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
|
| 959 |
+
latents.device, latents.dtype
|
| 960 |
+
)
|
| 961 |
+
latents = latents * latents_bn_std + latents_bn_mean
|
| 962 |
+
latents = self._unpatchify_latents(latents)
|
| 963 |
+
|
| 964 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 965 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 966 |
+
|
| 967 |
+
# Offload all models
|
| 968 |
+
self.maybe_free_model_hooks()
|
| 969 |
+
|
| 970 |
+
if not return_dict:
|
| 971 |
+
return (image,)
|
| 972 |
+
|
| 973 |
+
return Flux2PipelineOutput(images=image)
|
videox_fun/pipeline/pipeline_hunyuanvideo.py
ADDED
|
@@ -0,0 +1,805 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py
|
| 2 |
+
# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 23 |
+
from diffusers.loaders import HunyuanVideoLoraLoaderMixin
|
| 24 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 25 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 26 |
+
from diffusers.utils import (BaseOutput, deprecate, is_torch_xla_available,
|
| 27 |
+
logging, replace_example_docstring)
|
| 28 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 29 |
+
from diffusers.video_processor import VideoProcessor
|
| 30 |
+
|
| 31 |
+
from ..models import (AutoencoderKLHunyuanVideo, CLIPImageProcessor,
|
| 32 |
+
CLIPTextModel, CLIPTokenizer,
|
| 33 |
+
HunyuanVideoTransformer3DModel, LlamaModel,
|
| 34 |
+
LlamaTokenizerFast, LlavaForConditionalGeneration)
|
| 35 |
+
|
| 36 |
+
if is_torch_xla_available():
|
| 37 |
+
import torch_xla.core.xla_model as xm
|
| 38 |
+
|
| 39 |
+
XLA_AVAILABLE = True
|
| 40 |
+
else:
|
| 41 |
+
XLA_AVAILABLE = False
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
EXAMPLE_DOC_STRING = """
|
| 47 |
+
Examples:
|
| 48 |
+
```python
|
| 49 |
+
>>> import torch
|
| 50 |
+
>>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
| 51 |
+
>>> from diffusers.utils import export_to_video
|
| 52 |
+
|
| 53 |
+
>>> model_id = "hunyuanvideo-community/HunyuanVideo"
|
| 54 |
+
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
| 55 |
+
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
| 56 |
+
... )
|
| 57 |
+
>>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
|
| 58 |
+
>>> pipe.vae.enable_tiling()
|
| 59 |
+
>>> pipe.to("cuda")
|
| 60 |
+
|
| 61 |
+
>>> output = pipe(
|
| 62 |
+
... prompt="A cat walks on the grass, realistic",
|
| 63 |
+
... height=320,
|
| 64 |
+
... width=512,
|
| 65 |
+
... num_frames=61,
|
| 66 |
+
... num_inference_steps=30,
|
| 67 |
+
... ).frames[0]
|
| 68 |
+
>>> export_to_video(output, "output.mp4", fps=15)
|
| 69 |
+
```
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
DEFAULT_PROMPT_TEMPLATE = {
|
| 74 |
+
"template": (
|
| 75 |
+
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
|
| 76 |
+
"1. The main content and theme of the video."
|
| 77 |
+
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
| 78 |
+
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
| 79 |
+
"4. background environment, light, style and atmosphere."
|
| 80 |
+
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
| 81 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
| 82 |
+
),
|
| 83 |
+
"crop_start": 95,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 88 |
+
def retrieve_timesteps(
|
| 89 |
+
scheduler,
|
| 90 |
+
num_inference_steps: Optional[int] = None,
|
| 91 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 92 |
+
timesteps: Optional[List[int]] = None,
|
| 93 |
+
sigmas: Optional[List[float]] = None,
|
| 94 |
+
**kwargs,
|
| 95 |
+
):
|
| 96 |
+
r"""
|
| 97 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 98 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
scheduler (`SchedulerMixin`):
|
| 102 |
+
The scheduler to get timesteps from.
|
| 103 |
+
num_inference_steps (`int`):
|
| 104 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 105 |
+
must be `None`.
|
| 106 |
+
device (`str` or `torch.device`, *optional*):
|
| 107 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 108 |
+
timesteps (`List[int]`, *optional*):
|
| 109 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 110 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 111 |
+
sigmas (`List[float]`, *optional*):
|
| 112 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 113 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 117 |
+
second element is the number of inference steps.
|
| 118 |
+
"""
|
| 119 |
+
if timesteps is not None and sigmas is not None:
|
| 120 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 121 |
+
if timesteps is not None:
|
| 122 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 123 |
+
if not accepts_timesteps:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 126 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 127 |
+
)
|
| 128 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 129 |
+
timesteps = scheduler.timesteps
|
| 130 |
+
num_inference_steps = len(timesteps)
|
| 131 |
+
elif sigmas is not None:
|
| 132 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 133 |
+
if not accept_sigmas:
|
| 134 |
+
raise ValueError(
|
| 135 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 136 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 137 |
+
)
|
| 138 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 139 |
+
timesteps = scheduler.timesteps
|
| 140 |
+
num_inference_steps = len(timesteps)
|
| 141 |
+
else:
|
| 142 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 143 |
+
timesteps = scheduler.timesteps
|
| 144 |
+
return timesteps, num_inference_steps
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@dataclass
|
| 148 |
+
class HunyuanVideoPipelineOutput(BaseOutput):
|
| 149 |
+
r"""
|
| 150 |
+
Output class for video pipelines.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 154 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 155 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 156 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
videos: torch.Tensor
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class HunyuanVideoPipeline(DiffusionPipeline):
|
| 163 |
+
r"""
|
| 164 |
+
Pipeline for text-to-video generation using HunyuanVideo.
|
| 165 |
+
|
| 166 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 167 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
text_encoder ([`LlamaModel`]):
|
| 171 |
+
[Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
|
| 172 |
+
tokenizer (`LlamaTokenizer`):
|
| 173 |
+
Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
|
| 174 |
+
transformer ([`HunyuanVideoTransformer3DModel`]):
|
| 175 |
+
Conditional Transformer to denoise the encoded image latents.
|
| 176 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 177 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 178 |
+
vae ([`AutoencoderKLHunyuanVideo`]):
|
| 179 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 180 |
+
text_encoder_2 ([`CLIPTextModel`]):
|
| 181 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 182 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 183 |
+
tokenizer_2 (`CLIPTokenizer`):
|
| 184 |
+
Tokenizer of class
|
| 185 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
| 189 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 190 |
+
|
| 191 |
+
def __init__(
|
| 192 |
+
self,
|
| 193 |
+
text_encoder: LlamaModel,
|
| 194 |
+
tokenizer: LlamaTokenizerFast,
|
| 195 |
+
transformer: HunyuanVideoTransformer3DModel,
|
| 196 |
+
vae: AutoencoderKLHunyuanVideo,
|
| 197 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 198 |
+
text_encoder_2: CLIPTextModel,
|
| 199 |
+
tokenizer_2: CLIPTokenizer,
|
| 200 |
+
):
|
| 201 |
+
super().__init__()
|
| 202 |
+
|
| 203 |
+
self.register_modules(
|
| 204 |
+
vae=vae,
|
| 205 |
+
text_encoder=text_encoder,
|
| 206 |
+
tokenizer=tokenizer,
|
| 207 |
+
transformer=transformer,
|
| 208 |
+
scheduler=scheduler,
|
| 209 |
+
text_encoder_2=text_encoder_2,
|
| 210 |
+
tokenizer_2=tokenizer_2,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
| 214 |
+
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
|
| 215 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 216 |
+
|
| 217 |
+
def _get_llama_prompt_embeds(
|
| 218 |
+
self,
|
| 219 |
+
prompt: Union[str, List[str]],
|
| 220 |
+
prompt_template: Dict[str, Any],
|
| 221 |
+
num_videos_per_prompt: int = 1,
|
| 222 |
+
device: Optional[torch.device] = None,
|
| 223 |
+
dtype: Optional[torch.dtype] = None,
|
| 224 |
+
max_sequence_length: int = 256,
|
| 225 |
+
num_hidden_layers_to_skip: int = 2,
|
| 226 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 227 |
+
device = device or self._execution_device
|
| 228 |
+
dtype = dtype or self.text_encoder.dtype
|
| 229 |
+
|
| 230 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 231 |
+
batch_size = len(prompt)
|
| 232 |
+
|
| 233 |
+
prompt = [prompt_template["template"].format(p) for p in prompt]
|
| 234 |
+
|
| 235 |
+
crop_start = prompt_template.get("crop_start", None)
|
| 236 |
+
if crop_start is None:
|
| 237 |
+
prompt_template_input = self.tokenizer(
|
| 238 |
+
prompt_template["template"],
|
| 239 |
+
padding="max_length",
|
| 240 |
+
return_tensors="pt",
|
| 241 |
+
return_length=False,
|
| 242 |
+
return_overflowing_tokens=False,
|
| 243 |
+
return_attention_mask=False,
|
| 244 |
+
)
|
| 245 |
+
crop_start = prompt_template_input["input_ids"].shape[-1]
|
| 246 |
+
# Remove <|eot_id|> token and placeholder {}
|
| 247 |
+
crop_start -= 2
|
| 248 |
+
|
| 249 |
+
max_sequence_length += crop_start
|
| 250 |
+
text_inputs = self.tokenizer(
|
| 251 |
+
prompt,
|
| 252 |
+
max_length=max_sequence_length,
|
| 253 |
+
padding="max_length",
|
| 254 |
+
truncation=True,
|
| 255 |
+
return_tensors="pt",
|
| 256 |
+
return_length=False,
|
| 257 |
+
return_overflowing_tokens=False,
|
| 258 |
+
return_attention_mask=True,
|
| 259 |
+
)
|
| 260 |
+
text_input_ids = text_inputs.input_ids.to(device=device)
|
| 261 |
+
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
|
| 262 |
+
|
| 263 |
+
prompt_embeds = self.text_encoder(
|
| 264 |
+
input_ids=text_input_ids,
|
| 265 |
+
attention_mask=prompt_attention_mask,
|
| 266 |
+
output_hidden_states=True,
|
| 267 |
+
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
| 268 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
| 269 |
+
|
| 270 |
+
if crop_start is not None and crop_start > 0:
|
| 271 |
+
prompt_embeds = prompt_embeds[:, crop_start:]
|
| 272 |
+
prompt_attention_mask = prompt_attention_mask[:, crop_start:]
|
| 273 |
+
|
| 274 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 275 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 276 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 277 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 278 |
+
prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
|
| 279 |
+
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
|
| 280 |
+
|
| 281 |
+
return prompt_embeds, prompt_attention_mask
|
| 282 |
+
|
| 283 |
+
def _get_clip_prompt_embeds(
|
| 284 |
+
self,
|
| 285 |
+
prompt: Union[str, List[str]],
|
| 286 |
+
num_videos_per_prompt: int = 1,
|
| 287 |
+
device: Optional[torch.device] = None,
|
| 288 |
+
dtype: Optional[torch.dtype] = None,
|
| 289 |
+
max_sequence_length: int = 77,
|
| 290 |
+
) -> torch.Tensor:
|
| 291 |
+
device = device or self._execution_device
|
| 292 |
+
dtype = dtype or self.text_encoder_2.dtype
|
| 293 |
+
|
| 294 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 295 |
+
batch_size = len(prompt)
|
| 296 |
+
|
| 297 |
+
text_inputs = self.tokenizer_2(
|
| 298 |
+
prompt,
|
| 299 |
+
padding="max_length",
|
| 300 |
+
max_length=max_sequence_length,
|
| 301 |
+
truncation=True,
|
| 302 |
+
return_tensors="pt",
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
text_input_ids = text_inputs.input_ids
|
| 306 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
| 307 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 308 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 309 |
+
logger.warning(
|
| 310 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 311 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
|
| 315 |
+
|
| 316 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 317 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
|
| 318 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
|
| 319 |
+
|
| 320 |
+
return prompt_embeds
|
| 321 |
+
|
| 322 |
+
def encode_prompt(
|
| 323 |
+
self,
|
| 324 |
+
prompt: Union[str, List[str]],
|
| 325 |
+
prompt_2: Union[str, List[str]] = None,
|
| 326 |
+
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
|
| 327 |
+
num_videos_per_prompt: int = 1,
|
| 328 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 329 |
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 330 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 331 |
+
device: Optional[torch.device] = None,
|
| 332 |
+
dtype: Optional[torch.dtype] = None,
|
| 333 |
+
max_sequence_length: int = 256,
|
| 334 |
+
):
|
| 335 |
+
if prompt_embeds is None:
|
| 336 |
+
prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
|
| 337 |
+
prompt,
|
| 338 |
+
prompt_template,
|
| 339 |
+
num_videos_per_prompt,
|
| 340 |
+
device=device,
|
| 341 |
+
dtype=dtype,
|
| 342 |
+
max_sequence_length=max_sequence_length,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
if pooled_prompt_embeds is None:
|
| 346 |
+
if prompt_2 is None:
|
| 347 |
+
prompt_2 = prompt
|
| 348 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
| 349 |
+
prompt,
|
| 350 |
+
num_videos_per_prompt,
|
| 351 |
+
device=device,
|
| 352 |
+
dtype=dtype,
|
| 353 |
+
max_sequence_length=77,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
|
| 357 |
+
|
| 358 |
+
def check_inputs(
|
| 359 |
+
self,
|
| 360 |
+
prompt,
|
| 361 |
+
prompt_2,
|
| 362 |
+
height,
|
| 363 |
+
width,
|
| 364 |
+
prompt_embeds=None,
|
| 365 |
+
callback_on_step_end_tensor_inputs=None,
|
| 366 |
+
prompt_template=None,
|
| 367 |
+
):
|
| 368 |
+
if height % 16 != 0 or width % 16 != 0:
|
| 369 |
+
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
| 370 |
+
|
| 371 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 372 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 373 |
+
):
|
| 374 |
+
raise ValueError(
|
| 375 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
if prompt is not None and prompt_embeds is not None:
|
| 379 |
+
raise ValueError(
|
| 380 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 381 |
+
" only forward one of the two."
|
| 382 |
+
)
|
| 383 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 384 |
+
raise ValueError(
|
| 385 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 386 |
+
" only forward one of the two."
|
| 387 |
+
)
|
| 388 |
+
elif prompt is None and prompt_embeds is None:
|
| 389 |
+
raise ValueError(
|
| 390 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 391 |
+
)
|
| 392 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 393 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 394 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 395 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 396 |
+
|
| 397 |
+
if prompt_template is not None:
|
| 398 |
+
if not isinstance(prompt_template, dict):
|
| 399 |
+
raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}")
|
| 400 |
+
if "template" not in prompt_template:
|
| 401 |
+
raise ValueError(
|
| 402 |
+
f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
def prepare_latents(
|
| 406 |
+
self,
|
| 407 |
+
batch_size: int,
|
| 408 |
+
num_channels_latents: int = 32,
|
| 409 |
+
height: int = 720,
|
| 410 |
+
width: int = 1280,
|
| 411 |
+
num_frames: int = 129,
|
| 412 |
+
dtype: Optional[torch.dtype] = None,
|
| 413 |
+
device: Optional[torch.device] = None,
|
| 414 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 415 |
+
latents: Optional[torch.Tensor] = None,
|
| 416 |
+
) -> torch.Tensor:
|
| 417 |
+
if latents is not None:
|
| 418 |
+
return latents.to(device=device, dtype=dtype)
|
| 419 |
+
|
| 420 |
+
shape = (
|
| 421 |
+
batch_size,
|
| 422 |
+
num_channels_latents,
|
| 423 |
+
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
| 424 |
+
int(height) // self.vae_scale_factor_spatial,
|
| 425 |
+
int(width) // self.vae_scale_factor_spatial,
|
| 426 |
+
)
|
| 427 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 428 |
+
raise ValueError(
|
| 429 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 430 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 434 |
+
return latents
|
| 435 |
+
|
| 436 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 437 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 438 |
+
|
| 439 |
+
frames = self.vae.decode(latents).sample
|
| 440 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 441 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 442 |
+
frames = frames.cpu().float().numpy()
|
| 443 |
+
return frames
|
| 444 |
+
|
| 445 |
+
def enable_vae_slicing(self):
|
| 446 |
+
r"""
|
| 447 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 448 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 449 |
+
"""
|
| 450 |
+
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
|
| 451 |
+
deprecate(
|
| 452 |
+
"enable_vae_slicing",
|
| 453 |
+
"0.40.0",
|
| 454 |
+
depr_message,
|
| 455 |
+
)
|
| 456 |
+
self.vae.enable_slicing()
|
| 457 |
+
|
| 458 |
+
def disable_vae_slicing(self):
|
| 459 |
+
r"""
|
| 460 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 461 |
+
computing decoding in one step.
|
| 462 |
+
"""
|
| 463 |
+
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
|
| 464 |
+
deprecate(
|
| 465 |
+
"disable_vae_slicing",
|
| 466 |
+
"0.40.0",
|
| 467 |
+
depr_message,
|
| 468 |
+
)
|
| 469 |
+
self.vae.disable_slicing()
|
| 470 |
+
|
| 471 |
+
def enable_vae_tiling(self):
|
| 472 |
+
r"""
|
| 473 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 474 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 475 |
+
processing larger images.
|
| 476 |
+
"""
|
| 477 |
+
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
| 478 |
+
deprecate(
|
| 479 |
+
"enable_vae_tiling",
|
| 480 |
+
"0.40.0",
|
| 481 |
+
depr_message,
|
| 482 |
+
)
|
| 483 |
+
self.vae.enable_tiling()
|
| 484 |
+
|
| 485 |
+
def disable_vae_tiling(self):
|
| 486 |
+
r"""
|
| 487 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 488 |
+
computing decoding in one step.
|
| 489 |
+
"""
|
| 490 |
+
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
| 491 |
+
deprecate(
|
| 492 |
+
"disable_vae_tiling",
|
| 493 |
+
"0.40.0",
|
| 494 |
+
depr_message,
|
| 495 |
+
)
|
| 496 |
+
self.vae.disable_tiling()
|
| 497 |
+
|
| 498 |
+
@property
|
| 499 |
+
def guidance_scale(self):
|
| 500 |
+
return self._guidance_scale
|
| 501 |
+
|
| 502 |
+
@property
|
| 503 |
+
def num_timesteps(self):
|
| 504 |
+
return self._num_timesteps
|
| 505 |
+
|
| 506 |
+
@property
|
| 507 |
+
def attention_kwargs(self):
|
| 508 |
+
return self._attention_kwargs
|
| 509 |
+
|
| 510 |
+
@property
|
| 511 |
+
def current_timestep(self):
|
| 512 |
+
return self._current_timestep
|
| 513 |
+
|
| 514 |
+
@property
|
| 515 |
+
def interrupt(self):
|
| 516 |
+
return self._interrupt
|
| 517 |
+
|
| 518 |
+
@torch.no_grad()
|
| 519 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 520 |
+
def __call__(
|
| 521 |
+
self,
|
| 522 |
+
prompt: Union[str, List[str]] = None,
|
| 523 |
+
prompt_2: Union[str, List[str]] = None,
|
| 524 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 525 |
+
negative_prompt_2: Union[str, List[str]] = None,
|
| 526 |
+
height: int = 720,
|
| 527 |
+
width: int = 1280,
|
| 528 |
+
num_frames: int = 129,
|
| 529 |
+
num_inference_steps: int = 50,
|
| 530 |
+
sigmas: List[float] = None,
|
| 531 |
+
true_cfg_scale: float = 1.0,
|
| 532 |
+
guidance_scale: float = 6.0,
|
| 533 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 534 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 535 |
+
latents: Optional[torch.Tensor] = None,
|
| 536 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 537 |
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 538 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 539 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 540 |
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 541 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 542 |
+
output_type: str = "numpy",
|
| 543 |
+
return_dict: bool = False,
|
| 544 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 545 |
+
callback_on_step_end: Optional[
|
| 546 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 547 |
+
] = None,
|
| 548 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 549 |
+
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
|
| 550 |
+
max_sequence_length: int = 256,
|
| 551 |
+
):
|
| 552 |
+
r"""
|
| 553 |
+
The call function to the pipeline for generation.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 557 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 558 |
+
instead.
|
| 559 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 560 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 561 |
+
will be used instead.
|
| 562 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 563 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 564 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
| 565 |
+
not greater than `1`).
|
| 566 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 567 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 568 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 569 |
+
height (`int`, defaults to `720`):
|
| 570 |
+
The height in pixels of the generated image.
|
| 571 |
+
width (`int`, defaults to `1280`):
|
| 572 |
+
The width in pixels of the generated image.
|
| 573 |
+
num_frames (`int`, defaults to `129`):
|
| 574 |
+
The number of frames in the generated video.
|
| 575 |
+
num_inference_steps (`int`, defaults to `50`):
|
| 576 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 577 |
+
expense of slower inference.
|
| 578 |
+
sigmas (`List[float]`, *optional*):
|
| 579 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 580 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 581 |
+
will be used.
|
| 582 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
| 583 |
+
True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
|
| 584 |
+
`negative_prompt` is provided.
|
| 585 |
+
guidance_scale (`float`, defaults to `6.0`):
|
| 586 |
+
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
| 587 |
+
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
| 588 |
+
|
| 589 |
+
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
| 590 |
+
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
| 591 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 592 |
+
The number of images to generate per prompt.
|
| 593 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 594 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 595 |
+
generation deterministic.
|
| 596 |
+
latents (`torch.Tensor`, *optional*):
|
| 597 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 598 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 599 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 600 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 601 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 602 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 603 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 604 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 605 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 606 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 607 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 608 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 609 |
+
argument.
|
| 610 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 611 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 612 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 613 |
+
input argument.
|
| 614 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 615 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 616 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 617 |
+
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
|
| 618 |
+
attention_kwargs (`dict`, *optional*):
|
| 619 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 620 |
+
`self.processor` in
|
| 621 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 622 |
+
clip_skip (`int`, *optional*):
|
| 623 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 624 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 625 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 626 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 627 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 628 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 629 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 630 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 631 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 632 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 633 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 634 |
+
|
| 635 |
+
Examples:
|
| 636 |
+
|
| 637 |
+
Returns:
|
| 638 |
+
[`~HunyuanVideoPipelineOutput`] or `tuple`:
|
| 639 |
+
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
|
| 640 |
+
where the first element is a list with the generated images and the second element is a list of `bool`s
|
| 641 |
+
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
| 642 |
+
"""
|
| 643 |
+
|
| 644 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 645 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 646 |
+
|
| 647 |
+
# 1. Check inputs. Raise error if not correct
|
| 648 |
+
self.check_inputs(
|
| 649 |
+
prompt,
|
| 650 |
+
prompt_2,
|
| 651 |
+
height,
|
| 652 |
+
width,
|
| 653 |
+
prompt_embeds,
|
| 654 |
+
callback_on_step_end_tensor_inputs,
|
| 655 |
+
prompt_template,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
has_neg_prompt = negative_prompt is not None or (
|
| 659 |
+
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
|
| 660 |
+
)
|
| 661 |
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 662 |
+
|
| 663 |
+
self._guidance_scale = guidance_scale
|
| 664 |
+
self._attention_kwargs = attention_kwargs
|
| 665 |
+
self._current_timestep = None
|
| 666 |
+
self._interrupt = False
|
| 667 |
+
|
| 668 |
+
device = self._execution_device
|
| 669 |
+
|
| 670 |
+
# 2. Define call parameters
|
| 671 |
+
if prompt is not None and isinstance(prompt, str):
|
| 672 |
+
batch_size = 1
|
| 673 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 674 |
+
batch_size = len(prompt)
|
| 675 |
+
else:
|
| 676 |
+
batch_size = prompt_embeds.shape[0]
|
| 677 |
+
|
| 678 |
+
# 3. Encode input prompt
|
| 679 |
+
transformer_dtype = self.transformer.dtype
|
| 680 |
+
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
|
| 681 |
+
prompt=prompt,
|
| 682 |
+
prompt_2=prompt_2,
|
| 683 |
+
prompt_template=prompt_template,
|
| 684 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 685 |
+
prompt_embeds=prompt_embeds,
|
| 686 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 687 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 688 |
+
device=device,
|
| 689 |
+
max_sequence_length=max_sequence_length,
|
| 690 |
+
)
|
| 691 |
+
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
| 692 |
+
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
|
| 693 |
+
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
|
| 694 |
+
|
| 695 |
+
if do_true_cfg:
|
| 696 |
+
negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
|
| 697 |
+
prompt=negative_prompt,
|
| 698 |
+
prompt_2=negative_prompt_2,
|
| 699 |
+
prompt_template=prompt_template,
|
| 700 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 701 |
+
prompt_embeds=negative_prompt_embeds,
|
| 702 |
+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 703 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
| 704 |
+
device=device,
|
| 705 |
+
max_sequence_length=max_sequence_length,
|
| 706 |
+
)
|
| 707 |
+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
| 708 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
|
| 709 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
|
| 710 |
+
|
| 711 |
+
# 4. Prepare timesteps
|
| 712 |
+
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
| 713 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
| 714 |
+
|
| 715 |
+
# 5. Prepare latent variables
|
| 716 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 717 |
+
latents = self.prepare_latents(
|
| 718 |
+
batch_size * num_videos_per_prompt,
|
| 719 |
+
num_channels_latents,
|
| 720 |
+
height,
|
| 721 |
+
width,
|
| 722 |
+
num_frames,
|
| 723 |
+
torch.float32,
|
| 724 |
+
device,
|
| 725 |
+
generator,
|
| 726 |
+
latents,
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
# 6. Prepare guidance condition
|
| 730 |
+
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
|
| 731 |
+
|
| 732 |
+
# 7. Denoising loop
|
| 733 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 734 |
+
self._num_timesteps = len(timesteps)
|
| 735 |
+
|
| 736 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 737 |
+
for i, t in enumerate(timesteps):
|
| 738 |
+
if self.interrupt:
|
| 739 |
+
continue
|
| 740 |
+
|
| 741 |
+
self._current_timestep = t
|
| 742 |
+
latent_model_input = latents.to(transformer_dtype)
|
| 743 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 744 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 745 |
+
|
| 746 |
+
noise_pred = self.transformer(
|
| 747 |
+
hidden_states=latent_model_input,
|
| 748 |
+
timestep=timestep,
|
| 749 |
+
encoder_hidden_states=prompt_embeds,
|
| 750 |
+
encoder_attention_mask=prompt_attention_mask,
|
| 751 |
+
pooled_projections=pooled_prompt_embeds,
|
| 752 |
+
guidance=guidance,
|
| 753 |
+
attention_kwargs=attention_kwargs,
|
| 754 |
+
return_dict=False,
|
| 755 |
+
)[0]
|
| 756 |
+
|
| 757 |
+
if do_true_cfg:
|
| 758 |
+
neg_noise_pred = self.transformer(
|
| 759 |
+
hidden_states=latent_model_input,
|
| 760 |
+
timestep=timestep,
|
| 761 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 762 |
+
encoder_attention_mask=negative_prompt_attention_mask,
|
| 763 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
| 764 |
+
guidance=guidance,
|
| 765 |
+
attention_kwargs=attention_kwargs,
|
| 766 |
+
return_dict=False,
|
| 767 |
+
)[0]
|
| 768 |
+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 769 |
+
|
| 770 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 771 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 772 |
+
|
| 773 |
+
if callback_on_step_end is not None:
|
| 774 |
+
callback_kwargs = {}
|
| 775 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 776 |
+
callback_kwargs[k] = locals()[k]
|
| 777 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 778 |
+
|
| 779 |
+
latents = callback_outputs.pop("latents", latents)
|
| 780 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 781 |
+
|
| 782 |
+
# call the callback, if provided
|
| 783 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 784 |
+
progress_bar.update()
|
| 785 |
+
|
| 786 |
+
if XLA_AVAILABLE:
|
| 787 |
+
xm.mark_step()
|
| 788 |
+
|
| 789 |
+
self._current_timestep = None
|
| 790 |
+
|
| 791 |
+
if output_type == "numpy":
|
| 792 |
+
video = self.decode_latents(latents)
|
| 793 |
+
elif not output_type == "latent":
|
| 794 |
+
video = self.decode_latents(latents)
|
| 795 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 796 |
+
else:
|
| 797 |
+
video = latents
|
| 798 |
+
|
| 799 |
+
# Offload all models
|
| 800 |
+
self.maybe_free_model_hooks()
|
| 801 |
+
|
| 802 |
+
if not return_dict:
|
| 803 |
+
video = torch.from_numpy(video)
|
| 804 |
+
|
| 805 |
+
return HunyuanVideoPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_hunyuanvideo_i2v.py
ADDED
|
@@ -0,0 +1,972 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
|
| 2 |
+
# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import PIL
|
| 22 |
+
import torch
|
| 23 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 24 |
+
from diffusers.loaders import HunyuanVideoLoraLoaderMixin
|
| 25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 26 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 27 |
+
from diffusers.utils import (BaseOutput, deprecate, is_torch_xla_available,
|
| 28 |
+
logging, replace_example_docstring)
|
| 29 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 30 |
+
from diffusers.video_processor import VideoProcessor
|
| 31 |
+
|
| 32 |
+
from ..models import (AutoencoderKLHunyuanVideo, CLIPImageProcessor,
|
| 33 |
+
CLIPTextModel, CLIPTokenizer,
|
| 34 |
+
HunyuanVideoTransformer3DModel, LlamaModel,
|
| 35 |
+
LlamaTokenizerFast, LlavaForConditionalGeneration)
|
| 36 |
+
|
| 37 |
+
if is_torch_xla_available():
|
| 38 |
+
import torch_xla.core.xla_model as xm
|
| 39 |
+
|
| 40 |
+
XLA_AVAILABLE = True
|
| 41 |
+
else:
|
| 42 |
+
XLA_AVAILABLE = False
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
EXAMPLE_DOC_STRING = """
|
| 48 |
+
Examples:
|
| 49 |
+
```python
|
| 50 |
+
```
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
DEFAULT_PROMPT_TEMPLATE = {
|
| 55 |
+
"template": (
|
| 56 |
+
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
| 57 |
+
"1. The main content and theme of the video."
|
| 58 |
+
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
| 59 |
+
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
| 60 |
+
"4. background environment, light, style and atmosphere."
|
| 61 |
+
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
|
| 62 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
| 63 |
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 64 |
+
),
|
| 65 |
+
"crop_start": 103,
|
| 66 |
+
"image_emb_start": 5,
|
| 67 |
+
"image_emb_end": 581,
|
| 68 |
+
"image_emb_len": 576,
|
| 69 |
+
"double_return_token_id": 271,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 73 |
+
def retrieve_timesteps(
|
| 74 |
+
scheduler,
|
| 75 |
+
num_inference_steps: Optional[int] = None,
|
| 76 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 77 |
+
timesteps: Optional[List[int]] = None,
|
| 78 |
+
sigmas: Optional[List[float]] = None,
|
| 79 |
+
**kwargs,
|
| 80 |
+
):
|
| 81 |
+
r"""
|
| 82 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 83 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
scheduler (`SchedulerMixin`):
|
| 87 |
+
The scheduler to get timesteps from.
|
| 88 |
+
num_inference_steps (`int`):
|
| 89 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 90 |
+
must be `None`.
|
| 91 |
+
device (`str` or `torch.device`, *optional*):
|
| 92 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 93 |
+
timesteps (`List[int]`, *optional*):
|
| 94 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 95 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 96 |
+
sigmas (`List[float]`, *optional*):
|
| 97 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 98 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 102 |
+
second element is the number of inference steps.
|
| 103 |
+
"""
|
| 104 |
+
if timesteps is not None and sigmas is not None:
|
| 105 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 106 |
+
if timesteps is not None:
|
| 107 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 108 |
+
if not accepts_timesteps:
|
| 109 |
+
raise ValueError(
|
| 110 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 111 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 112 |
+
)
|
| 113 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 114 |
+
timesteps = scheduler.timesteps
|
| 115 |
+
num_inference_steps = len(timesteps)
|
| 116 |
+
elif sigmas is not None:
|
| 117 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 118 |
+
if not accept_sigmas:
|
| 119 |
+
raise ValueError(
|
| 120 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 121 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 122 |
+
)
|
| 123 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 124 |
+
timesteps = scheduler.timesteps
|
| 125 |
+
num_inference_steps = len(timesteps)
|
| 126 |
+
else:
|
| 127 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 128 |
+
timesteps = scheduler.timesteps
|
| 129 |
+
return timesteps, num_inference_steps
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _expand_input_ids_with_image_tokens(
|
| 133 |
+
text_input_ids,
|
| 134 |
+
prompt_attention_mask,
|
| 135 |
+
max_sequence_length,
|
| 136 |
+
image_token_index,
|
| 137 |
+
image_emb_len,
|
| 138 |
+
image_emb_start,
|
| 139 |
+
image_emb_end,
|
| 140 |
+
pad_token_id,
|
| 141 |
+
):
|
| 142 |
+
special_image_token_mask = text_input_ids == image_token_index
|
| 143 |
+
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
| 144 |
+
batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index)
|
| 145 |
+
|
| 146 |
+
max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1))
|
| 147 |
+
new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1
|
| 148 |
+
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
| 149 |
+
|
| 150 |
+
expanded_input_ids = torch.full(
|
| 151 |
+
(text_input_ids.shape[0], max_expanded_length),
|
| 152 |
+
pad_token_id,
|
| 153 |
+
dtype=text_input_ids.dtype,
|
| 154 |
+
device=text_input_ids.device,
|
| 155 |
+
)
|
| 156 |
+
expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices]
|
| 157 |
+
expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index
|
| 158 |
+
|
| 159 |
+
expanded_attention_mask = torch.zeros(
|
| 160 |
+
(text_input_ids.shape[0], max_expanded_length),
|
| 161 |
+
dtype=prompt_attention_mask.dtype,
|
| 162 |
+
device=prompt_attention_mask.device,
|
| 163 |
+
)
|
| 164 |
+
attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id)
|
| 165 |
+
expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0
|
| 166 |
+
expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype)
|
| 167 |
+
position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1)
|
| 168 |
+
|
| 169 |
+
return {
|
| 170 |
+
"input_ids": expanded_input_ids,
|
| 171 |
+
"attention_mask": expanded_attention_mask,
|
| 172 |
+
"position_ids": position_ids,
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@dataclass
|
| 177 |
+
class HunyuanVideoPipelineOutput(BaseOutput):
|
| 178 |
+
r"""
|
| 179 |
+
Output class for video pipelines.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 183 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 184 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 185 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
videos: torch.Tensor
|
| 189 |
+
|
| 190 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 191 |
+
def retrieve_latents(
|
| 192 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 193 |
+
):
|
| 194 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 195 |
+
return encoder_output.latent_dist.sample(generator)
|
| 196 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 197 |
+
return encoder_output.latent_dist.mode()
|
| 198 |
+
elif hasattr(encoder_output, "latents"):
|
| 199 |
+
return encoder_output.latents
|
| 200 |
+
else:
|
| 201 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 202 |
+
|
| 203 |
+
class HunyuanVideoI2VPipeline(DiffusionPipeline):
|
| 204 |
+
r"""
|
| 205 |
+
Pipeline for image-to-video generation using HunyuanVideo.
|
| 206 |
+
|
| 207 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 208 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
text_encoder ([`LlamaModel`]):
|
| 212 |
+
[Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
|
| 213 |
+
tokenizer (`LlamaTokenizer`):
|
| 214 |
+
Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
|
| 215 |
+
transformer ([`HunyuanVideoTransformer3DModel`]):
|
| 216 |
+
Conditional Transformer to denoise the encoded image latents.
|
| 217 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 218 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 219 |
+
vae ([`AutoencoderKLHunyuanVideo`]):
|
| 220 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 221 |
+
text_encoder_2 ([`CLIPTextModel`]):
|
| 222 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 223 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 224 |
+
tokenizer_2 (`CLIPTokenizer`):
|
| 225 |
+
Tokenizer of class
|
| 226 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
| 230 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 231 |
+
|
| 232 |
+
def __init__(
|
| 233 |
+
self,
|
| 234 |
+
text_encoder: LlavaForConditionalGeneration,
|
| 235 |
+
tokenizer: LlamaTokenizerFast,
|
| 236 |
+
transformer: HunyuanVideoTransformer3DModel,
|
| 237 |
+
vae: AutoencoderKLHunyuanVideo,
|
| 238 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 239 |
+
text_encoder_2: CLIPTextModel,
|
| 240 |
+
tokenizer_2: CLIPTokenizer,
|
| 241 |
+
image_processor: CLIPImageProcessor,
|
| 242 |
+
):
|
| 243 |
+
super().__init__()
|
| 244 |
+
|
| 245 |
+
self.register_modules(
|
| 246 |
+
vae=vae,
|
| 247 |
+
text_encoder=text_encoder,
|
| 248 |
+
tokenizer=tokenizer,
|
| 249 |
+
transformer=transformer,
|
| 250 |
+
scheduler=scheduler,
|
| 251 |
+
text_encoder_2=text_encoder_2,
|
| 252 |
+
tokenizer_2=tokenizer_2,
|
| 253 |
+
image_processor=image_processor,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
self.vae_scaling_factor = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.476986
|
| 257 |
+
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
| 258 |
+
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
|
| 259 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 260 |
+
|
| 261 |
+
def _get_llama_prompt_embeds(
|
| 262 |
+
self,
|
| 263 |
+
image: torch.Tensor,
|
| 264 |
+
prompt: Union[str, List[str]],
|
| 265 |
+
prompt_template: Dict[str, Any],
|
| 266 |
+
num_videos_per_prompt: int = 1,
|
| 267 |
+
device: Optional[torch.device] = None,
|
| 268 |
+
dtype: Optional[torch.dtype] = None,
|
| 269 |
+
max_sequence_length: int = 256,
|
| 270 |
+
num_hidden_layers_to_skip: int = 2,
|
| 271 |
+
image_embed_interleave: int = 2,
|
| 272 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 273 |
+
device = device or self._execution_device
|
| 274 |
+
dtype = dtype or self.text_encoder.dtype
|
| 275 |
+
|
| 276 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 277 |
+
prompt = [prompt_template["template"].format(p) for p in prompt]
|
| 278 |
+
|
| 279 |
+
crop_start = prompt_template.get("crop_start", None)
|
| 280 |
+
|
| 281 |
+
image_emb_len = prompt_template.get("image_emb_len", 576)
|
| 282 |
+
image_emb_start = prompt_template.get("image_emb_start", 5)
|
| 283 |
+
image_emb_end = prompt_template.get("image_emb_end", 581)
|
| 284 |
+
double_return_token_id = prompt_template.get("double_return_token_id", 271)
|
| 285 |
+
|
| 286 |
+
if crop_start is None:
|
| 287 |
+
prompt_template_input = self.tokenizer(
|
| 288 |
+
prompt_template["template"],
|
| 289 |
+
padding="max_length",
|
| 290 |
+
return_tensors="pt",
|
| 291 |
+
return_length=False,
|
| 292 |
+
return_overflowing_tokens=False,
|
| 293 |
+
return_attention_mask=False,
|
| 294 |
+
)
|
| 295 |
+
crop_start = prompt_template_input["input_ids"].shape[-1]
|
| 296 |
+
# Remove <|start_header_id|>, <|end_header_id|>, assistant, <|eot_id|>, and placeholder {}
|
| 297 |
+
crop_start -= 5
|
| 298 |
+
|
| 299 |
+
max_sequence_length += crop_start
|
| 300 |
+
text_inputs = self.tokenizer(
|
| 301 |
+
prompt,
|
| 302 |
+
max_length=max_sequence_length,
|
| 303 |
+
padding="max_length",
|
| 304 |
+
truncation=True,
|
| 305 |
+
return_tensors="pt",
|
| 306 |
+
return_length=False,
|
| 307 |
+
return_overflowing_tokens=False,
|
| 308 |
+
return_attention_mask=True,
|
| 309 |
+
)
|
| 310 |
+
text_input_ids = text_inputs.input_ids.to(device=device)
|
| 311 |
+
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
|
| 312 |
+
|
| 313 |
+
image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
|
| 314 |
+
|
| 315 |
+
image_token_index = self.text_encoder.config.image_token_index
|
| 316 |
+
pad_token_id = self.text_encoder.config.pad_token_id
|
| 317 |
+
expanded_inputs = _expand_input_ids_with_image_tokens(
|
| 318 |
+
text_input_ids,
|
| 319 |
+
prompt_attention_mask,
|
| 320 |
+
max_sequence_length,
|
| 321 |
+
image_token_index,
|
| 322 |
+
image_emb_len,
|
| 323 |
+
image_emb_start,
|
| 324 |
+
image_emb_end,
|
| 325 |
+
pad_token_id,
|
| 326 |
+
)
|
| 327 |
+
prompt_embeds = self.text_encoder(
|
| 328 |
+
**expanded_inputs,
|
| 329 |
+
pixel_values=image_embeds,
|
| 330 |
+
output_hidden_states=True,
|
| 331 |
+
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
| 332 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
| 333 |
+
|
| 334 |
+
if crop_start is not None and crop_start > 0:
|
| 335 |
+
text_crop_start = crop_start - 1 + image_emb_len
|
| 336 |
+
batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)
|
| 337 |
+
|
| 338 |
+
if last_double_return_token_indices.shape[0] == 3:
|
| 339 |
+
# in case the prompt is too long
|
| 340 |
+
last_double_return_token_indices = torch.cat(
|
| 341 |
+
(last_double_return_token_indices, torch.tensor([text_input_ids.shape[-1]]))
|
| 342 |
+
)
|
| 343 |
+
batch_indices = torch.cat((batch_indices, torch.tensor([0])))
|
| 344 |
+
|
| 345 |
+
last_double_return_token_indices = last_double_return_token_indices.reshape(text_input_ids.shape[0], -1)[
|
| 346 |
+
:, -1
|
| 347 |
+
]
|
| 348 |
+
batch_indices = batch_indices.reshape(text_input_ids.shape[0], -1)[:, -1]
|
| 349 |
+
assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4
|
| 350 |
+
assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len
|
| 351 |
+
attention_mask_assistant_crop_start = last_double_return_token_indices - 4
|
| 352 |
+
attention_mask_assistant_crop_end = last_double_return_token_indices
|
| 353 |
+
|
| 354 |
+
prompt_embed_list = []
|
| 355 |
+
prompt_attention_mask_list = []
|
| 356 |
+
image_embed_list = []
|
| 357 |
+
image_attention_mask_list = []
|
| 358 |
+
|
| 359 |
+
for i in range(text_input_ids.shape[0]):
|
| 360 |
+
prompt_embed_list.append(
|
| 361 |
+
torch.cat(
|
| 362 |
+
[
|
| 363 |
+
prompt_embeds[i, text_crop_start : assistant_crop_start[i].item()],
|
| 364 |
+
prompt_embeds[i, assistant_crop_end[i].item() :],
|
| 365 |
+
]
|
| 366 |
+
)
|
| 367 |
+
)
|
| 368 |
+
prompt_attention_mask_list.append(
|
| 369 |
+
torch.cat(
|
| 370 |
+
[
|
| 371 |
+
prompt_attention_mask[i, crop_start : attention_mask_assistant_crop_start[i].item()],
|
| 372 |
+
prompt_attention_mask[i, attention_mask_assistant_crop_end[i].item() :],
|
| 373 |
+
]
|
| 374 |
+
)
|
| 375 |
+
)
|
| 376 |
+
image_embed_list.append(prompt_embeds[i, image_emb_start:image_emb_end])
|
| 377 |
+
image_attention_mask_list.append(
|
| 378 |
+
torch.ones(image_embed_list[-1].shape[0]).to(prompt_embeds.device).to(prompt_attention_mask.dtype)
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
prompt_embed_list = torch.stack(prompt_embed_list)
|
| 382 |
+
prompt_attention_mask_list = torch.stack(prompt_attention_mask_list)
|
| 383 |
+
image_embed_list = torch.stack(image_embed_list)
|
| 384 |
+
image_attention_mask_list = torch.stack(image_attention_mask_list)
|
| 385 |
+
|
| 386 |
+
if 0 < image_embed_interleave < 6:
|
| 387 |
+
image_embed_list = image_embed_list[:, ::image_embed_interleave, :]
|
| 388 |
+
image_attention_mask_list = image_attention_mask_list[:, ::image_embed_interleave]
|
| 389 |
+
|
| 390 |
+
assert (
|
| 391 |
+
prompt_embed_list.shape[0] == prompt_attention_mask_list.shape[0]
|
| 392 |
+
and image_embed_list.shape[0] == image_attention_mask_list.shape[0]
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
prompt_embeds = torch.cat([image_embed_list, prompt_embed_list], dim=1)
|
| 396 |
+
prompt_attention_mask = torch.cat([image_attention_mask_list, prompt_attention_mask_list], dim=1)
|
| 397 |
+
|
| 398 |
+
return prompt_embeds, prompt_attention_mask
|
| 399 |
+
|
| 400 |
+
def _get_clip_prompt_embeds(
|
| 401 |
+
self,
|
| 402 |
+
prompt: Union[str, List[str]],
|
| 403 |
+
num_videos_per_prompt: int = 1,
|
| 404 |
+
device: Optional[torch.device] = None,
|
| 405 |
+
dtype: Optional[torch.dtype] = None,
|
| 406 |
+
max_sequence_length: int = 77,
|
| 407 |
+
) -> torch.Tensor:
|
| 408 |
+
device = device or self._execution_device
|
| 409 |
+
dtype = dtype or self.text_encoder_2.dtype
|
| 410 |
+
|
| 411 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 412 |
+
|
| 413 |
+
text_inputs = self.tokenizer_2(
|
| 414 |
+
prompt,
|
| 415 |
+
padding="max_length",
|
| 416 |
+
max_length=max_sequence_length,
|
| 417 |
+
truncation=True,
|
| 418 |
+
return_tensors="pt",
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
text_input_ids = text_inputs.input_ids
|
| 422 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
| 423 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 424 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 425 |
+
logger.warning(
|
| 426 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 427 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
|
| 431 |
+
return prompt_embeds
|
| 432 |
+
|
| 433 |
+
def encode_prompt(
|
| 434 |
+
self,
|
| 435 |
+
image: torch.Tensor,
|
| 436 |
+
prompt: Union[str, List[str]],
|
| 437 |
+
prompt_2: Union[str, List[str]] = None,
|
| 438 |
+
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
|
| 439 |
+
num_videos_per_prompt: int = 1,
|
| 440 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 441 |
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 442 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 443 |
+
device: Optional[torch.device] = None,
|
| 444 |
+
dtype: Optional[torch.dtype] = None,
|
| 445 |
+
max_sequence_length: int = 256,
|
| 446 |
+
image_embed_interleave: int = 2,
|
| 447 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 448 |
+
if prompt_embeds is None:
|
| 449 |
+
prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
|
| 450 |
+
image,
|
| 451 |
+
prompt,
|
| 452 |
+
prompt_template,
|
| 453 |
+
num_videos_per_prompt,
|
| 454 |
+
device=device,
|
| 455 |
+
dtype=dtype,
|
| 456 |
+
max_sequence_length=max_sequence_length,
|
| 457 |
+
image_embed_interleave=image_embed_interleave,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
if pooled_prompt_embeds is None:
|
| 461 |
+
if prompt_2 is None:
|
| 462 |
+
prompt_2 = prompt
|
| 463 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
| 464 |
+
prompt,
|
| 465 |
+
num_videos_per_prompt,
|
| 466 |
+
device=device,
|
| 467 |
+
dtype=dtype,
|
| 468 |
+
max_sequence_length=77,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
|
| 472 |
+
|
| 473 |
+
def check_inputs(
|
| 474 |
+
self,
|
| 475 |
+
prompt,
|
| 476 |
+
prompt_2,
|
| 477 |
+
height,
|
| 478 |
+
width,
|
| 479 |
+
prompt_embeds=None,
|
| 480 |
+
callback_on_step_end_tensor_inputs=None,
|
| 481 |
+
prompt_template=None,
|
| 482 |
+
true_cfg_scale=1.0,
|
| 483 |
+
guidance_scale=1.0,
|
| 484 |
+
):
|
| 485 |
+
if height % 16 != 0 or width % 16 != 0:
|
| 486 |
+
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
| 487 |
+
|
| 488 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 489 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 490 |
+
):
|
| 491 |
+
raise ValueError(
|
| 492 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
if prompt is not None and prompt_embeds is not None:
|
| 496 |
+
raise ValueError(
|
| 497 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 498 |
+
" only forward one of the two."
|
| 499 |
+
)
|
| 500 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 501 |
+
raise ValueError(
|
| 502 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 503 |
+
" only forward one of the two."
|
| 504 |
+
)
|
| 505 |
+
elif prompt is None and prompt_embeds is None:
|
| 506 |
+
raise ValueError(
|
| 507 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 508 |
+
)
|
| 509 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 510 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 511 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 512 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 513 |
+
|
| 514 |
+
if prompt_template is not None:
|
| 515 |
+
if not isinstance(prompt_template, dict):
|
| 516 |
+
raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}")
|
| 517 |
+
if "template" not in prompt_template:
|
| 518 |
+
raise ValueError(
|
| 519 |
+
f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
if true_cfg_scale > 1.0 and guidance_scale > 1.0:
|
| 523 |
+
logger.warning(
|
| 524 |
+
"Both `true_cfg_scale` and `guidance_scale` are greater than 1.0. This will result in both "
|
| 525 |
+
"classifier-free guidance and embedded-guidance to be applied. This is not recommended "
|
| 526 |
+
"as it may lead to higher memory usage, slower inference and potentially worse results."
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
def prepare_latents(
|
| 530 |
+
self,
|
| 531 |
+
image: torch.Tensor,
|
| 532 |
+
batch_size: int,
|
| 533 |
+
num_channels_latents: int = 32,
|
| 534 |
+
height: int = 720,
|
| 535 |
+
width: int = 1280,
|
| 536 |
+
num_frames: int = 129,
|
| 537 |
+
dtype: Optional[torch.dtype] = None,
|
| 538 |
+
device: Optional[torch.device] = None,
|
| 539 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 540 |
+
latents: Optional[torch.Tensor] = None,
|
| 541 |
+
) -> torch.Tensor:
|
| 542 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 543 |
+
raise ValueError(
|
| 544 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 545 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 549 |
+
latent_height, latent_width = height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial
|
| 550 |
+
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
| 551 |
+
|
| 552 |
+
image = image.unsqueeze(2) # [B, C, 1, H, W]
|
| 553 |
+
if isinstance(generator, list):
|
| 554 |
+
image_latents = [
|
| 555 |
+
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i], "argmax")
|
| 556 |
+
for i in range(batch_size)
|
| 557 |
+
]
|
| 558 |
+
else:
|
| 559 |
+
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator, "argmax") for img in image]
|
| 560 |
+
|
| 561 |
+
image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor
|
| 562 |
+
image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1)
|
| 563 |
+
|
| 564 |
+
if latents is None:
|
| 565 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 566 |
+
else:
|
| 567 |
+
latents = latents.to(device=device, dtype=dtype)
|
| 568 |
+
|
| 569 |
+
t = torch.tensor([0.999]).to(device=device)
|
| 570 |
+
latents = latents * t + image_latents * (1 - t)
|
| 571 |
+
|
| 572 |
+
image_latents = image_latents[:, :, :1]
|
| 573 |
+
return latents, image_latents
|
| 574 |
+
|
| 575 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 576 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 577 |
+
|
| 578 |
+
frames = self.vae.decode(latents).sample
|
| 579 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 580 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 581 |
+
frames = frames.cpu().float().numpy()
|
| 582 |
+
return frames
|
| 583 |
+
|
| 584 |
+
def enable_vae_slicing(self):
|
| 585 |
+
r"""
|
| 586 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 587 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 588 |
+
"""
|
| 589 |
+
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
|
| 590 |
+
deprecate(
|
| 591 |
+
"enable_vae_slicing",
|
| 592 |
+
"0.40.0",
|
| 593 |
+
depr_message,
|
| 594 |
+
)
|
| 595 |
+
self.vae.enable_slicing()
|
| 596 |
+
|
| 597 |
+
def disable_vae_slicing(self):
|
| 598 |
+
r"""
|
| 599 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 600 |
+
computing decoding in one step.
|
| 601 |
+
"""
|
| 602 |
+
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
|
| 603 |
+
deprecate(
|
| 604 |
+
"disable_vae_slicing",
|
| 605 |
+
"0.40.0",
|
| 606 |
+
depr_message,
|
| 607 |
+
)
|
| 608 |
+
self.vae.disable_slicing()
|
| 609 |
+
|
| 610 |
+
def enable_vae_tiling(self):
|
| 611 |
+
r"""
|
| 612 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 613 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 614 |
+
processing larger images.
|
| 615 |
+
"""
|
| 616 |
+
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
| 617 |
+
deprecate(
|
| 618 |
+
"enable_vae_tiling",
|
| 619 |
+
"0.40.0",
|
| 620 |
+
depr_message,
|
| 621 |
+
)
|
| 622 |
+
self.vae.enable_tiling()
|
| 623 |
+
|
| 624 |
+
def disable_vae_tiling(self):
|
| 625 |
+
r"""
|
| 626 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 627 |
+
computing decoding in one step.
|
| 628 |
+
"""
|
| 629 |
+
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
| 630 |
+
deprecate(
|
| 631 |
+
"disable_vae_tiling",
|
| 632 |
+
"0.40.0",
|
| 633 |
+
depr_message,
|
| 634 |
+
)
|
| 635 |
+
self.vae.disable_tiling()
|
| 636 |
+
|
| 637 |
+
@property
|
| 638 |
+
def guidance_scale(self):
|
| 639 |
+
return self._guidance_scale
|
| 640 |
+
|
| 641 |
+
@property
|
| 642 |
+
def num_timesteps(self):
|
| 643 |
+
return self._num_timesteps
|
| 644 |
+
|
| 645 |
+
@property
|
| 646 |
+
def attention_kwargs(self):
|
| 647 |
+
return self._attention_kwargs
|
| 648 |
+
|
| 649 |
+
@property
|
| 650 |
+
def current_timestep(self):
|
| 651 |
+
return self._current_timestep
|
| 652 |
+
|
| 653 |
+
@property
|
| 654 |
+
def interrupt(self):
|
| 655 |
+
return self._interrupt
|
| 656 |
+
|
| 657 |
+
@torch.no_grad()
|
| 658 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 659 |
+
def __call__(
|
| 660 |
+
self,
|
| 661 |
+
prompt: Union[str, List[str]] = None,
|
| 662 |
+
prompt_2: Union[str, List[str]] = None,
|
| 663 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 664 |
+
negative_prompt_2: Union[str, List[str]] = None,
|
| 665 |
+
height: int = 720,
|
| 666 |
+
width: int = 1280,
|
| 667 |
+
num_frames: int = 129,
|
| 668 |
+
num_inference_steps: int = 50,
|
| 669 |
+
sigmas: List[float] = None,
|
| 670 |
+
true_cfg_scale: float = 1.0,
|
| 671 |
+
guidance_scale: float = 6.0,
|
| 672 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 673 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 674 |
+
latents: Optional[torch.Tensor] = None,
|
| 675 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 676 |
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 677 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 678 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 679 |
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 680 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 681 |
+
output_type: str = "numpy",
|
| 682 |
+
return_dict: bool = False,
|
| 683 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 684 |
+
callback_on_step_end: Optional[
|
| 685 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 686 |
+
] = None,
|
| 687 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 688 |
+
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
|
| 689 |
+
image: PIL.Image.Image = None,
|
| 690 |
+
max_sequence_length: int = 256,
|
| 691 |
+
image_embed_interleave: Optional[int] = None,
|
| 692 |
+
):
|
| 693 |
+
r"""
|
| 694 |
+
The call function to the pipeline for generation.
|
| 695 |
+
|
| 696 |
+
Args:
|
| 697 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 698 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 699 |
+
instead.
|
| 700 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 701 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 702 |
+
will be used instead.
|
| 703 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 704 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 705 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
| 706 |
+
not greater than `1`).
|
| 707 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 708 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 709 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 710 |
+
height (`int`, defaults to `720`):
|
| 711 |
+
The height in pixels of the generated image.
|
| 712 |
+
width (`int`, defaults to `1280`):
|
| 713 |
+
The width in pixels of the generated image.
|
| 714 |
+
num_frames (`int`, defaults to `129`):
|
| 715 |
+
The number of frames in the generated video.
|
| 716 |
+
num_inference_steps (`int`, defaults to `50`):
|
| 717 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 718 |
+
expense of slower inference.
|
| 719 |
+
sigmas (`List[float]`, *optional*):
|
| 720 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 721 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 722 |
+
will be used.
|
| 723 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
| 724 |
+
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
|
| 725 |
+
guidance_scale (`float`, defaults to `1.0`):
|
| 726 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 727 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 728 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 729 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 730 |
+
the text `prompt`, usually at the expense of lower image quality. Note that the only available
|
| 731 |
+
HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and
|
| 732 |
+
conditional latent is not applied.
|
| 733 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 734 |
+
The number of images to generate per prompt.
|
| 735 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 736 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 737 |
+
generation deterministic.
|
| 738 |
+
latents (`torch.Tensor`, *optional*):
|
| 739 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 740 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 741 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 742 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 743 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 744 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 745 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 746 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 747 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 748 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 749 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 750 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 751 |
+
argument.
|
| 752 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 753 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 754 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 755 |
+
input argument.
|
| 756 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 757 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 758 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 759 |
+
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
|
| 760 |
+
attention_kwargs (`dict`, *optional*):
|
| 761 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 762 |
+
`self.processor` in
|
| 763 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 764 |
+
clip_skip (`int`, *optional*):
|
| 765 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 766 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 767 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 768 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 769 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 770 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 771 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 772 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 773 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 774 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 775 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 776 |
+
|
| 777 |
+
Examples:
|
| 778 |
+
|
| 779 |
+
Returns:
|
| 780 |
+
[`~HunyuanVideoPipelineOutput`] or `tuple`:
|
| 781 |
+
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
|
| 782 |
+
where the first element is a list with the generated images and the second element is a list of `bool`s
|
| 783 |
+
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
| 784 |
+
"""
|
| 785 |
+
|
| 786 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 787 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 788 |
+
|
| 789 |
+
# 1. Check inputs. Raise error if not correct
|
| 790 |
+
self.check_inputs(
|
| 791 |
+
prompt,
|
| 792 |
+
prompt_2,
|
| 793 |
+
height,
|
| 794 |
+
width,
|
| 795 |
+
prompt_embeds,
|
| 796 |
+
callback_on_step_end_tensor_inputs,
|
| 797 |
+
prompt_template,
|
| 798 |
+
true_cfg_scale,
|
| 799 |
+
guidance_scale,
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
has_neg_prompt = negative_prompt is not None or (
|
| 803 |
+
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
|
| 804 |
+
)
|
| 805 |
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 806 |
+
image_embed_interleave = (
|
| 807 |
+
image_embed_interleave
|
| 808 |
+
if image_embed_interleave is not None
|
| 809 |
+
else 4
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
self._guidance_scale = guidance_scale
|
| 813 |
+
self._attention_kwargs = attention_kwargs
|
| 814 |
+
self._current_timestep = None
|
| 815 |
+
self._interrupt = False
|
| 816 |
+
|
| 817 |
+
device = self._execution_device
|
| 818 |
+
|
| 819 |
+
# 2. Define call parameters
|
| 820 |
+
if prompt is not None and isinstance(prompt, str):
|
| 821 |
+
batch_size = 1
|
| 822 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 823 |
+
batch_size = len(prompt)
|
| 824 |
+
else:
|
| 825 |
+
batch_size = prompt_embeds.shape[0]
|
| 826 |
+
|
| 827 |
+
# 3. Prepare latent variables
|
| 828 |
+
vae_dtype = self.vae.dtype
|
| 829 |
+
image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype)
|
| 830 |
+
|
| 831 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 832 |
+
|
| 833 |
+
latents, image_latents = self.prepare_latents(
|
| 834 |
+
image_tensor,
|
| 835 |
+
batch_size * num_videos_per_prompt,
|
| 836 |
+
num_channels_latents,
|
| 837 |
+
height,
|
| 838 |
+
width,
|
| 839 |
+
num_frames,
|
| 840 |
+
torch.float32,
|
| 841 |
+
device,
|
| 842 |
+
generator,
|
| 843 |
+
latents,
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
# 4. Encode input prompt
|
| 847 |
+
transformer_dtype = self.transformer.dtype
|
| 848 |
+
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
|
| 849 |
+
image=image,
|
| 850 |
+
prompt=prompt,
|
| 851 |
+
prompt_2=prompt_2,
|
| 852 |
+
prompt_template=prompt_template,
|
| 853 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 854 |
+
prompt_embeds=prompt_embeds,
|
| 855 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 856 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 857 |
+
device=device,
|
| 858 |
+
max_sequence_length=max_sequence_length,
|
| 859 |
+
image_embed_interleave=image_embed_interleave,
|
| 860 |
+
)
|
| 861 |
+
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
| 862 |
+
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
|
| 863 |
+
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
|
| 864 |
+
|
| 865 |
+
if do_true_cfg:
|
| 866 |
+
black_image = PIL.Image.new("RGB", (width, height), 0)
|
| 867 |
+
negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
|
| 868 |
+
image=black_image,
|
| 869 |
+
prompt=negative_prompt,
|
| 870 |
+
prompt_2=negative_prompt_2,
|
| 871 |
+
prompt_template=prompt_template,
|
| 872 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 873 |
+
prompt_embeds=negative_prompt_embeds,
|
| 874 |
+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 875 |
+
prompt_attention_mask=negative_prompt_attention_mask,
|
| 876 |
+
device=device,
|
| 877 |
+
max_sequence_length=max_sequence_length,
|
| 878 |
+
)
|
| 879 |
+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
| 880 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
|
| 881 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
|
| 882 |
+
|
| 883 |
+
# 5. Prepare timesteps
|
| 884 |
+
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
| 885 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
| 886 |
+
|
| 887 |
+
# 6. Prepare guidance condition
|
| 888 |
+
guidance = None
|
| 889 |
+
if self.transformer.config.guidance_embeds:
|
| 890 |
+
guidance = (
|
| 891 |
+
torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
# 7. Denoising loop
|
| 895 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 896 |
+
self._num_timesteps = len(timesteps)
|
| 897 |
+
|
| 898 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 899 |
+
for i, t in enumerate(timesteps):
|
| 900 |
+
if self.interrupt:
|
| 901 |
+
continue
|
| 902 |
+
|
| 903 |
+
self._current_timestep = t
|
| 904 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 905 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 906 |
+
|
| 907 |
+
latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype)
|
| 908 |
+
|
| 909 |
+
noise_pred = self.transformer(
|
| 910 |
+
hidden_states=latent_model_input,
|
| 911 |
+
timestep=timestep,
|
| 912 |
+
encoder_hidden_states=prompt_embeds,
|
| 913 |
+
encoder_attention_mask=prompt_attention_mask,
|
| 914 |
+
pooled_projections=pooled_prompt_embeds,
|
| 915 |
+
guidance=guidance,
|
| 916 |
+
attention_kwargs=attention_kwargs,
|
| 917 |
+
return_dict=False,
|
| 918 |
+
)[0]
|
| 919 |
+
|
| 920 |
+
if do_true_cfg:
|
| 921 |
+
neg_noise_pred = self.transformer(
|
| 922 |
+
hidden_states=latent_model_input,
|
| 923 |
+
timestep=timestep,
|
| 924 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 925 |
+
encoder_attention_mask=negative_prompt_attention_mask,
|
| 926 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
| 927 |
+
guidance=guidance,
|
| 928 |
+
attention_kwargs=attention_kwargs,
|
| 929 |
+
return_dict=False,
|
| 930 |
+
)[0]
|
| 931 |
+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 932 |
+
|
| 933 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 934 |
+
latents = latents = self.scheduler.step(
|
| 935 |
+
noise_pred[:, :, 1:], t, latents[:, :, 1:], return_dict=False
|
| 936 |
+
)[0]
|
| 937 |
+
latents = torch.cat([image_latents, latents], dim=2)
|
| 938 |
+
latents = latents.to(self.vae.dtype)
|
| 939 |
+
|
| 940 |
+
if callback_on_step_end is not None:
|
| 941 |
+
callback_kwargs = {}
|
| 942 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 943 |
+
callback_kwargs[k] = locals()[k]
|
| 944 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 945 |
+
|
| 946 |
+
latents = callback_outputs.pop("latents", latents)
|
| 947 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 948 |
+
|
| 949 |
+
# call the callback, if provided
|
| 950 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 951 |
+
progress_bar.update()
|
| 952 |
+
|
| 953 |
+
if XLA_AVAILABLE:
|
| 954 |
+
xm.mark_step()
|
| 955 |
+
|
| 956 |
+
self._current_timestep = None
|
| 957 |
+
|
| 958 |
+
if output_type == "numpy":
|
| 959 |
+
video = self.decode_latents(latents)
|
| 960 |
+
elif not output_type == "latent":
|
| 961 |
+
video = self.decode_latents(latents)
|
| 962 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 963 |
+
else:
|
| 964 |
+
video = latents
|
| 965 |
+
|
| 966 |
+
# Offload all models
|
| 967 |
+
self.maybe_free_model_hooks()
|
| 968 |
+
|
| 969 |
+
if not return_dict:
|
| 970 |
+
video = torch.from_numpy(video)
|
| 971 |
+
|
| 972 |
+
return HunyuanVideoPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_qwenimage.py
ADDED
|
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
|
| 2 |
+
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import PIL.Image
|
| 22 |
+
import torch
|
| 23 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 24 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 25 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 26 |
+
from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
|
| 27 |
+
replace_example_docstring)
|
| 28 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 29 |
+
|
| 30 |
+
from ..models import (AutoencoderKLQwenImage,
|
| 31 |
+
Qwen2_5_VLForConditionalGeneration,
|
| 32 |
+
Qwen2Tokenizer, QwenImageTransformer2DModel)
|
| 33 |
+
|
| 34 |
+
if is_torch_xla_available():
|
| 35 |
+
import torch_xla.core.xla_model as xm
|
| 36 |
+
|
| 37 |
+
XLA_AVAILABLE = True
|
| 38 |
+
else:
|
| 39 |
+
XLA_AVAILABLE = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 43 |
+
|
| 44 |
+
EXAMPLE_DOC_STRING = """
|
| 45 |
+
Examples:
|
| 46 |
+
```py
|
| 47 |
+
>>> import torch
|
| 48 |
+
>>> from diffusers import QwenImagePipeline
|
| 49 |
+
|
| 50 |
+
>>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
|
| 51 |
+
>>> pipe.to("cuda")
|
| 52 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
| 53 |
+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
| 54 |
+
>>> # Refer to the pipeline documentation for more details.
|
| 55 |
+
>>> image = pipe(prompt, num_inference_steps=50).images[0]
|
| 56 |
+
>>> image.save("qwenimage.png")
|
| 57 |
+
```
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def calculate_shift(
|
| 62 |
+
image_seq_len,
|
| 63 |
+
base_seq_len: int = 256,
|
| 64 |
+
max_seq_len: int = 4096,
|
| 65 |
+
base_shift: float = 0.5,
|
| 66 |
+
max_shift: float = 1.15,
|
| 67 |
+
):
|
| 68 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 69 |
+
b = base_shift - m * base_seq_len
|
| 70 |
+
mu = image_seq_len * m + b
|
| 71 |
+
return mu
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 75 |
+
def retrieve_timesteps(
|
| 76 |
+
scheduler,
|
| 77 |
+
num_inference_steps: Optional[int] = None,
|
| 78 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 79 |
+
timesteps: Optional[List[int]] = None,
|
| 80 |
+
sigmas: Optional[List[float]] = None,
|
| 81 |
+
**kwargs,
|
| 82 |
+
):
|
| 83 |
+
r"""
|
| 84 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 85 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
scheduler (`SchedulerMixin`):
|
| 89 |
+
The scheduler to get timesteps from.
|
| 90 |
+
num_inference_steps (`int`):
|
| 91 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 92 |
+
must be `None`.
|
| 93 |
+
device (`str` or `torch.device`, *optional*):
|
| 94 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 95 |
+
timesteps (`List[int]`, *optional*):
|
| 96 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 97 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 98 |
+
sigmas (`List[float]`, *optional*):
|
| 99 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 100 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 104 |
+
second element is the number of inference steps.
|
| 105 |
+
"""
|
| 106 |
+
if timesteps is not None and sigmas is not None:
|
| 107 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 108 |
+
if timesteps is not None:
|
| 109 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 110 |
+
if not accepts_timesteps:
|
| 111 |
+
raise ValueError(
|
| 112 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 113 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 114 |
+
)
|
| 115 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 116 |
+
timesteps = scheduler.timesteps
|
| 117 |
+
num_inference_steps = len(timesteps)
|
| 118 |
+
elif sigmas is not None:
|
| 119 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 120 |
+
if not accept_sigmas:
|
| 121 |
+
raise ValueError(
|
| 122 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 123 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 124 |
+
)
|
| 125 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 126 |
+
timesteps = scheduler.timesteps
|
| 127 |
+
num_inference_steps = len(timesteps)
|
| 128 |
+
else:
|
| 129 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 130 |
+
timesteps = scheduler.timesteps
|
| 131 |
+
return timesteps, num_inference_steps
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@dataclass
|
| 135 |
+
class QwenImagePipelineOutput(BaseOutput):
|
| 136 |
+
"""
|
| 137 |
+
Output class for Stable Diffusion pipelines.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
| 141 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
| 142 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class QwenImagePipeline(DiffusionPipeline):
|
| 149 |
+
r"""
|
| 150 |
+
The QwenImage pipeline for text-to-image generation.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
transformer ([`QwenImageTransformer2DModel`]):
|
| 154 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 155 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 156 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 157 |
+
vae ([`AutoencoderKL`]):
|
| 158 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 159 |
+
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
| 160 |
+
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
| 161 |
+
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
| 162 |
+
tokenizer (`QwenTokenizer`):
|
| 163 |
+
Tokenizer of class
|
| 164 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 168 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 169 |
+
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 173 |
+
vae: AutoencoderKLQwenImage,
|
| 174 |
+
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
| 175 |
+
tokenizer: Qwen2Tokenizer,
|
| 176 |
+
transformer: QwenImageTransformer2DModel,
|
| 177 |
+
):
|
| 178 |
+
super().__init__()
|
| 179 |
+
|
| 180 |
+
self.register_modules(
|
| 181 |
+
vae=vae,
|
| 182 |
+
text_encoder=text_encoder,
|
| 183 |
+
tokenizer=tokenizer,
|
| 184 |
+
transformer=transformer,
|
| 185 |
+
scheduler=scheduler,
|
| 186 |
+
)
|
| 187 |
+
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
| 188 |
+
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
| 189 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
| 190 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 191 |
+
self.tokenizer_max_length = 1024
|
| 192 |
+
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
| 193 |
+
self.prompt_template_encode_start_idx = 34
|
| 194 |
+
self.default_sample_size = 128
|
| 195 |
+
|
| 196 |
+
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
| 197 |
+
bool_mask = mask.bool()
|
| 198 |
+
valid_lengths = bool_mask.sum(dim=1)
|
| 199 |
+
selected = hidden_states[bool_mask]
|
| 200 |
+
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
| 201 |
+
|
| 202 |
+
return split_result
|
| 203 |
+
|
| 204 |
+
def _get_qwen_prompt_embeds(
|
| 205 |
+
self,
|
| 206 |
+
prompt: Union[str, List[str]] = None,
|
| 207 |
+
device: Optional[torch.device] = None,
|
| 208 |
+
dtype: Optional[torch.dtype] = None,
|
| 209 |
+
):
|
| 210 |
+
device = device or self._execution_device
|
| 211 |
+
dtype = dtype or self.text_encoder.dtype
|
| 212 |
+
|
| 213 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 214 |
+
|
| 215 |
+
template = self.prompt_template_encode
|
| 216 |
+
drop_idx = self.prompt_template_encode_start_idx
|
| 217 |
+
txt = [template.format(e) for e in prompt]
|
| 218 |
+
txt_tokens = self.tokenizer(
|
| 219 |
+
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
|
| 220 |
+
).to(device)
|
| 221 |
+
encoder_hidden_states = self.text_encoder(
|
| 222 |
+
input_ids=txt_tokens.input_ids,
|
| 223 |
+
attention_mask=txt_tokens.attention_mask,
|
| 224 |
+
output_hidden_states=True,
|
| 225 |
+
)
|
| 226 |
+
hidden_states = encoder_hidden_states.hidden_states[-1]
|
| 227 |
+
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
| 228 |
+
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
| 229 |
+
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
| 230 |
+
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
| 231 |
+
prompt_embeds = torch.stack(
|
| 232 |
+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
| 233 |
+
)
|
| 234 |
+
encoder_attention_mask = torch.stack(
|
| 235 |
+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 239 |
+
|
| 240 |
+
return prompt_embeds, encoder_attention_mask
|
| 241 |
+
|
| 242 |
+
def encode_prompt(
|
| 243 |
+
self,
|
| 244 |
+
prompt: Union[str, List[str]],
|
| 245 |
+
device: Optional[torch.device] = None,
|
| 246 |
+
num_images_per_prompt: int = 1,
|
| 247 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 248 |
+
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
| 249 |
+
max_sequence_length: int = 1024,
|
| 250 |
+
):
|
| 251 |
+
r"""
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 255 |
+
prompt to be encoded
|
| 256 |
+
device: (`torch.device`):
|
| 257 |
+
torch device
|
| 258 |
+
num_images_per_prompt (`int`):
|
| 259 |
+
number of images that should be generated per prompt
|
| 260 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 261 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 262 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 263 |
+
"""
|
| 264 |
+
device = device or self._execution_device
|
| 265 |
+
|
| 266 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 267 |
+
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
|
| 268 |
+
|
| 269 |
+
if prompt_embeds is None:
|
| 270 |
+
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
|
| 271 |
+
|
| 272 |
+
prompt_embeds = prompt_embeds[:, :max_sequence_length]
|
| 273 |
+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
|
| 274 |
+
|
| 275 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 276 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 277 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 278 |
+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
| 279 |
+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
| 280 |
+
|
| 281 |
+
return prompt_embeds, prompt_embeds_mask
|
| 282 |
+
|
| 283 |
+
def check_inputs(
|
| 284 |
+
self,
|
| 285 |
+
prompt,
|
| 286 |
+
height,
|
| 287 |
+
width,
|
| 288 |
+
negative_prompt=None,
|
| 289 |
+
prompt_embeds=None,
|
| 290 |
+
negative_prompt_embeds=None,
|
| 291 |
+
prompt_embeds_mask=None,
|
| 292 |
+
negative_prompt_embeds_mask=None,
|
| 293 |
+
callback_on_step_end_tensor_inputs=None,
|
| 294 |
+
max_sequence_length=None,
|
| 295 |
+
):
|
| 296 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 297 |
+
logger.warning(
|
| 298 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 302 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 303 |
+
):
|
| 304 |
+
raise ValueError(
|
| 305 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if prompt is not None and prompt_embeds is not None:
|
| 309 |
+
raise ValueError(
|
| 310 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 311 |
+
" only forward one of the two."
|
| 312 |
+
)
|
| 313 |
+
elif prompt is None and prompt_embeds is None:
|
| 314 |
+
raise ValueError(
|
| 315 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 316 |
+
)
|
| 317 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 318 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 319 |
+
|
| 320 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 321 |
+
raise ValueError(
|
| 322 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 323 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
if prompt_embeds is not None and prompt_embeds_mask is None:
|
| 327 |
+
raise ValueError(
|
| 328 |
+
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
| 329 |
+
)
|
| 330 |
+
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
| 331 |
+
raise ValueError(
|
| 332 |
+
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if max_sequence_length is not None and max_sequence_length > 1024:
|
| 336 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
| 337 |
+
|
| 338 |
+
@staticmethod
|
| 339 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 340 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 341 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 342 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 343 |
+
|
| 344 |
+
return latents
|
| 345 |
+
|
| 346 |
+
@staticmethod
|
| 347 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 348 |
+
batch_size, num_patches, channels = latents.shape
|
| 349 |
+
|
| 350 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 351 |
+
# latent height and width to be divisible by 2.
|
| 352 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 353 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 354 |
+
|
| 355 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 356 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 357 |
+
|
| 358 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
|
| 359 |
+
|
| 360 |
+
return latents
|
| 361 |
+
|
| 362 |
+
def enable_vae_slicing(self):
|
| 363 |
+
r"""
|
| 364 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 365 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 366 |
+
"""
|
| 367 |
+
self.vae.enable_slicing()
|
| 368 |
+
|
| 369 |
+
def disable_vae_slicing(self):
|
| 370 |
+
r"""
|
| 371 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 372 |
+
computing decoding in one step.
|
| 373 |
+
"""
|
| 374 |
+
self.vae.disable_slicing()
|
| 375 |
+
|
| 376 |
+
def enable_vae_tiling(self):
|
| 377 |
+
r"""
|
| 378 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 379 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 380 |
+
processing larger images.
|
| 381 |
+
"""
|
| 382 |
+
self.vae.enable_tiling()
|
| 383 |
+
|
| 384 |
+
def disable_vae_tiling(self):
|
| 385 |
+
r"""
|
| 386 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 387 |
+
computing decoding in one step.
|
| 388 |
+
"""
|
| 389 |
+
self.vae.disable_tiling()
|
| 390 |
+
|
| 391 |
+
def prepare_latents(
|
| 392 |
+
self,
|
| 393 |
+
batch_size,
|
| 394 |
+
num_channels_latents,
|
| 395 |
+
height,
|
| 396 |
+
width,
|
| 397 |
+
dtype,
|
| 398 |
+
device,
|
| 399 |
+
generator,
|
| 400 |
+
latents=None,
|
| 401 |
+
):
|
| 402 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 403 |
+
# latent height and width to be divisible by 2.
|
| 404 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 405 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 406 |
+
|
| 407 |
+
shape = (batch_size, 1, num_channels_latents, height, width)
|
| 408 |
+
|
| 409 |
+
if latents is not None:
|
| 410 |
+
return latents.to(device=device, dtype=dtype)
|
| 411 |
+
|
| 412 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 413 |
+
raise ValueError(
|
| 414 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 415 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 419 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 420 |
+
|
| 421 |
+
return latents
|
| 422 |
+
|
| 423 |
+
@property
|
| 424 |
+
def guidance_scale(self):
|
| 425 |
+
return self._guidance_scale
|
| 426 |
+
|
| 427 |
+
@property
|
| 428 |
+
def attention_kwargs(self):
|
| 429 |
+
return self._attention_kwargs
|
| 430 |
+
|
| 431 |
+
@property
|
| 432 |
+
def num_timesteps(self):
|
| 433 |
+
return self._num_timesteps
|
| 434 |
+
|
| 435 |
+
@property
|
| 436 |
+
def current_timestep(self):
|
| 437 |
+
return self._current_timestep
|
| 438 |
+
|
| 439 |
+
@property
|
| 440 |
+
def interrupt(self):
|
| 441 |
+
return self._interrupt
|
| 442 |
+
|
| 443 |
+
@torch.no_grad()
|
| 444 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 445 |
+
def __call__(
|
| 446 |
+
self,
|
| 447 |
+
prompt: Union[str, List[str]] = None,
|
| 448 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 449 |
+
true_cfg_scale: float = 4.0,
|
| 450 |
+
height: Optional[int] = None,
|
| 451 |
+
width: Optional[int] = None,
|
| 452 |
+
num_inference_steps: int = 50,
|
| 453 |
+
sigmas: Optional[List[float]] = None,
|
| 454 |
+
guidance_scale: float = 1.0,
|
| 455 |
+
num_images_per_prompt: int = 1,
|
| 456 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 457 |
+
latents: Optional[torch.Tensor] = None,
|
| 458 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 459 |
+
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
| 460 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 461 |
+
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
| 462 |
+
output_type: Optional[str] = "pil",
|
| 463 |
+
return_dict: bool = True,
|
| 464 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 465 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 466 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 467 |
+
max_sequence_length: int = 512,
|
| 468 |
+
comfyui_progressbar: bool = False,
|
| 469 |
+
):
|
| 470 |
+
r"""
|
| 471 |
+
Function invoked when calling the pipeline for generation.
|
| 472 |
+
|
| 473 |
+
Args:
|
| 474 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 475 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 476 |
+
instead.
|
| 477 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 478 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 479 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
| 480 |
+
not greater than `1`).
|
| 481 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
| 482 |
+
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
|
| 483 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 484 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 485 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 486 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 487 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 488 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 489 |
+
expense of slower inference.
|
| 490 |
+
sigmas (`List[float]`, *optional*):
|
| 491 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 492 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 493 |
+
will be used.
|
| 494 |
+
guidance_scale (`float`, *optional*, defaults to 3.5):
|
| 495 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 496 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 497 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 498 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 499 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 500 |
+
|
| 501 |
+
This parameter in the pipeline is there to support future guidance-distilled models when they come up.
|
| 502 |
+
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
|
| 503 |
+
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
|
| 504 |
+
enable classifier-free guidance computations.
|
| 505 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 506 |
+
The number of images to generate per prompt.
|
| 507 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 508 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 509 |
+
to make generation deterministic.
|
| 510 |
+
latents (`torch.Tensor`, *optional*):
|
| 511 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 512 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 513 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 514 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 515 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 516 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 517 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 518 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 519 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 520 |
+
argument.
|
| 521 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 522 |
+
The output format of the generate image. Choose between
|
| 523 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 524 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 525 |
+
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
| 526 |
+
attention_kwargs (`dict`, *optional*):
|
| 527 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 528 |
+
`self.processor` in
|
| 529 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 530 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 531 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 532 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 533 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 534 |
+
`callback_on_step_end_tensor_inputs`.
|
| 535 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 536 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 537 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 538 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 539 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 540 |
+
|
| 541 |
+
Examples:
|
| 542 |
+
|
| 543 |
+
Returns:
|
| 544 |
+
[`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
|
| 545 |
+
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
| 546 |
+
returning a tuple, the first element is a list with the generated images.
|
| 547 |
+
"""
|
| 548 |
+
|
| 549 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 550 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 551 |
+
|
| 552 |
+
# 1. Check inputs. Raise error if not correct
|
| 553 |
+
self.check_inputs(
|
| 554 |
+
prompt,
|
| 555 |
+
height,
|
| 556 |
+
width,
|
| 557 |
+
negative_prompt=negative_prompt,
|
| 558 |
+
prompt_embeds=prompt_embeds,
|
| 559 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 560 |
+
prompt_embeds_mask=prompt_embeds_mask,
|
| 561 |
+
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
| 562 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 563 |
+
max_sequence_length=max_sequence_length,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
self._guidance_scale = guidance_scale
|
| 567 |
+
self._attention_kwargs = attention_kwargs
|
| 568 |
+
self._current_timestep = None
|
| 569 |
+
self._interrupt = False
|
| 570 |
+
|
| 571 |
+
# 2. Define call parameters
|
| 572 |
+
if prompt is not None and isinstance(prompt, str):
|
| 573 |
+
batch_size = 1
|
| 574 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 575 |
+
batch_size = len(prompt)
|
| 576 |
+
else:
|
| 577 |
+
batch_size = prompt_embeds.shape[0]
|
| 578 |
+
|
| 579 |
+
device = self._execution_device
|
| 580 |
+
if comfyui_progressbar:
|
| 581 |
+
from comfy.utils import ProgressBar
|
| 582 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 583 |
+
|
| 584 |
+
has_neg_prompt = negative_prompt is not None or (
|
| 585 |
+
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
| 586 |
+
)
|
| 587 |
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 588 |
+
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
| 589 |
+
prompt=prompt,
|
| 590 |
+
prompt_embeds=prompt_embeds,
|
| 591 |
+
prompt_embeds_mask=prompt_embeds_mask,
|
| 592 |
+
device=device,
|
| 593 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 594 |
+
max_sequence_length=max_sequence_length,
|
| 595 |
+
)
|
| 596 |
+
if do_true_cfg:
|
| 597 |
+
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
| 598 |
+
prompt=negative_prompt,
|
| 599 |
+
prompt_embeds=negative_prompt_embeds,
|
| 600 |
+
prompt_embeds_mask=negative_prompt_embeds_mask,
|
| 601 |
+
device=device,
|
| 602 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 603 |
+
max_sequence_length=max_sequence_length,
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
# 4. Prepare latent variables
|
| 607 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 608 |
+
latents = self.prepare_latents(
|
| 609 |
+
batch_size * num_images_per_prompt,
|
| 610 |
+
num_channels_latents,
|
| 611 |
+
height,
|
| 612 |
+
width,
|
| 613 |
+
prompt_embeds.dtype,
|
| 614 |
+
device,
|
| 615 |
+
generator,
|
| 616 |
+
latents,
|
| 617 |
+
)
|
| 618 |
+
img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
|
| 619 |
+
if comfyui_progressbar:
|
| 620 |
+
pbar.update(1)
|
| 621 |
+
|
| 622 |
+
# 5. Prepare timesteps
|
| 623 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 624 |
+
image_seq_len = latents.shape[1]
|
| 625 |
+
mu = calculate_shift(
|
| 626 |
+
image_seq_len,
|
| 627 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 628 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 629 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 630 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 631 |
+
)
|
| 632 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 633 |
+
self.scheduler,
|
| 634 |
+
num_inference_steps,
|
| 635 |
+
device,
|
| 636 |
+
sigmas=sigmas,
|
| 637 |
+
mu=mu,
|
| 638 |
+
)
|
| 639 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 640 |
+
self._num_timesteps = len(timesteps)
|
| 641 |
+
|
| 642 |
+
# handle guidance
|
| 643 |
+
if self.transformer.config.guidance_embeds:
|
| 644 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 645 |
+
guidance = guidance.expand(latents.shape[0])
|
| 646 |
+
else:
|
| 647 |
+
guidance = None
|
| 648 |
+
|
| 649 |
+
if self.attention_kwargs is None:
|
| 650 |
+
self._attention_kwargs = {}
|
| 651 |
+
|
| 652 |
+
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
| 653 |
+
negative_txt_seq_lens = (
|
| 654 |
+
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
| 655 |
+
)
|
| 656 |
+
if comfyui_progressbar:
|
| 657 |
+
pbar.update(1)
|
| 658 |
+
|
| 659 |
+
# 6. Denoising loop
|
| 660 |
+
self.scheduler.set_begin_index(0)
|
| 661 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 662 |
+
for i, t in enumerate(timesteps):
|
| 663 |
+
self.transformer.current_steps = i
|
| 664 |
+
if self.interrupt:
|
| 665 |
+
continue
|
| 666 |
+
|
| 667 |
+
if do_true_cfg:
|
| 668 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 669 |
+
prompt_embeds_mask_input = [_negative_prompt_embeds_mask for _negative_prompt_embeds_mask in negative_prompt_embeds_mask] + [_prompt_embeds_mask for _prompt_embeds_mask in prompt_embeds_mask]
|
| 670 |
+
prompt_embeds_input = [_negative_prompt_embeds for _negative_prompt_embeds in negative_prompt_embeds] + [_prompt_embeds for _prompt_embeds in prompt_embeds]
|
| 671 |
+
img_shapes_input = img_shapes * 2
|
| 672 |
+
txt_seq_lens_input = negative_txt_seq_lens + txt_seq_lens
|
| 673 |
+
else:
|
| 674 |
+
latent_model_input = latents
|
| 675 |
+
prompt_embeds_mask_input = prompt_embeds_mask
|
| 676 |
+
prompt_embeds_input = prompt_embeds
|
| 677 |
+
img_shapes_input = img_shapes
|
| 678 |
+
txt_seq_lens_input = txt_seq_lens
|
| 679 |
+
|
| 680 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 681 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 682 |
+
|
| 683 |
+
# handle guidance
|
| 684 |
+
if self.transformer.config.guidance_embeds:
|
| 685 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 686 |
+
guidance = guidance.expand(latent_model_input.shape[0])
|
| 687 |
+
else:
|
| 688 |
+
guidance = None
|
| 689 |
+
|
| 690 |
+
self._current_timestep = t
|
| 691 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 692 |
+
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
| 693 |
+
|
| 694 |
+
with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device):
|
| 695 |
+
noise_pred = self.transformer.forward_bs(
|
| 696 |
+
x=latent_model_input,
|
| 697 |
+
timestep=timestep / 1000,
|
| 698 |
+
guidance=guidance,
|
| 699 |
+
encoder_hidden_states_mask=prompt_embeds_mask_input,
|
| 700 |
+
encoder_hidden_states=prompt_embeds_input,
|
| 701 |
+
img_shapes=img_shapes_input,
|
| 702 |
+
txt_seq_lens=txt_seq_lens_input,
|
| 703 |
+
attention_kwargs=self.attention_kwargs,
|
| 704 |
+
return_dict=False,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
if do_true_cfg:
|
| 708 |
+
neg_noise_pred, noise_pred = noise_pred.chunk(2)
|
| 709 |
+
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 710 |
+
|
| 711 |
+
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
| 712 |
+
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
| 713 |
+
noise_pred = comb_pred * (cond_norm / noise_norm)
|
| 714 |
+
|
| 715 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 716 |
+
latents_dtype = latents.dtype
|
| 717 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 718 |
+
|
| 719 |
+
if latents.dtype != latents_dtype:
|
| 720 |
+
if torch.backends.mps.is_available():
|
| 721 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 722 |
+
latents = latents.to(latents_dtype)
|
| 723 |
+
|
| 724 |
+
if callback_on_step_end is not None:
|
| 725 |
+
callback_kwargs = {}
|
| 726 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 727 |
+
callback_kwargs[k] = locals()[k]
|
| 728 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 729 |
+
|
| 730 |
+
latents = callback_outputs.pop("latents", latents)
|
| 731 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 732 |
+
|
| 733 |
+
# call the callback, if provided
|
| 734 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 735 |
+
progress_bar.update()
|
| 736 |
+
|
| 737 |
+
if XLA_AVAILABLE:
|
| 738 |
+
xm.mark_step()
|
| 739 |
+
|
| 740 |
+
if comfyui_progressbar:
|
| 741 |
+
pbar.update(1)
|
| 742 |
+
|
| 743 |
+
self._current_timestep = None
|
| 744 |
+
if output_type == "latent":
|
| 745 |
+
image = latents
|
| 746 |
+
else:
|
| 747 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 748 |
+
latents = latents.to(self.vae.dtype)
|
| 749 |
+
latents_mean = (
|
| 750 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 751 |
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 752 |
+
.to(latents.device, latents.dtype)
|
| 753 |
+
)
|
| 754 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 755 |
+
latents.device, latents.dtype
|
| 756 |
+
)
|
| 757 |
+
latents = latents / latents_std + latents_mean
|
| 758 |
+
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
| 759 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 760 |
+
|
| 761 |
+
# Offload all models
|
| 762 |
+
self.maybe_free_model_hooks()
|
| 763 |
+
|
| 764 |
+
if not return_dict:
|
| 765 |
+
return (image,)
|
| 766 |
+
|
| 767 |
+
return QwenImagePipelineOutput(images=image)
|
videox_fun/pipeline/pipeline_qwenimage_edit.py
ADDED
|
@@ -0,0 +1,952 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/naykun/diffusers/blob/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
|
| 2 |
+
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import math
|
| 22 |
+
import PIL.Image
|
| 23 |
+
import torch
|
| 24 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 26 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 27 |
+
from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
|
| 28 |
+
replace_example_docstring)
|
| 29 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 30 |
+
|
| 31 |
+
from ..models import (AutoencoderKLQwenImage,
|
| 32 |
+
Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor,
|
| 33 |
+
Qwen2Tokenizer, QwenImageTransformer2DModel)
|
| 34 |
+
|
| 35 |
+
if is_torch_xla_available():
|
| 36 |
+
import torch_xla.core.xla_model as xm
|
| 37 |
+
|
| 38 |
+
XLA_AVAILABLE = True
|
| 39 |
+
else:
|
| 40 |
+
XLA_AVAILABLE = False
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 43 |
+
|
| 44 |
+
EXAMPLE_DOC_STRING = """
|
| 45 |
+
Examples:
|
| 46 |
+
```py
|
| 47 |
+
>>> import torch
|
| 48 |
+
>>> from PIL import Image
|
| 49 |
+
>>> from diffusers import QwenImageEditPipeline
|
| 50 |
+
|
| 51 |
+
>>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16)
|
| 52 |
+
>>> pipe.to("cuda")
|
| 53 |
+
>>> prompt = "Change the cat to a dog"
|
| 54 |
+
>>> image = Image.open("cat.png")
|
| 55 |
+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
| 56 |
+
>>> # Refer to the pipeline documentation for more details.
|
| 57 |
+
>>> image = pipe(image, prompt, num_inference_steps=50).images[0]
|
| 58 |
+
>>> image.save("qwenimageedit.png")
|
| 59 |
+
```
|
| 60 |
+
"""
|
| 61 |
+
PREFERRED_QWENIMAGE_RESOLUTIONS = [
|
| 62 |
+
(672, 1568),
|
| 63 |
+
(688, 1504),
|
| 64 |
+
(720, 1456),
|
| 65 |
+
(752, 1392),
|
| 66 |
+
(800, 1328),
|
| 67 |
+
(832, 1248),
|
| 68 |
+
(880, 1184),
|
| 69 |
+
(944, 1104),
|
| 70 |
+
(1024, 1024),
|
| 71 |
+
(1104, 944),
|
| 72 |
+
(1184, 880),
|
| 73 |
+
(1248, 832),
|
| 74 |
+
(1328, 800),
|
| 75 |
+
(1392, 752),
|
| 76 |
+
(1456, 720),
|
| 77 |
+
(1504, 688),
|
| 78 |
+
(1568, 672),
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
| 82 |
+
def calculate_shift(
|
| 83 |
+
image_seq_len,
|
| 84 |
+
base_seq_len: int = 256,
|
| 85 |
+
max_seq_len: int = 4096,
|
| 86 |
+
base_shift: float = 0.5,
|
| 87 |
+
max_shift: float = 1.15,
|
| 88 |
+
):
|
| 89 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 90 |
+
b = base_shift - m * base_seq_len
|
| 91 |
+
mu = image_seq_len * m + b
|
| 92 |
+
return mu
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 96 |
+
def retrieve_timesteps(
|
| 97 |
+
scheduler,
|
| 98 |
+
num_inference_steps: Optional[int] = None,
|
| 99 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 100 |
+
timesteps: Optional[List[int]] = None,
|
| 101 |
+
sigmas: Optional[List[float]] = None,
|
| 102 |
+
**kwargs,
|
| 103 |
+
):
|
| 104 |
+
r"""
|
| 105 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 106 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
scheduler (`SchedulerMixin`):
|
| 110 |
+
The scheduler to get timesteps from.
|
| 111 |
+
num_inference_steps (`int`):
|
| 112 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 113 |
+
must be `None`.
|
| 114 |
+
device (`str` or `torch.device`, *optional*):
|
| 115 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 116 |
+
timesteps (`List[int]`, *optional*):
|
| 117 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 118 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 119 |
+
sigmas (`List[float]`, *optional*):
|
| 120 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 121 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 125 |
+
second element is the number of inference steps.
|
| 126 |
+
"""
|
| 127 |
+
if timesteps is not None and sigmas is not None:
|
| 128 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 129 |
+
if timesteps is not None:
|
| 130 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 131 |
+
if not accepts_timesteps:
|
| 132 |
+
raise ValueError(
|
| 133 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 134 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 135 |
+
)
|
| 136 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 137 |
+
timesteps = scheduler.timesteps
|
| 138 |
+
num_inference_steps = len(timesteps)
|
| 139 |
+
elif sigmas is not None:
|
| 140 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 141 |
+
if not accept_sigmas:
|
| 142 |
+
raise ValueError(
|
| 143 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 144 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 145 |
+
)
|
| 146 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 147 |
+
timesteps = scheduler.timesteps
|
| 148 |
+
num_inference_steps = len(timesteps)
|
| 149 |
+
else:
|
| 150 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 151 |
+
timesteps = scheduler.timesteps
|
| 152 |
+
return timesteps, num_inference_steps
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 156 |
+
def retrieve_latents(
|
| 157 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 158 |
+
):
|
| 159 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 160 |
+
return encoder_output.latent_dist.sample(generator)
|
| 161 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 162 |
+
return encoder_output.latent_dist.mode()
|
| 163 |
+
elif hasattr(encoder_output, "latents"):
|
| 164 |
+
return encoder_output.latents
|
| 165 |
+
else:
|
| 166 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def calculate_dimensions(target_area, ratio):
|
| 170 |
+
width = math.sqrt(target_area * ratio)
|
| 171 |
+
height = width / ratio
|
| 172 |
+
|
| 173 |
+
width = round(width / 32) * 32
|
| 174 |
+
height = round(height / 32) * 32
|
| 175 |
+
|
| 176 |
+
return width, height
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@dataclass
|
| 180 |
+
class QwenImagePipelineOutput(BaseOutput):
|
| 181 |
+
"""
|
| 182 |
+
Output class for Stable Diffusion pipelines.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
| 186 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
| 187 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class QwenImageEditPipeline(DiffusionPipeline):
|
| 194 |
+
r"""
|
| 195 |
+
The QwenImage pipeline for text-to-image generation.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
transformer ([`QwenImageTransformer2DModel`]):
|
| 199 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 200 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 201 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 202 |
+
vae ([`AutoencoderKL`]):
|
| 203 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 204 |
+
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
| 205 |
+
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
| 206 |
+
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
| 207 |
+
tokenizer (`QwenTokenizer`):
|
| 208 |
+
Tokenizer of class
|
| 209 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 213 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 214 |
+
|
| 215 |
+
def __init__(
|
| 216 |
+
self,
|
| 217 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 218 |
+
vae: AutoencoderKLQwenImage,
|
| 219 |
+
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
| 220 |
+
tokenizer: Qwen2Tokenizer,
|
| 221 |
+
processor: Qwen2VLProcessor,
|
| 222 |
+
transformer: QwenImageTransformer2DModel,
|
| 223 |
+
):
|
| 224 |
+
super().__init__()
|
| 225 |
+
|
| 226 |
+
self.register_modules(
|
| 227 |
+
vae=vae,
|
| 228 |
+
text_encoder=text_encoder,
|
| 229 |
+
tokenizer=tokenizer,
|
| 230 |
+
processor=processor,
|
| 231 |
+
transformer=transformer,
|
| 232 |
+
scheduler=scheduler,
|
| 233 |
+
)
|
| 234 |
+
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
| 235 |
+
self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
|
| 236 |
+
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
| 237 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
| 238 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 239 |
+
self.tokenizer_max_length = 1024
|
| 240 |
+
|
| 241 |
+
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
| 242 |
+
self.prompt_template_encode_start_idx = 64
|
| 243 |
+
self.default_sample_size = 128
|
| 244 |
+
|
| 245 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
|
| 246 |
+
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
| 247 |
+
bool_mask = mask.bool()
|
| 248 |
+
valid_lengths = bool_mask.sum(dim=1)
|
| 249 |
+
selected = hidden_states[bool_mask]
|
| 250 |
+
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
| 251 |
+
|
| 252 |
+
return split_result
|
| 253 |
+
|
| 254 |
+
def _get_qwen_prompt_embeds(
|
| 255 |
+
self,
|
| 256 |
+
prompt: Union[str, List[str]] = None,
|
| 257 |
+
image: Optional[torch.Tensor] = None,
|
| 258 |
+
device: Optional[torch.device] = None,
|
| 259 |
+
dtype: Optional[torch.dtype] = None,
|
| 260 |
+
):
|
| 261 |
+
device = device or self._execution_device
|
| 262 |
+
dtype = dtype or self.text_encoder.dtype
|
| 263 |
+
|
| 264 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 265 |
+
|
| 266 |
+
template = self.prompt_template_encode
|
| 267 |
+
drop_idx = self.prompt_template_encode_start_idx
|
| 268 |
+
txt = [template.format(e) for e in prompt]
|
| 269 |
+
|
| 270 |
+
model_inputs = self.processor(
|
| 271 |
+
text=txt,
|
| 272 |
+
images=image,
|
| 273 |
+
padding=True,
|
| 274 |
+
return_tensors="pt",
|
| 275 |
+
).to(device)
|
| 276 |
+
|
| 277 |
+
outputs = self.text_encoder(
|
| 278 |
+
input_ids=model_inputs.input_ids,
|
| 279 |
+
attention_mask=model_inputs.attention_mask,
|
| 280 |
+
pixel_values=model_inputs.pixel_values,
|
| 281 |
+
image_grid_thw=model_inputs.image_grid_thw,
|
| 282 |
+
output_hidden_states=True,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
hidden_states = outputs.hidden_states[-1]
|
| 286 |
+
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
| 287 |
+
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
| 288 |
+
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
| 289 |
+
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
| 290 |
+
prompt_embeds = torch.stack(
|
| 291 |
+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
| 292 |
+
)
|
| 293 |
+
encoder_attention_mask = torch.stack(
|
| 294 |
+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 298 |
+
|
| 299 |
+
return prompt_embeds, encoder_attention_mask
|
| 300 |
+
|
| 301 |
+
def encode_prompt(
|
| 302 |
+
self,
|
| 303 |
+
prompt: Union[str, List[str]],
|
| 304 |
+
image: Optional[torch.Tensor] = None,
|
| 305 |
+
device: Optional[torch.device] = None,
|
| 306 |
+
num_images_per_prompt: int = 1,
|
| 307 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 308 |
+
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
| 309 |
+
max_sequence_length: int = 1024,
|
| 310 |
+
):
|
| 311 |
+
r"""
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 315 |
+
prompt to be encoded
|
| 316 |
+
image (`torch.Tensor`, *optional*):
|
| 317 |
+
image to be encoded
|
| 318 |
+
device: (`torch.device`):
|
| 319 |
+
torch device
|
| 320 |
+
num_images_per_prompt (`int`):
|
| 321 |
+
number of images that should be generated per prompt
|
| 322 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 323 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 324 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 325 |
+
"""
|
| 326 |
+
device = device or self._execution_device
|
| 327 |
+
|
| 328 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 329 |
+
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
|
| 330 |
+
|
| 331 |
+
if prompt_embeds is None:
|
| 332 |
+
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
|
| 333 |
+
|
| 334 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 335 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 336 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 337 |
+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
| 338 |
+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
| 339 |
+
|
| 340 |
+
return prompt_embeds, prompt_embeds_mask
|
| 341 |
+
|
| 342 |
+
def check_inputs(
|
| 343 |
+
self,
|
| 344 |
+
prompt,
|
| 345 |
+
height,
|
| 346 |
+
width,
|
| 347 |
+
negative_prompt=None,
|
| 348 |
+
prompt_embeds=None,
|
| 349 |
+
negative_prompt_embeds=None,
|
| 350 |
+
prompt_embeds_mask=None,
|
| 351 |
+
negative_prompt_embeds_mask=None,
|
| 352 |
+
callback_on_step_end_tensor_inputs=None,
|
| 353 |
+
max_sequence_length=None,
|
| 354 |
+
):
|
| 355 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 356 |
+
logger.warning(
|
| 357 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 361 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 362 |
+
):
|
| 363 |
+
raise ValueError(
|
| 364 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
if prompt is not None and prompt_embeds is not None:
|
| 368 |
+
raise ValueError(
|
| 369 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 370 |
+
" only forward one of the two."
|
| 371 |
+
)
|
| 372 |
+
elif prompt is None and prompt_embeds is None:
|
| 373 |
+
raise ValueError(
|
| 374 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 375 |
+
)
|
| 376 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 377 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 378 |
+
|
| 379 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 380 |
+
raise ValueError(
|
| 381 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 382 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
if prompt_embeds is not None and prompt_embeds_mask is None:
|
| 386 |
+
raise ValueError(
|
| 387 |
+
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
| 388 |
+
)
|
| 389 |
+
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
| 390 |
+
raise ValueError(
|
| 391 |
+
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
if max_sequence_length is not None and max_sequence_length > 1024:
|
| 395 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
| 396 |
+
|
| 397 |
+
@staticmethod
|
| 398 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
|
| 399 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 400 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 401 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 402 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 403 |
+
|
| 404 |
+
return latents
|
| 405 |
+
|
| 406 |
+
@staticmethod
|
| 407 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
|
| 408 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 409 |
+
batch_size, num_patches, channels = latents.shape
|
| 410 |
+
|
| 411 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 412 |
+
# latent height and width to be divisible by 2.
|
| 413 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 414 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 415 |
+
|
| 416 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 417 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 418 |
+
|
| 419 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
|
| 420 |
+
|
| 421 |
+
return latents
|
| 422 |
+
|
| 423 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
| 424 |
+
if isinstance(generator, list):
|
| 425 |
+
image_latents = [
|
| 426 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
|
| 427 |
+
for i in range(image.shape[0])
|
| 428 |
+
]
|
| 429 |
+
image_latents = torch.cat(image_latents, dim=0)
|
| 430 |
+
else:
|
| 431 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
| 432 |
+
latents_mean = (
|
| 433 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 434 |
+
.view(1, self.latent_channels, 1, 1, 1)
|
| 435 |
+
.to(image_latents.device, image_latents.dtype)
|
| 436 |
+
)
|
| 437 |
+
latents_std = (
|
| 438 |
+
torch.tensor(self.vae.config.latents_std)
|
| 439 |
+
.view(1, self.latent_channels, 1, 1, 1)
|
| 440 |
+
.to(image_latents.device, image_latents.dtype)
|
| 441 |
+
)
|
| 442 |
+
image_latents = (image_latents - latents_mean) / latents_std
|
| 443 |
+
|
| 444 |
+
return image_latents
|
| 445 |
+
|
| 446 |
+
def enable_vae_slicing(self):
|
| 447 |
+
r"""
|
| 448 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 449 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 450 |
+
"""
|
| 451 |
+
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
|
| 452 |
+
deprecate(
|
| 453 |
+
"enable_vae_slicing",
|
| 454 |
+
"0.40.0",
|
| 455 |
+
depr_message,
|
| 456 |
+
)
|
| 457 |
+
self.vae.enable_slicing()
|
| 458 |
+
|
| 459 |
+
def disable_vae_slicing(self):
|
| 460 |
+
r"""
|
| 461 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 462 |
+
computing decoding in one step.
|
| 463 |
+
"""
|
| 464 |
+
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
|
| 465 |
+
deprecate(
|
| 466 |
+
"disable_vae_slicing",
|
| 467 |
+
"0.40.0",
|
| 468 |
+
depr_message,
|
| 469 |
+
)
|
| 470 |
+
self.vae.disable_slicing()
|
| 471 |
+
|
| 472 |
+
def enable_vae_tiling(self):
|
| 473 |
+
r"""
|
| 474 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 475 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 476 |
+
processing larger images.
|
| 477 |
+
"""
|
| 478 |
+
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
|
| 479 |
+
deprecate(
|
| 480 |
+
"enable_vae_tiling",
|
| 481 |
+
"0.40.0",
|
| 482 |
+
depr_message,
|
| 483 |
+
)
|
| 484 |
+
self.vae.enable_tiling()
|
| 485 |
+
|
| 486 |
+
def disable_vae_tiling(self):
|
| 487 |
+
r"""
|
| 488 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 489 |
+
computing decoding in one step.
|
| 490 |
+
"""
|
| 491 |
+
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
|
| 492 |
+
deprecate(
|
| 493 |
+
"disable_vae_tiling",
|
| 494 |
+
"0.40.0",
|
| 495 |
+
depr_message,
|
| 496 |
+
)
|
| 497 |
+
self.vae.disable_tiling()
|
| 498 |
+
|
| 499 |
+
def prepare_latents(
|
| 500 |
+
self,
|
| 501 |
+
image,
|
| 502 |
+
batch_size,
|
| 503 |
+
num_channels_latents,
|
| 504 |
+
height,
|
| 505 |
+
width,
|
| 506 |
+
dtype,
|
| 507 |
+
device,
|
| 508 |
+
generator,
|
| 509 |
+
latents=None,
|
| 510 |
+
):
|
| 511 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 512 |
+
# latent height and width to be divisible by 2.
|
| 513 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 514 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 515 |
+
|
| 516 |
+
shape = (batch_size, 1, num_channels_latents, height, width)
|
| 517 |
+
|
| 518 |
+
image_latents = None
|
| 519 |
+
if image is not None:
|
| 520 |
+
image = image.to(device=device, dtype=dtype)
|
| 521 |
+
if image.shape[1] != self.latent_channels:
|
| 522 |
+
image_latents = self._encode_vae_image(image=image, generator=generator)
|
| 523 |
+
else:
|
| 524 |
+
image_latents = image
|
| 525 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
| 526 |
+
# expand init_latents for batch_size
|
| 527 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
| 528 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
| 529 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
| 530 |
+
raise ValueError(
|
| 531 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
| 532 |
+
)
|
| 533 |
+
else:
|
| 534 |
+
image_latents = torch.cat([image_latents], dim=0)
|
| 535 |
+
|
| 536 |
+
image_latent_height, image_latent_width = image_latents.shape[3:]
|
| 537 |
+
image_latents = self._pack_latents(
|
| 538 |
+
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 542 |
+
raise ValueError(
|
| 543 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 544 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 545 |
+
)
|
| 546 |
+
if latents is None:
|
| 547 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 548 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 549 |
+
else:
|
| 550 |
+
latents = latents.to(device=device, dtype=dtype)
|
| 551 |
+
|
| 552 |
+
return latents, image_latents
|
| 553 |
+
|
| 554 |
+
@property
|
| 555 |
+
def guidance_scale(self):
|
| 556 |
+
return self._guidance_scale
|
| 557 |
+
|
| 558 |
+
@property
|
| 559 |
+
def attention_kwargs(self):
|
| 560 |
+
return self._attention_kwargs
|
| 561 |
+
|
| 562 |
+
@property
|
| 563 |
+
def num_timesteps(self):
|
| 564 |
+
return self._num_timesteps
|
| 565 |
+
|
| 566 |
+
@property
|
| 567 |
+
def current_timestep(self):
|
| 568 |
+
return self._current_timestep
|
| 569 |
+
|
| 570 |
+
@property
|
| 571 |
+
def interrupt(self):
|
| 572 |
+
return self._interrupt
|
| 573 |
+
|
| 574 |
+
@torch.no_grad()
|
| 575 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 576 |
+
def __call__(
|
| 577 |
+
self,
|
| 578 |
+
image = None,
|
| 579 |
+
prompt: Union[str, List[str]] = None,
|
| 580 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 581 |
+
true_cfg_scale: float = 4.0,
|
| 582 |
+
height: Optional[int] = None,
|
| 583 |
+
width: Optional[int] = None,
|
| 584 |
+
num_inference_steps: int = 50,
|
| 585 |
+
sigmas: Optional[List[float]] = None,
|
| 586 |
+
guidance_scale: Optional[float] = None,
|
| 587 |
+
num_images_per_prompt: int = 1,
|
| 588 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 589 |
+
latents: Optional[torch.Tensor] = None,
|
| 590 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 591 |
+
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
| 592 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 593 |
+
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
| 594 |
+
output_type: Optional[str] = "pil",
|
| 595 |
+
return_dict: bool = True,
|
| 596 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 597 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 598 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 599 |
+
max_sequence_length: int = 512,
|
| 600 |
+
comfyui_progressbar: bool = False,
|
| 601 |
+
):
|
| 602 |
+
r"""
|
| 603 |
+
Function invoked when calling the pipeline for generation.
|
| 604 |
+
|
| 605 |
+
Args:
|
| 606 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 607 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
| 608 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
| 609 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
| 610 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
| 611 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
| 612 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 613 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 614 |
+
instead.
|
| 615 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 616 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 617 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
| 618 |
+
not greater than `1`).
|
| 619 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
| 620 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
|
| 621 |
+
Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
|
| 622 |
+
equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
|
| 623 |
+
enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
|
| 624 |
+
encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
|
| 625 |
+
lower image quality.
|
| 626 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 627 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 628 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 629 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 630 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 631 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 632 |
+
expense of slower inference.
|
| 633 |
+
sigmas (`List[float]`, *optional*):
|
| 634 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 635 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 636 |
+
will be used.
|
| 637 |
+
guidance_scale (`float`, *optional*, defaults to None):
|
| 638 |
+
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
|
| 639 |
+
where the guidance scale is applied during inference through noise prediction rescaling, guidance
|
| 640 |
+
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
|
| 641 |
+
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
|
| 642 |
+
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
|
| 643 |
+
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
|
| 644 |
+
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
|
| 645 |
+
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
|
| 646 |
+
enable classifier-free guidance computations).
|
| 647 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 648 |
+
The number of images to generate per prompt.
|
| 649 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 650 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 651 |
+
to make generation deterministic.
|
| 652 |
+
latents (`torch.Tensor`, *optional*):
|
| 653 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 654 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 655 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 656 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 657 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 658 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 659 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 660 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 661 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 662 |
+
argument.
|
| 663 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 664 |
+
The output format of the generate image. Choose between
|
| 665 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 666 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 667 |
+
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
| 668 |
+
attention_kwargs (`dict`, *optional*):
|
| 669 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 670 |
+
`self.processor` in
|
| 671 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 672 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 673 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 674 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 675 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 676 |
+
`callback_on_step_end_tensor_inputs`.
|
| 677 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 678 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 679 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 680 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 681 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 682 |
+
|
| 683 |
+
Examples:
|
| 684 |
+
|
| 685 |
+
Returns:
|
| 686 |
+
[`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
|
| 687 |
+
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
| 688 |
+
returning a tuple, the first element is a list with the generated images.
|
| 689 |
+
"""
|
| 690 |
+
image_size = image[0].size if isinstance(image, list) else image.size
|
| 691 |
+
calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
|
| 692 |
+
height = height or calculated_height
|
| 693 |
+
width = width or calculated_width
|
| 694 |
+
|
| 695 |
+
multiple_of = self.vae_scale_factor * 2
|
| 696 |
+
width = width // multiple_of * multiple_of
|
| 697 |
+
height = height // multiple_of * multiple_of
|
| 698 |
+
|
| 699 |
+
# 1. Check inputs. Raise error if not correct
|
| 700 |
+
self.check_inputs(
|
| 701 |
+
prompt,
|
| 702 |
+
height,
|
| 703 |
+
width,
|
| 704 |
+
negative_prompt=negative_prompt,
|
| 705 |
+
prompt_embeds=prompt_embeds,
|
| 706 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 707 |
+
prompt_embeds_mask=prompt_embeds_mask,
|
| 708 |
+
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
| 709 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 710 |
+
max_sequence_length=max_sequence_length,
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
self._guidance_scale = guidance_scale
|
| 714 |
+
self._attention_kwargs = attention_kwargs
|
| 715 |
+
self._current_timestep = None
|
| 716 |
+
self._interrupt = False
|
| 717 |
+
|
| 718 |
+
# 2. Define call parameters
|
| 719 |
+
if prompt is not None and isinstance(prompt, str):
|
| 720 |
+
batch_size = 1
|
| 721 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 722 |
+
batch_size = len(prompt)
|
| 723 |
+
else:
|
| 724 |
+
batch_size = prompt_embeds.shape[0]
|
| 725 |
+
|
| 726 |
+
device = self._execution_device
|
| 727 |
+
if comfyui_progressbar:
|
| 728 |
+
from comfy.utils import ProgressBar
|
| 729 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 730 |
+
|
| 731 |
+
# 3. Preprocess image
|
| 732 |
+
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
| 733 |
+
image = self.image_processor.resize(image, calculated_height, calculated_width)
|
| 734 |
+
prompt_image = image
|
| 735 |
+
image = self.image_processor.preprocess(image, calculated_height, calculated_width)
|
| 736 |
+
image = image.unsqueeze(2)
|
| 737 |
+
|
| 738 |
+
has_neg_prompt = negative_prompt is not None or (
|
| 739 |
+
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
if true_cfg_scale > 1 and not has_neg_prompt:
|
| 743 |
+
logger.warning(
|
| 744 |
+
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
|
| 745 |
+
)
|
| 746 |
+
elif true_cfg_scale <= 1 and has_neg_prompt:
|
| 747 |
+
logger.warning(
|
| 748 |
+
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 752 |
+
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
| 753 |
+
image=prompt_image,
|
| 754 |
+
prompt=prompt,
|
| 755 |
+
prompt_embeds=prompt_embeds,
|
| 756 |
+
prompt_embeds_mask=prompt_embeds_mask,
|
| 757 |
+
device=device,
|
| 758 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 759 |
+
max_sequence_length=max_sequence_length,
|
| 760 |
+
)
|
| 761 |
+
if do_true_cfg:
|
| 762 |
+
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
| 763 |
+
image=prompt_image,
|
| 764 |
+
prompt=negative_prompt,
|
| 765 |
+
prompt_embeds=negative_prompt_embeds,
|
| 766 |
+
prompt_embeds_mask=negative_prompt_embeds_mask,
|
| 767 |
+
device=device,
|
| 768 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 769 |
+
max_sequence_length=max_sequence_length,
|
| 770 |
+
)
|
| 771 |
+
if comfyui_progressbar:
|
| 772 |
+
pbar.update(1)
|
| 773 |
+
|
| 774 |
+
# 4. Prepare latent variables
|
| 775 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 776 |
+
latents, image_latents = self.prepare_latents(
|
| 777 |
+
image,
|
| 778 |
+
batch_size * num_images_per_prompt,
|
| 779 |
+
num_channels_latents,
|
| 780 |
+
height,
|
| 781 |
+
width,
|
| 782 |
+
prompt_embeds.dtype,
|
| 783 |
+
device,
|
| 784 |
+
generator,
|
| 785 |
+
latents,
|
| 786 |
+
)
|
| 787 |
+
img_shapes = [
|
| 788 |
+
[
|
| 789 |
+
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
|
| 790 |
+
(1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
|
| 791 |
+
]
|
| 792 |
+
] * batch_size
|
| 793 |
+
|
| 794 |
+
# 5. Prepare timesteps
|
| 795 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 796 |
+
image_seq_len = latents.shape[1]
|
| 797 |
+
mu = calculate_shift(
|
| 798 |
+
image_seq_len,
|
| 799 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 800 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 801 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 802 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 803 |
+
)
|
| 804 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 805 |
+
self.scheduler,
|
| 806 |
+
num_inference_steps,
|
| 807 |
+
device,
|
| 808 |
+
sigmas=sigmas,
|
| 809 |
+
mu=mu,
|
| 810 |
+
)
|
| 811 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 812 |
+
self._num_timesteps = len(timesteps)
|
| 813 |
+
|
| 814 |
+
# handle guidance
|
| 815 |
+
if self.transformer.config.guidance_embeds and guidance_scale is None:
|
| 816 |
+
raise ValueError("guidance_scale is required for guidance-distilled model.")
|
| 817 |
+
elif self.transformer.config.guidance_embeds:
|
| 818 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 819 |
+
guidance = guidance.expand(latents.shape[0])
|
| 820 |
+
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
|
| 821 |
+
logger.warning(
|
| 822 |
+
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
|
| 823 |
+
)
|
| 824 |
+
guidance = None
|
| 825 |
+
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
|
| 826 |
+
guidance = None
|
| 827 |
+
|
| 828 |
+
if self.attention_kwargs is None:
|
| 829 |
+
self._attention_kwargs = {}
|
| 830 |
+
|
| 831 |
+
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
| 832 |
+
negative_txt_seq_lens = (
|
| 833 |
+
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
| 834 |
+
)
|
| 835 |
+
if comfyui_progressbar:
|
| 836 |
+
pbar.update(1)
|
| 837 |
+
|
| 838 |
+
# 6. Denoising loop
|
| 839 |
+
self.scheduler.set_begin_index(0)
|
| 840 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 841 |
+
for i, t in enumerate(timesteps):
|
| 842 |
+
self.transformer.current_steps = i
|
| 843 |
+
if self.interrupt:
|
| 844 |
+
continue
|
| 845 |
+
|
| 846 |
+
if image_latents is not None:
|
| 847 |
+
latents_and_image_latents = torch.cat([latents, image_latents], dim=1)
|
| 848 |
+
else:
|
| 849 |
+
latents_and_image_latents = latents
|
| 850 |
+
|
| 851 |
+
if do_true_cfg:
|
| 852 |
+
latent_model_input = torch.cat([latents_and_image_latents] * 2)
|
| 853 |
+
prompt_embeds_mask_input = [_negative_prompt_embeds_mask for _negative_prompt_embeds_mask in negative_prompt_embeds_mask] + [_prompt_embeds_mask for _prompt_embeds_mask in prompt_embeds_mask]
|
| 854 |
+
prompt_embeds_input = [_negative_prompt_embeds for _negative_prompt_embeds in negative_prompt_embeds] + [_prompt_embeds for _prompt_embeds in prompt_embeds]
|
| 855 |
+
img_shapes_input = img_shapes * 2
|
| 856 |
+
txt_seq_lens_input = negative_txt_seq_lens + txt_seq_lens
|
| 857 |
+
else:
|
| 858 |
+
latent_model_input = latents_and_image_latents
|
| 859 |
+
prompt_embeds_mask_input = prompt_embeds_mask
|
| 860 |
+
prompt_embeds_input = prompt_embeds
|
| 861 |
+
img_shapes_input = img_shapes
|
| 862 |
+
txt_seq_lens_input = txt_seq_lens
|
| 863 |
+
|
| 864 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 865 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 866 |
+
|
| 867 |
+
# handle guidance
|
| 868 |
+
if self.transformer.config.guidance_embeds:
|
| 869 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 870 |
+
guidance = guidance.expand(latent_model_input.shape[0])
|
| 871 |
+
else:
|
| 872 |
+
guidance = None
|
| 873 |
+
|
| 874 |
+
self._current_timestep = t
|
| 875 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 876 |
+
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
| 877 |
+
|
| 878 |
+
with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device):
|
| 879 |
+
noise_pred = self.transformer.forward_bs(
|
| 880 |
+
x=latent_model_input,
|
| 881 |
+
timestep=timestep / 1000,
|
| 882 |
+
guidance=guidance,
|
| 883 |
+
encoder_hidden_states_mask=prompt_embeds_mask_input,
|
| 884 |
+
encoder_hidden_states=prompt_embeds_input,
|
| 885 |
+
img_shapes=img_shapes_input,
|
| 886 |
+
txt_seq_lens=txt_seq_lens_input,
|
| 887 |
+
attention_kwargs=self.attention_kwargs,
|
| 888 |
+
return_dict=False,
|
| 889 |
+
)
|
| 890 |
+
noise_pred = noise_pred[:, : latents.size(1)]
|
| 891 |
+
|
| 892 |
+
if do_true_cfg:
|
| 893 |
+
neg_noise_pred, noise_pred = noise_pred.chunk(2)
|
| 894 |
+
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 895 |
+
|
| 896 |
+
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
| 897 |
+
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
| 898 |
+
noise_pred = comb_pred * (cond_norm / noise_norm)
|
| 899 |
+
|
| 900 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 901 |
+
latents_dtype = latents.dtype
|
| 902 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 903 |
+
|
| 904 |
+
if latents.dtype != latents_dtype:
|
| 905 |
+
if torch.backends.mps.is_available():
|
| 906 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 907 |
+
latents = latents.to(latents_dtype)
|
| 908 |
+
|
| 909 |
+
if callback_on_step_end is not None:
|
| 910 |
+
callback_kwargs = {}
|
| 911 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 912 |
+
callback_kwargs[k] = locals()[k]
|
| 913 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 914 |
+
|
| 915 |
+
latents = callback_outputs.pop("latents", latents)
|
| 916 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 917 |
+
|
| 918 |
+
# call the callback, if provided
|
| 919 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 920 |
+
progress_bar.update()
|
| 921 |
+
|
| 922 |
+
if XLA_AVAILABLE:
|
| 923 |
+
xm.mark_step()
|
| 924 |
+
|
| 925 |
+
if comfyui_progressbar:
|
| 926 |
+
pbar.update(1)
|
| 927 |
+
|
| 928 |
+
self._current_timestep = None
|
| 929 |
+
if output_type == "latent":
|
| 930 |
+
image = latents
|
| 931 |
+
else:
|
| 932 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 933 |
+
latents = latents.to(self.vae.dtype)
|
| 934 |
+
latents_mean = (
|
| 935 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 936 |
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 937 |
+
.to(latents.device, latents.dtype)
|
| 938 |
+
)
|
| 939 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 940 |
+
latents.device, latents.dtype
|
| 941 |
+
)
|
| 942 |
+
latents = latents / latents_std + latents_mean
|
| 943 |
+
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
| 944 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 945 |
+
|
| 946 |
+
# Offload all models
|
| 947 |
+
self.maybe_free_model_hooks()
|
| 948 |
+
|
| 949 |
+
if not return_dict:
|
| 950 |
+
return (image,)
|
| 951 |
+
|
| 952 |
+
return QwenImagePipelineOutput(images=image)
|
videox_fun/pipeline/pipeline_qwenimage_edit_plus.py
ADDED
|
@@ -0,0 +1,937 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/naykun/diffusers/blob/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
|
| 2 |
+
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import math
|
| 22 |
+
import PIL.Image
|
| 23 |
+
import torch
|
| 24 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 26 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 27 |
+
from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
|
| 28 |
+
replace_example_docstring)
|
| 29 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 30 |
+
|
| 31 |
+
from ..models import (AutoencoderKLQwenImage,
|
| 32 |
+
Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor,
|
| 33 |
+
Qwen2Tokenizer, QwenImageTransformer2DModel)
|
| 34 |
+
|
| 35 |
+
if is_torch_xla_available():
|
| 36 |
+
import torch_xla.core.xla_model as xm
|
| 37 |
+
|
| 38 |
+
XLA_AVAILABLE = True
|
| 39 |
+
else:
|
| 40 |
+
XLA_AVAILABLE = False
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 43 |
+
|
| 44 |
+
EXAMPLE_DOC_STRING = """
|
| 45 |
+
Examples:
|
| 46 |
+
```py
|
| 47 |
+
>>> import torch
|
| 48 |
+
>>> from PIL import Image
|
| 49 |
+
>>> from diffusers import QwenImageEditPipeline
|
| 50 |
+
|
| 51 |
+
>>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16)
|
| 52 |
+
>>> pipe.to("cuda")
|
| 53 |
+
>>> prompt = "Change the cat to a dog"
|
| 54 |
+
>>> image = Image.open("cat.png")
|
| 55 |
+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
| 56 |
+
>>> # Refer to the pipeline documentation for more details.
|
| 57 |
+
>>> image = pipe(image, prompt, num_inference_steps=50).images[0]
|
| 58 |
+
>>> image.save("qwenimageedit.png")
|
| 59 |
+
```
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
CONDITION_IMAGE_SIZE = 384 * 384
|
| 63 |
+
VAE_IMAGE_SIZE = 1024 * 1024
|
| 64 |
+
|
| 65 |
+
PREFERRED_QWENIMAGE_RESOLUTIONS = [
|
| 66 |
+
(672, 1568),
|
| 67 |
+
(688, 1504),
|
| 68 |
+
(720, 1456),
|
| 69 |
+
(752, 1392),
|
| 70 |
+
(800, 1328),
|
| 71 |
+
(832, 1248),
|
| 72 |
+
(880, 1184),
|
| 73 |
+
(944, 1104),
|
| 74 |
+
(1024, 1024),
|
| 75 |
+
(1104, 944),
|
| 76 |
+
(1184, 880),
|
| 77 |
+
(1248, 832),
|
| 78 |
+
(1328, 800),
|
| 79 |
+
(1392, 752),
|
| 80 |
+
(1456, 720),
|
| 81 |
+
(1504, 688),
|
| 82 |
+
(1568, 672),
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
| 86 |
+
def calculate_shift(
|
| 87 |
+
image_seq_len,
|
| 88 |
+
base_seq_len: int = 256,
|
| 89 |
+
max_seq_len: int = 4096,
|
| 90 |
+
base_shift: float = 0.5,
|
| 91 |
+
max_shift: float = 1.15,
|
| 92 |
+
):
|
| 93 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 94 |
+
b = base_shift - m * base_seq_len
|
| 95 |
+
mu = image_seq_len * m + b
|
| 96 |
+
return mu
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 100 |
+
def retrieve_timesteps(
|
| 101 |
+
scheduler,
|
| 102 |
+
num_inference_steps: Optional[int] = None,
|
| 103 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 104 |
+
timesteps: Optional[List[int]] = None,
|
| 105 |
+
sigmas: Optional[List[float]] = None,
|
| 106 |
+
**kwargs,
|
| 107 |
+
):
|
| 108 |
+
r"""
|
| 109 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 110 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
scheduler (`SchedulerMixin`):
|
| 114 |
+
The scheduler to get timesteps from.
|
| 115 |
+
num_inference_steps (`int`):
|
| 116 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 117 |
+
must be `None`.
|
| 118 |
+
device (`str` or `torch.device`, *optional*):
|
| 119 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 120 |
+
timesteps (`List[int]`, *optional*):
|
| 121 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 122 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 123 |
+
sigmas (`List[float]`, *optional*):
|
| 124 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 125 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 129 |
+
second element is the number of inference steps.
|
| 130 |
+
"""
|
| 131 |
+
if timesteps is not None and sigmas is not None:
|
| 132 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 133 |
+
if timesteps is not None:
|
| 134 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 135 |
+
if not accepts_timesteps:
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 138 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 139 |
+
)
|
| 140 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 141 |
+
timesteps = scheduler.timesteps
|
| 142 |
+
num_inference_steps = len(timesteps)
|
| 143 |
+
elif sigmas is not None:
|
| 144 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 145 |
+
if not accept_sigmas:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 148 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 149 |
+
)
|
| 150 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 151 |
+
timesteps = scheduler.timesteps
|
| 152 |
+
num_inference_steps = len(timesteps)
|
| 153 |
+
else:
|
| 154 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 155 |
+
timesteps = scheduler.timesteps
|
| 156 |
+
return timesteps, num_inference_steps
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 160 |
+
def retrieve_latents(
|
| 161 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 162 |
+
):
|
| 163 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 164 |
+
return encoder_output.latent_dist.sample(generator)
|
| 165 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 166 |
+
return encoder_output.latent_dist.mode()
|
| 167 |
+
elif hasattr(encoder_output, "latents"):
|
| 168 |
+
return encoder_output.latents
|
| 169 |
+
else:
|
| 170 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def calculate_dimensions(target_area, ratio):
|
| 174 |
+
width = math.sqrt(target_area * ratio)
|
| 175 |
+
height = width / ratio
|
| 176 |
+
|
| 177 |
+
width = round(width / 32) * 32
|
| 178 |
+
height = round(height / 32) * 32
|
| 179 |
+
|
| 180 |
+
return width, height
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
@dataclass
|
| 184 |
+
class QwenImagePipelineOutput(BaseOutput):
|
| 185 |
+
"""
|
| 186 |
+
Output class for Stable Diffusion pipelines.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
| 190 |
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
| 191 |
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class QwenImageEditPlusPipeline(DiffusionPipeline):
|
| 198 |
+
r"""
|
| 199 |
+
The QwenImage pipeline for text-to-image generation.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
transformer ([`QwenImageTransformer2DModel`]):
|
| 203 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 204 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 205 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 206 |
+
vae ([`AutoencoderKL`]):
|
| 207 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 208 |
+
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
| 209 |
+
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
| 210 |
+
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
| 211 |
+
tokenizer (`QwenTokenizer`):
|
| 212 |
+
Tokenizer of class
|
| 213 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 217 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 218 |
+
|
| 219 |
+
def __init__(
|
| 220 |
+
self,
|
| 221 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 222 |
+
vae: AutoencoderKLQwenImage,
|
| 223 |
+
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
| 224 |
+
tokenizer: Qwen2Tokenizer,
|
| 225 |
+
processor: Qwen2VLProcessor,
|
| 226 |
+
transformer: QwenImageTransformer2DModel,
|
| 227 |
+
):
|
| 228 |
+
super().__init__()
|
| 229 |
+
|
| 230 |
+
self.register_modules(
|
| 231 |
+
vae=vae,
|
| 232 |
+
text_encoder=text_encoder,
|
| 233 |
+
tokenizer=tokenizer,
|
| 234 |
+
processor=processor,
|
| 235 |
+
transformer=transformer,
|
| 236 |
+
scheduler=scheduler,
|
| 237 |
+
)
|
| 238 |
+
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
| 239 |
+
self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
|
| 240 |
+
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
| 241 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
| 242 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 243 |
+
self.tokenizer_max_length = 1024
|
| 244 |
+
|
| 245 |
+
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
| 246 |
+
self.prompt_template_encode_start_idx = 64
|
| 247 |
+
self.default_sample_size = 128
|
| 248 |
+
|
| 249 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
|
| 250 |
+
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
| 251 |
+
bool_mask = mask.bool()
|
| 252 |
+
valid_lengths = bool_mask.sum(dim=1)
|
| 253 |
+
selected = hidden_states[bool_mask]
|
| 254 |
+
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
| 255 |
+
|
| 256 |
+
return split_result
|
| 257 |
+
|
| 258 |
+
def _get_qwen_prompt_embeds(
|
| 259 |
+
self,
|
| 260 |
+
prompt: Union[str, List[str]] = None,
|
| 261 |
+
image: Optional[torch.Tensor] = None,
|
| 262 |
+
device: Optional[torch.device] = None,
|
| 263 |
+
dtype: Optional[torch.dtype] = None,
|
| 264 |
+
):
|
| 265 |
+
device = device or self._execution_device
|
| 266 |
+
dtype = dtype or self.text_encoder.dtype
|
| 267 |
+
|
| 268 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 269 |
+
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
|
| 270 |
+
if isinstance(image, list):
|
| 271 |
+
base_img_prompt = ""
|
| 272 |
+
for i, img in enumerate(image):
|
| 273 |
+
base_img_prompt += img_prompt_template.format(i + 1)
|
| 274 |
+
elif image is not None:
|
| 275 |
+
base_img_prompt = img_prompt_template.format(1)
|
| 276 |
+
else:
|
| 277 |
+
base_img_prompt = ""
|
| 278 |
+
|
| 279 |
+
template = self.prompt_template_encode
|
| 280 |
+
|
| 281 |
+
drop_idx = self.prompt_template_encode_start_idx
|
| 282 |
+
txt = [template.format(base_img_prompt + e) for e in prompt]
|
| 283 |
+
|
| 284 |
+
model_inputs = self.processor(
|
| 285 |
+
text=txt,
|
| 286 |
+
images=image,
|
| 287 |
+
padding=True,
|
| 288 |
+
return_tensors="pt",
|
| 289 |
+
).to(device)
|
| 290 |
+
|
| 291 |
+
outputs = self.text_encoder(
|
| 292 |
+
input_ids=model_inputs.input_ids,
|
| 293 |
+
attention_mask=model_inputs.attention_mask,
|
| 294 |
+
pixel_values=model_inputs.pixel_values,
|
| 295 |
+
image_grid_thw=model_inputs.image_grid_thw,
|
| 296 |
+
output_hidden_states=True,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
hidden_states = outputs.hidden_states[-1]
|
| 300 |
+
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
| 301 |
+
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
| 302 |
+
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
| 303 |
+
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
| 304 |
+
prompt_embeds = torch.stack(
|
| 305 |
+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
| 306 |
+
)
|
| 307 |
+
encoder_attention_mask = torch.stack(
|
| 308 |
+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 312 |
+
|
| 313 |
+
return prompt_embeds, encoder_attention_mask
|
| 314 |
+
|
| 315 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
|
| 316 |
+
def encode_prompt(
|
| 317 |
+
self,
|
| 318 |
+
prompt: Union[str, List[str]],
|
| 319 |
+
image: Optional[torch.Tensor] = None,
|
| 320 |
+
device: Optional[torch.device] = None,
|
| 321 |
+
num_images_per_prompt: int = 1,
|
| 322 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 323 |
+
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
| 324 |
+
max_sequence_length: int = 1024,
|
| 325 |
+
):
|
| 326 |
+
r"""
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 330 |
+
prompt to be encoded
|
| 331 |
+
image (`torch.Tensor`, *optional*):
|
| 332 |
+
image to be encoded
|
| 333 |
+
device: (`torch.device`):
|
| 334 |
+
torch device
|
| 335 |
+
num_images_per_prompt (`int`):
|
| 336 |
+
number of images that should be generated per prompt
|
| 337 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 338 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 339 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 340 |
+
"""
|
| 341 |
+
device = device or self._execution_device
|
| 342 |
+
|
| 343 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 344 |
+
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
|
| 345 |
+
|
| 346 |
+
if prompt_embeds is None:
|
| 347 |
+
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
|
| 348 |
+
|
| 349 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 350 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 351 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 352 |
+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
| 353 |
+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
| 354 |
+
|
| 355 |
+
return prompt_embeds, prompt_embeds_mask
|
| 356 |
+
|
| 357 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
|
| 358 |
+
def check_inputs(
|
| 359 |
+
self,
|
| 360 |
+
prompt,
|
| 361 |
+
height,
|
| 362 |
+
width,
|
| 363 |
+
negative_prompt=None,
|
| 364 |
+
prompt_embeds=None,
|
| 365 |
+
negative_prompt_embeds=None,
|
| 366 |
+
prompt_embeds_mask=None,
|
| 367 |
+
negative_prompt_embeds_mask=None,
|
| 368 |
+
callback_on_step_end_tensor_inputs=None,
|
| 369 |
+
max_sequence_length=None,
|
| 370 |
+
):
|
| 371 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 372 |
+
logger.warning(
|
| 373 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 377 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 378 |
+
):
|
| 379 |
+
raise ValueError(
|
| 380 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
if prompt is not None and prompt_embeds is not None:
|
| 384 |
+
raise ValueError(
|
| 385 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 386 |
+
" only forward one of the two."
|
| 387 |
+
)
|
| 388 |
+
elif prompt is None and prompt_embeds is None:
|
| 389 |
+
raise ValueError(
|
| 390 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 391 |
+
)
|
| 392 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 393 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 394 |
+
|
| 395 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 396 |
+
raise ValueError(
|
| 397 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 398 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
if prompt_embeds is not None and prompt_embeds_mask is None:
|
| 402 |
+
raise ValueError(
|
| 403 |
+
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
| 404 |
+
)
|
| 405 |
+
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
| 406 |
+
raise ValueError(
|
| 407 |
+
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
if max_sequence_length is not None and max_sequence_length > 1024:
|
| 411 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
| 412 |
+
|
| 413 |
+
@staticmethod
|
| 414 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
|
| 415 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 416 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 417 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 418 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 419 |
+
|
| 420 |
+
return latents
|
| 421 |
+
|
| 422 |
+
@staticmethod
|
| 423 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
|
| 424 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 425 |
+
batch_size, num_patches, channels = latents.shape
|
| 426 |
+
|
| 427 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 428 |
+
# latent height and width to be divisible by 2.
|
| 429 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 430 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 431 |
+
|
| 432 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 433 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 434 |
+
|
| 435 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
|
| 436 |
+
|
| 437 |
+
return latents
|
| 438 |
+
|
| 439 |
+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
|
| 440 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
| 441 |
+
if isinstance(generator, list):
|
| 442 |
+
image_latents = [
|
| 443 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
|
| 444 |
+
for i in range(image.shape[0])
|
| 445 |
+
]
|
| 446 |
+
image_latents = torch.cat(image_latents, dim=0)
|
| 447 |
+
else:
|
| 448 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
| 449 |
+
latents_mean = (
|
| 450 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 451 |
+
.view(1, self.latent_channels, 1, 1, 1)
|
| 452 |
+
.to(image_latents.device, image_latents.dtype)
|
| 453 |
+
)
|
| 454 |
+
latents_std = (
|
| 455 |
+
torch.tensor(self.vae.config.latents_std)
|
| 456 |
+
.view(1, self.latent_channels, 1, 1, 1)
|
| 457 |
+
.to(image_latents.device, image_latents.dtype)
|
| 458 |
+
)
|
| 459 |
+
image_latents = (image_latents - latents_mean) / latents_std
|
| 460 |
+
|
| 461 |
+
return image_latents
|
| 462 |
+
|
| 463 |
+
def prepare_latents(
|
| 464 |
+
self,
|
| 465 |
+
images,
|
| 466 |
+
batch_size,
|
| 467 |
+
num_channels_latents,
|
| 468 |
+
height,
|
| 469 |
+
width,
|
| 470 |
+
dtype,
|
| 471 |
+
device,
|
| 472 |
+
generator,
|
| 473 |
+
latents=None,
|
| 474 |
+
):
|
| 475 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 476 |
+
# latent height and width to be divisible by 2.
|
| 477 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 478 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 479 |
+
|
| 480 |
+
shape = (batch_size, 1, num_channels_latents, height, width)
|
| 481 |
+
|
| 482 |
+
image_latents = None
|
| 483 |
+
if images is not None:
|
| 484 |
+
if not isinstance(images, list):
|
| 485 |
+
images = [images]
|
| 486 |
+
all_image_latents = []
|
| 487 |
+
for image in images:
|
| 488 |
+
image = image.to(device=device, dtype=dtype)
|
| 489 |
+
if image.shape[1] != self.latent_channels:
|
| 490 |
+
image_latents = self._encode_vae_image(image=image, generator=generator)
|
| 491 |
+
else:
|
| 492 |
+
image_latents = image
|
| 493 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
| 494 |
+
# expand init_latents for batch_size
|
| 495 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
| 496 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
| 497 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
| 498 |
+
raise ValueError(
|
| 499 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
| 500 |
+
)
|
| 501 |
+
else:
|
| 502 |
+
image_latents = torch.cat([image_latents], dim=0)
|
| 503 |
+
|
| 504 |
+
image_latent_height, image_latent_width = image_latents.shape[3:]
|
| 505 |
+
image_latents = self._pack_latents(
|
| 506 |
+
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
|
| 507 |
+
)
|
| 508 |
+
all_image_latents.append(image_latents)
|
| 509 |
+
image_latents = torch.cat(all_image_latents, dim=1)
|
| 510 |
+
|
| 511 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 512 |
+
raise ValueError(
|
| 513 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 514 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 515 |
+
)
|
| 516 |
+
if latents is None:
|
| 517 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 518 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 519 |
+
else:
|
| 520 |
+
latents = latents.to(device=device, dtype=dtype)
|
| 521 |
+
|
| 522 |
+
return latents, image_latents
|
| 523 |
+
|
| 524 |
+
@property
|
| 525 |
+
def guidance_scale(self):
|
| 526 |
+
return self._guidance_scale
|
| 527 |
+
|
| 528 |
+
@property
|
| 529 |
+
def attention_kwargs(self):
|
| 530 |
+
return self._attention_kwargs
|
| 531 |
+
|
| 532 |
+
@property
|
| 533 |
+
def num_timesteps(self):
|
| 534 |
+
return self._num_timesteps
|
| 535 |
+
|
| 536 |
+
@property
|
| 537 |
+
def current_timestep(self):
|
| 538 |
+
return self._current_timestep
|
| 539 |
+
|
| 540 |
+
@property
|
| 541 |
+
def interrupt(self):
|
| 542 |
+
return self._interrupt
|
| 543 |
+
|
| 544 |
+
@torch.no_grad()
|
| 545 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 546 |
+
def __call__(
|
| 547 |
+
self,
|
| 548 |
+
image = None,
|
| 549 |
+
prompt: Union[str, List[str]] = None,
|
| 550 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 551 |
+
true_cfg_scale: float = 4.0,
|
| 552 |
+
height: Optional[int] = None,
|
| 553 |
+
width: Optional[int] = None,
|
| 554 |
+
num_inference_steps: int = 50,
|
| 555 |
+
sigmas: Optional[List[float]] = None,
|
| 556 |
+
guidance_scale: Optional[float] = None,
|
| 557 |
+
num_images_per_prompt: int = 1,
|
| 558 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 559 |
+
latents: Optional[torch.Tensor] = None,
|
| 560 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 561 |
+
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
| 562 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 563 |
+
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
| 564 |
+
output_type: Optional[str] = "pil",
|
| 565 |
+
return_dict: bool = True,
|
| 566 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 567 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 568 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 569 |
+
max_sequence_length: int = 512,
|
| 570 |
+
comfyui_progressbar: bool = False,
|
| 571 |
+
):
|
| 572 |
+
r"""
|
| 573 |
+
Function invoked when calling the pipeline for generation.
|
| 574 |
+
|
| 575 |
+
Args:
|
| 576 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 577 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
| 578 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
| 579 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
| 580 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
| 581 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
| 582 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 583 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 584 |
+
instead.
|
| 585 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 586 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 587 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
| 588 |
+
not greater than `1`).
|
| 589 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
| 590 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
|
| 591 |
+
Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
|
| 592 |
+
equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
|
| 593 |
+
enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
|
| 594 |
+
encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
|
| 595 |
+
lower image quality.
|
| 596 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 597 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 598 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 599 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 600 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 601 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 602 |
+
expense of slower inference.
|
| 603 |
+
sigmas (`List[float]`, *optional*):
|
| 604 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 605 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 606 |
+
will be used.
|
| 607 |
+
guidance_scale (`float`, *optional*, defaults to None):
|
| 608 |
+
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
|
| 609 |
+
where the guidance scale is applied during inference through noise prediction rescaling, guidance
|
| 610 |
+
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
|
| 611 |
+
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
|
| 612 |
+
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
|
| 613 |
+
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
|
| 614 |
+
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
|
| 615 |
+
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
|
| 616 |
+
enable classifier-free guidance computations).
|
| 617 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 618 |
+
The number of images to generate per prompt.
|
| 619 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 620 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 621 |
+
to make generation deterministic.
|
| 622 |
+
latents (`torch.Tensor`, *optional*):
|
| 623 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 624 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 625 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 626 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 627 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 628 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 629 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 630 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 631 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 632 |
+
argument.
|
| 633 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 634 |
+
The output format of the generate image. Choose between
|
| 635 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 636 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 637 |
+
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
| 638 |
+
attention_kwargs (`dict`, *optional*):
|
| 639 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 640 |
+
`self.processor` in
|
| 641 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 642 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 643 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 644 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 645 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 646 |
+
`callback_on_step_end_tensor_inputs`.
|
| 647 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 648 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 649 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 650 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 651 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 652 |
+
|
| 653 |
+
Examples:
|
| 654 |
+
|
| 655 |
+
Returns:
|
| 656 |
+
[`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
|
| 657 |
+
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
| 658 |
+
returning a tuple, the first element is a list with the generated images.
|
| 659 |
+
"""
|
| 660 |
+
image_size = image[-1].size if isinstance(image, list) else image.size
|
| 661 |
+
calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
|
| 662 |
+
height = height or calculated_height
|
| 663 |
+
width = width or calculated_width
|
| 664 |
+
|
| 665 |
+
multiple_of = self.vae_scale_factor * 2
|
| 666 |
+
width = width // multiple_of * multiple_of
|
| 667 |
+
height = height // multiple_of * multiple_of
|
| 668 |
+
|
| 669 |
+
# 1. Check inputs. Raise error if not correct
|
| 670 |
+
self.check_inputs(
|
| 671 |
+
prompt,
|
| 672 |
+
height,
|
| 673 |
+
width,
|
| 674 |
+
negative_prompt=negative_prompt,
|
| 675 |
+
prompt_embeds=prompt_embeds,
|
| 676 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 677 |
+
prompt_embeds_mask=prompt_embeds_mask,
|
| 678 |
+
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
| 679 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 680 |
+
max_sequence_length=max_sequence_length,
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
self._guidance_scale = guidance_scale
|
| 684 |
+
self._attention_kwargs = attention_kwargs
|
| 685 |
+
self._current_timestep = None
|
| 686 |
+
self._interrupt = False
|
| 687 |
+
|
| 688 |
+
# 2. Define call parameters
|
| 689 |
+
if prompt is not None and isinstance(prompt, str):
|
| 690 |
+
batch_size = 1
|
| 691 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 692 |
+
batch_size = len(prompt)
|
| 693 |
+
else:
|
| 694 |
+
batch_size = prompt_embeds.shape[0]
|
| 695 |
+
|
| 696 |
+
device = self._execution_device
|
| 697 |
+
if comfyui_progressbar:
|
| 698 |
+
from comfy.utils import ProgressBar
|
| 699 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 700 |
+
|
| 701 |
+
# 3. Preprocess image
|
| 702 |
+
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
| 703 |
+
if not isinstance(image, list):
|
| 704 |
+
image = [image]
|
| 705 |
+
condition_image_sizes = []
|
| 706 |
+
condition_images = []
|
| 707 |
+
vae_image_sizes = []
|
| 708 |
+
vae_images = []
|
| 709 |
+
for img in image:
|
| 710 |
+
image_width, image_height = img.size
|
| 711 |
+
condition_width, condition_height = calculate_dimensions(
|
| 712 |
+
CONDITION_IMAGE_SIZE, image_width / image_height
|
| 713 |
+
)
|
| 714 |
+
vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
|
| 715 |
+
condition_image_sizes.append((condition_width, condition_height))
|
| 716 |
+
vae_image_sizes.append((vae_width, vae_height))
|
| 717 |
+
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
|
| 718 |
+
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
|
| 719 |
+
|
| 720 |
+
has_neg_prompt = negative_prompt is not None or (
|
| 721 |
+
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
if true_cfg_scale > 1 and not has_neg_prompt:
|
| 725 |
+
logger.warning(
|
| 726 |
+
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
|
| 727 |
+
)
|
| 728 |
+
elif true_cfg_scale <= 1 and has_neg_prompt:
|
| 729 |
+
logger.warning(
|
| 730 |
+
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 734 |
+
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
| 735 |
+
image=condition_images,
|
| 736 |
+
prompt=prompt,
|
| 737 |
+
prompt_embeds=prompt_embeds,
|
| 738 |
+
prompt_embeds_mask=prompt_embeds_mask,
|
| 739 |
+
device=device,
|
| 740 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 741 |
+
max_sequence_length=max_sequence_length,
|
| 742 |
+
)
|
| 743 |
+
if do_true_cfg:
|
| 744 |
+
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
| 745 |
+
image=condition_images,
|
| 746 |
+
prompt=negative_prompt,
|
| 747 |
+
prompt_embeds=negative_prompt_embeds,
|
| 748 |
+
prompt_embeds_mask=negative_prompt_embeds_mask,
|
| 749 |
+
device=device,
|
| 750 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 751 |
+
max_sequence_length=max_sequence_length,
|
| 752 |
+
)
|
| 753 |
+
if comfyui_progressbar:
|
| 754 |
+
pbar.update(1)
|
| 755 |
+
|
| 756 |
+
# 4. Prepare latent variables
|
| 757 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 758 |
+
latents, image_latents = self.prepare_latents(
|
| 759 |
+
vae_images,
|
| 760 |
+
batch_size * num_images_per_prompt,
|
| 761 |
+
num_channels_latents,
|
| 762 |
+
height,
|
| 763 |
+
width,
|
| 764 |
+
prompt_embeds.dtype,
|
| 765 |
+
device,
|
| 766 |
+
generator,
|
| 767 |
+
latents,
|
| 768 |
+
)
|
| 769 |
+
img_shapes = [
|
| 770 |
+
[
|
| 771 |
+
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
|
| 772 |
+
*[
|
| 773 |
+
(1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
|
| 774 |
+
for vae_width, vae_height in vae_image_sizes
|
| 775 |
+
],
|
| 776 |
+
]
|
| 777 |
+
] * batch_size
|
| 778 |
+
|
| 779 |
+
# 5. Prepare timesteps
|
| 780 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 781 |
+
image_seq_len = latents.shape[1]
|
| 782 |
+
mu = calculate_shift(
|
| 783 |
+
image_seq_len,
|
| 784 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 785 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 786 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 787 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 788 |
+
)
|
| 789 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 790 |
+
self.scheduler,
|
| 791 |
+
num_inference_steps,
|
| 792 |
+
device,
|
| 793 |
+
sigmas=sigmas,
|
| 794 |
+
mu=mu,
|
| 795 |
+
)
|
| 796 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 797 |
+
self._num_timesteps = len(timesteps)
|
| 798 |
+
|
| 799 |
+
# handle guidance
|
| 800 |
+
if self.transformer.config.guidance_embeds and guidance_scale is None:
|
| 801 |
+
raise ValueError("guidance_scale is required for guidance-distilled model.")
|
| 802 |
+
elif self.transformer.config.guidance_embeds:
|
| 803 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 804 |
+
guidance = guidance.expand(latents.shape[0])
|
| 805 |
+
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
|
| 806 |
+
logger.warning(
|
| 807 |
+
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
|
| 808 |
+
)
|
| 809 |
+
guidance = None
|
| 810 |
+
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
|
| 811 |
+
guidance = None
|
| 812 |
+
|
| 813 |
+
if self.attention_kwargs is None:
|
| 814 |
+
self._attention_kwargs = {}
|
| 815 |
+
|
| 816 |
+
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
| 817 |
+
negative_txt_seq_lens = (
|
| 818 |
+
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
| 819 |
+
)
|
| 820 |
+
if comfyui_progressbar:
|
| 821 |
+
pbar.update(1)
|
| 822 |
+
|
| 823 |
+
# 6. Denoising loop
|
| 824 |
+
self.scheduler.set_begin_index(0)
|
| 825 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 826 |
+
for i, t in enumerate(timesteps):
|
| 827 |
+
self.transformer.current_steps = i
|
| 828 |
+
if self.interrupt:
|
| 829 |
+
continue
|
| 830 |
+
|
| 831 |
+
if image_latents is not None:
|
| 832 |
+
latents_and_image_latents = torch.cat([latents, image_latents], dim=1)
|
| 833 |
+
else:
|
| 834 |
+
latents_and_image_latents = latents
|
| 835 |
+
|
| 836 |
+
if do_true_cfg:
|
| 837 |
+
latent_model_input = torch.cat([latents_and_image_latents] * 2)
|
| 838 |
+
prompt_embeds_mask_input = [_negative_prompt_embeds_mask for _negative_prompt_embeds_mask in negative_prompt_embeds_mask] + [_prompt_embeds_mask for _prompt_embeds_mask in prompt_embeds_mask]
|
| 839 |
+
prompt_embeds_input = [_negative_prompt_embeds for _negative_prompt_embeds in negative_prompt_embeds] + [_prompt_embeds for _prompt_embeds in prompt_embeds]
|
| 840 |
+
img_shapes_input = img_shapes * 2
|
| 841 |
+
txt_seq_lens_input = negative_txt_seq_lens + txt_seq_lens
|
| 842 |
+
else:
|
| 843 |
+
latent_model_input = latents_and_image_latents
|
| 844 |
+
prompt_embeds_mask_input = prompt_embeds_mask
|
| 845 |
+
prompt_embeds_input = prompt_embeds
|
| 846 |
+
img_shapes_input = img_shapes
|
| 847 |
+
txt_seq_lens_input = txt_seq_lens
|
| 848 |
+
|
| 849 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 850 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 851 |
+
|
| 852 |
+
# handle guidance
|
| 853 |
+
if self.transformer.config.guidance_embeds:
|
| 854 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 855 |
+
guidance = guidance.expand(latent_model_input.shape[0])
|
| 856 |
+
else:
|
| 857 |
+
guidance = None
|
| 858 |
+
|
| 859 |
+
self._current_timestep = t
|
| 860 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 861 |
+
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
| 862 |
+
|
| 863 |
+
with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device):
|
| 864 |
+
noise_pred = self.transformer.forward_bs(
|
| 865 |
+
x=latent_model_input,
|
| 866 |
+
timestep=timestep / 1000,
|
| 867 |
+
guidance=guidance,
|
| 868 |
+
encoder_hidden_states_mask=prompt_embeds_mask_input,
|
| 869 |
+
encoder_hidden_states=prompt_embeds_input,
|
| 870 |
+
img_shapes=img_shapes_input,
|
| 871 |
+
txt_seq_lens=txt_seq_lens_input,
|
| 872 |
+
attention_kwargs=self.attention_kwargs,
|
| 873 |
+
return_dict=False,
|
| 874 |
+
)
|
| 875 |
+
noise_pred = noise_pred[:, : latents.size(1)]
|
| 876 |
+
|
| 877 |
+
if do_true_cfg:
|
| 878 |
+
neg_noise_pred, noise_pred = noise_pred.chunk(2)
|
| 879 |
+
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 880 |
+
|
| 881 |
+
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
| 882 |
+
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
| 883 |
+
noise_pred = comb_pred * (cond_norm / noise_norm)
|
| 884 |
+
|
| 885 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 886 |
+
latents_dtype = latents.dtype
|
| 887 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 888 |
+
|
| 889 |
+
if latents.dtype != latents_dtype:
|
| 890 |
+
if torch.backends.mps.is_available():
|
| 891 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 892 |
+
latents = latents.to(latents_dtype)
|
| 893 |
+
|
| 894 |
+
if callback_on_step_end is not None:
|
| 895 |
+
callback_kwargs = {}
|
| 896 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 897 |
+
callback_kwargs[k] = locals()[k]
|
| 898 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 899 |
+
|
| 900 |
+
latents = callback_outputs.pop("latents", latents)
|
| 901 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 902 |
+
|
| 903 |
+
# call the callback, if provided
|
| 904 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 905 |
+
progress_bar.update()
|
| 906 |
+
|
| 907 |
+
if XLA_AVAILABLE:
|
| 908 |
+
xm.mark_step()
|
| 909 |
+
|
| 910 |
+
if comfyui_progressbar:
|
| 911 |
+
pbar.update(1)
|
| 912 |
+
|
| 913 |
+
self._current_timestep = None
|
| 914 |
+
if output_type == "latent":
|
| 915 |
+
image = latents
|
| 916 |
+
else:
|
| 917 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 918 |
+
latents = latents.to(self.vae.dtype)
|
| 919 |
+
latents_mean = (
|
| 920 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 921 |
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 922 |
+
.to(latents.device, latents.dtype)
|
| 923 |
+
)
|
| 924 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 925 |
+
latents.device, latents.dtype
|
| 926 |
+
)
|
| 927 |
+
latents = latents / latents_std + latents_mean
|
| 928 |
+
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
| 929 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 930 |
+
|
| 931 |
+
# Offload all models
|
| 932 |
+
self.maybe_free_model_hooks()
|
| 933 |
+
|
| 934 |
+
if not return_dict:
|
| 935 |
+
return (image,)
|
| 936 |
+
|
| 937 |
+
return QwenImagePipelineOutput(images=image)
|
videox_fun/pipeline/pipeline_wan.py
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 9 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 10 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 11 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 12 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 13 |
+
from diffusers.video_processor import VideoProcessor
|
| 14 |
+
|
| 15 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer,
|
| 16 |
+
WanT5EncoderModel, WanTransformer3DModel)
|
| 17 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 18 |
+
get_sampling_sigmas)
|
| 19 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
EXAMPLE_DOC_STRING = """
|
| 25 |
+
Examples:
|
| 26 |
+
```python
|
| 27 |
+
pass
|
| 28 |
+
```
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 33 |
+
def retrieve_timesteps(
|
| 34 |
+
scheduler,
|
| 35 |
+
num_inference_steps: Optional[int] = None,
|
| 36 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 37 |
+
timesteps: Optional[List[int]] = None,
|
| 38 |
+
sigmas: Optional[List[float]] = None,
|
| 39 |
+
**kwargs,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 43 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
scheduler (`SchedulerMixin`):
|
| 47 |
+
The scheduler to get timesteps from.
|
| 48 |
+
num_inference_steps (`int`):
|
| 49 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 50 |
+
must be `None`.
|
| 51 |
+
device (`str` or `torch.device`, *optional*):
|
| 52 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 53 |
+
timesteps (`List[int]`, *optional*):
|
| 54 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 55 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 56 |
+
sigmas (`List[float]`, *optional*):
|
| 57 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 58 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 62 |
+
second element is the number of inference steps.
|
| 63 |
+
"""
|
| 64 |
+
if timesteps is not None and sigmas is not None:
|
| 65 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 66 |
+
if timesteps is not None:
|
| 67 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 68 |
+
if not accepts_timesteps:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 71 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 72 |
+
)
|
| 73 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 74 |
+
timesteps = scheduler.timesteps
|
| 75 |
+
num_inference_steps = len(timesteps)
|
| 76 |
+
elif sigmas is not None:
|
| 77 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 78 |
+
if not accept_sigmas:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 81 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 82 |
+
)
|
| 83 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 84 |
+
timesteps = scheduler.timesteps
|
| 85 |
+
num_inference_steps = len(timesteps)
|
| 86 |
+
else:
|
| 87 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 88 |
+
timesteps = scheduler.timesteps
|
| 89 |
+
return timesteps, num_inference_steps
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass
|
| 93 |
+
class WanPipelineOutput(BaseOutput):
|
| 94 |
+
r"""
|
| 95 |
+
Output class for CogVideo pipelines.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 99 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 100 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 101 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
videos: torch.Tensor
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class WanPipeline(DiffusionPipeline):
|
| 108 |
+
r"""
|
| 109 |
+
Pipeline for text-to-video generation using Wan.
|
| 110 |
+
|
| 111 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 112 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
_optional_components = []
|
| 116 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 117 |
+
|
| 118 |
+
_callback_tensor_inputs = [
|
| 119 |
+
"latents",
|
| 120 |
+
"prompt_embeds",
|
| 121 |
+
"negative_prompt_embeds",
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
tokenizer: AutoTokenizer,
|
| 127 |
+
text_encoder: WanT5EncoderModel,
|
| 128 |
+
vae: AutoencoderKLWan,
|
| 129 |
+
transformer: WanTransformer3DModel,
|
| 130 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 131 |
+
):
|
| 132 |
+
super().__init__()
|
| 133 |
+
|
| 134 |
+
self.register_modules(
|
| 135 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 136 |
+
)
|
| 137 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 138 |
+
|
| 139 |
+
def _get_t5_prompt_embeds(
|
| 140 |
+
self,
|
| 141 |
+
prompt: Union[str, List[str]] = None,
|
| 142 |
+
num_videos_per_prompt: int = 1,
|
| 143 |
+
max_sequence_length: int = 512,
|
| 144 |
+
device: Optional[torch.device] = None,
|
| 145 |
+
dtype: Optional[torch.dtype] = None,
|
| 146 |
+
):
|
| 147 |
+
device = device or self._execution_device
|
| 148 |
+
dtype = dtype or self.text_encoder.dtype
|
| 149 |
+
|
| 150 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 151 |
+
batch_size = len(prompt)
|
| 152 |
+
|
| 153 |
+
text_inputs = self.tokenizer(
|
| 154 |
+
prompt,
|
| 155 |
+
padding="max_length",
|
| 156 |
+
max_length=max_sequence_length,
|
| 157 |
+
truncation=True,
|
| 158 |
+
add_special_tokens=True,
|
| 159 |
+
return_tensors="pt",
|
| 160 |
+
)
|
| 161 |
+
text_input_ids = text_inputs.input_ids
|
| 162 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 163 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 164 |
+
|
| 165 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 166 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 167 |
+
logger.warning(
|
| 168 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 169 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 173 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 174 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 175 |
+
|
| 176 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 177 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 178 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 179 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 180 |
+
|
| 181 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 182 |
+
|
| 183 |
+
def encode_prompt(
|
| 184 |
+
self,
|
| 185 |
+
prompt: Union[str, List[str]],
|
| 186 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 187 |
+
do_classifier_free_guidance: bool = True,
|
| 188 |
+
num_videos_per_prompt: int = 1,
|
| 189 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 190 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 191 |
+
max_sequence_length: int = 512,
|
| 192 |
+
device: Optional[torch.device] = None,
|
| 193 |
+
dtype: Optional[torch.dtype] = None,
|
| 194 |
+
):
|
| 195 |
+
r"""
|
| 196 |
+
Encodes the prompt into text encoder hidden states.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 200 |
+
prompt to be encoded
|
| 201 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 202 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 203 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 204 |
+
less than `1`).
|
| 205 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 206 |
+
Whether to use classifier free guidance or not.
|
| 207 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 208 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 209 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 210 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 211 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 212 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 213 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 214 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 215 |
+
argument.
|
| 216 |
+
device: (`torch.device`, *optional*):
|
| 217 |
+
torch device
|
| 218 |
+
dtype: (`torch.dtype`, *optional*):
|
| 219 |
+
torch dtype
|
| 220 |
+
"""
|
| 221 |
+
device = device or self._execution_device
|
| 222 |
+
|
| 223 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 224 |
+
if prompt is not None:
|
| 225 |
+
batch_size = len(prompt)
|
| 226 |
+
else:
|
| 227 |
+
batch_size = prompt_embeds.shape[0]
|
| 228 |
+
|
| 229 |
+
if prompt_embeds is None:
|
| 230 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 231 |
+
prompt=prompt,
|
| 232 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 233 |
+
max_sequence_length=max_sequence_length,
|
| 234 |
+
device=device,
|
| 235 |
+
dtype=dtype,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 239 |
+
negative_prompt = negative_prompt or ""
|
| 240 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 241 |
+
|
| 242 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 243 |
+
raise TypeError(
|
| 244 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 245 |
+
f" {type(prompt)}."
|
| 246 |
+
)
|
| 247 |
+
elif batch_size != len(negative_prompt):
|
| 248 |
+
raise ValueError(
|
| 249 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 250 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 251 |
+
" the batch size of `prompt`."
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 255 |
+
prompt=negative_prompt,
|
| 256 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 257 |
+
max_sequence_length=max_sequence_length,
|
| 258 |
+
device=device,
|
| 259 |
+
dtype=dtype,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return prompt_embeds, negative_prompt_embeds
|
| 263 |
+
|
| 264 |
+
def prepare_latents(
|
| 265 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 266 |
+
):
|
| 267 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 268 |
+
raise ValueError(
|
| 269 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 270 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
shape = (
|
| 274 |
+
batch_size,
|
| 275 |
+
num_channels_latents,
|
| 276 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1,
|
| 277 |
+
height // self.vae.spatial_compression_ratio,
|
| 278 |
+
width // self.vae.spatial_compression_ratio,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
if latents is None:
|
| 282 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 283 |
+
else:
|
| 284 |
+
latents = latents.to(device)
|
| 285 |
+
|
| 286 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 287 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 288 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 289 |
+
return latents
|
| 290 |
+
|
| 291 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 292 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 293 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 294 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 295 |
+
frames = frames.cpu().float().numpy()
|
| 296 |
+
return frames
|
| 297 |
+
|
| 298 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 299 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 300 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 301 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 302 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 303 |
+
# and should be between [0, 1]
|
| 304 |
+
|
| 305 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 306 |
+
extra_step_kwargs = {}
|
| 307 |
+
if accepts_eta:
|
| 308 |
+
extra_step_kwargs["eta"] = eta
|
| 309 |
+
|
| 310 |
+
# check if the scheduler accepts generator
|
| 311 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 312 |
+
if accepts_generator:
|
| 313 |
+
extra_step_kwargs["generator"] = generator
|
| 314 |
+
return extra_step_kwargs
|
| 315 |
+
|
| 316 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 317 |
+
def check_inputs(
|
| 318 |
+
self,
|
| 319 |
+
prompt,
|
| 320 |
+
height,
|
| 321 |
+
width,
|
| 322 |
+
negative_prompt,
|
| 323 |
+
callback_on_step_end_tensor_inputs,
|
| 324 |
+
prompt_embeds=None,
|
| 325 |
+
negative_prompt_embeds=None,
|
| 326 |
+
):
|
| 327 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 328 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 329 |
+
|
| 330 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 331 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 332 |
+
):
|
| 333 |
+
raise ValueError(
|
| 334 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 335 |
+
)
|
| 336 |
+
if prompt is not None and prompt_embeds is not None:
|
| 337 |
+
raise ValueError(
|
| 338 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 339 |
+
" only forward one of the two."
|
| 340 |
+
)
|
| 341 |
+
elif prompt is None and prompt_embeds is None:
|
| 342 |
+
raise ValueError(
|
| 343 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 344 |
+
)
|
| 345 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 346 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 347 |
+
|
| 348 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 349 |
+
raise ValueError(
|
| 350 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 351 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 355 |
+
raise ValueError(
|
| 356 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 357 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 361 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 362 |
+
raise ValueError(
|
| 363 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 364 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 365 |
+
f" {negative_prompt_embeds.shape}."
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
@property
|
| 369 |
+
def guidance_scale(self):
|
| 370 |
+
return self._guidance_scale
|
| 371 |
+
|
| 372 |
+
@property
|
| 373 |
+
def num_timesteps(self):
|
| 374 |
+
return self._num_timesteps
|
| 375 |
+
|
| 376 |
+
@property
|
| 377 |
+
def attention_kwargs(self):
|
| 378 |
+
return self._attention_kwargs
|
| 379 |
+
|
| 380 |
+
@property
|
| 381 |
+
def interrupt(self):
|
| 382 |
+
return self._interrupt
|
| 383 |
+
|
| 384 |
+
@torch.no_grad()
|
| 385 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 386 |
+
def __call__(
|
| 387 |
+
self,
|
| 388 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 389 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 390 |
+
height: int = 480,
|
| 391 |
+
width: int = 720,
|
| 392 |
+
num_frames: int = 49,
|
| 393 |
+
num_inference_steps: int = 50,
|
| 394 |
+
timesteps: Optional[List[int]] = None,
|
| 395 |
+
guidance_scale: float = 6,
|
| 396 |
+
num_videos_per_prompt: int = 1,
|
| 397 |
+
eta: float = 0.0,
|
| 398 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 399 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 400 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 401 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 402 |
+
output_type: str = "numpy",
|
| 403 |
+
return_dict: bool = False,
|
| 404 |
+
callback_on_step_end: Optional[
|
| 405 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 406 |
+
] = None,
|
| 407 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 408 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 409 |
+
max_sequence_length: int = 512,
|
| 410 |
+
comfyui_progressbar: bool = False,
|
| 411 |
+
shift: int = 5,
|
| 412 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 413 |
+
"""
|
| 414 |
+
Function invoked when calling the pipeline for generation.
|
| 415 |
+
Args:
|
| 416 |
+
|
| 417 |
+
Examples:
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
|
| 421 |
+
"""
|
| 422 |
+
|
| 423 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 424 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 425 |
+
num_videos_per_prompt = 1
|
| 426 |
+
|
| 427 |
+
# 1. Check inputs. Raise error if not correct
|
| 428 |
+
self.check_inputs(
|
| 429 |
+
prompt,
|
| 430 |
+
height,
|
| 431 |
+
width,
|
| 432 |
+
negative_prompt,
|
| 433 |
+
callback_on_step_end_tensor_inputs,
|
| 434 |
+
prompt_embeds,
|
| 435 |
+
negative_prompt_embeds,
|
| 436 |
+
)
|
| 437 |
+
self._guidance_scale = guidance_scale
|
| 438 |
+
self._attention_kwargs = attention_kwargs
|
| 439 |
+
self._interrupt = False
|
| 440 |
+
|
| 441 |
+
# 2. Default call parameters
|
| 442 |
+
if prompt is not None and isinstance(prompt, str):
|
| 443 |
+
batch_size = 1
|
| 444 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 445 |
+
batch_size = len(prompt)
|
| 446 |
+
else:
|
| 447 |
+
batch_size = prompt_embeds.shape[0]
|
| 448 |
+
|
| 449 |
+
device = self._execution_device
|
| 450 |
+
weight_dtype = self.text_encoder.dtype
|
| 451 |
+
|
| 452 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 453 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 454 |
+
# corresponds to doing no classifier free guidance.
|
| 455 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 456 |
+
|
| 457 |
+
# 3. Encode input prompt
|
| 458 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 459 |
+
prompt,
|
| 460 |
+
negative_prompt,
|
| 461 |
+
do_classifier_free_guidance,
|
| 462 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 463 |
+
prompt_embeds=prompt_embeds,
|
| 464 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 465 |
+
max_sequence_length=max_sequence_length,
|
| 466 |
+
device=device,
|
| 467 |
+
)
|
| 468 |
+
if do_classifier_free_guidance:
|
| 469 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 470 |
+
else:
|
| 471 |
+
in_prompt_embeds = prompt_embeds
|
| 472 |
+
|
| 473 |
+
# 4. Prepare timesteps
|
| 474 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 475 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
| 476 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 477 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 478 |
+
timesteps = self.scheduler.timesteps
|
| 479 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 480 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 481 |
+
timesteps, _ = retrieve_timesteps(
|
| 482 |
+
self.scheduler,
|
| 483 |
+
device=device,
|
| 484 |
+
sigmas=sampling_sigmas)
|
| 485 |
+
else:
|
| 486 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 487 |
+
self._num_timesteps = len(timesteps)
|
| 488 |
+
if comfyui_progressbar:
|
| 489 |
+
from comfy.utils import ProgressBar
|
| 490 |
+
pbar = ProgressBar(num_inference_steps + 1)
|
| 491 |
+
|
| 492 |
+
# 5. Prepare latents
|
| 493 |
+
latent_channels = self.transformer.config.in_channels
|
| 494 |
+
latents = self.prepare_latents(
|
| 495 |
+
batch_size * num_videos_per_prompt,
|
| 496 |
+
latent_channels,
|
| 497 |
+
num_frames,
|
| 498 |
+
height,
|
| 499 |
+
width,
|
| 500 |
+
weight_dtype,
|
| 501 |
+
device,
|
| 502 |
+
generator,
|
| 503 |
+
latents,
|
| 504 |
+
)
|
| 505 |
+
if comfyui_progressbar:
|
| 506 |
+
pbar.update(1)
|
| 507 |
+
|
| 508 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 509 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 510 |
+
|
| 511 |
+
target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
|
| 512 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 513 |
+
# 7. Denoising loop
|
| 514 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 515 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 516 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 517 |
+
for i, t in enumerate(timesteps):
|
| 518 |
+
self.transformer.current_steps = i
|
| 519 |
+
|
| 520 |
+
if self.interrupt:
|
| 521 |
+
continue
|
| 522 |
+
|
| 523 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 524 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 525 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 526 |
+
|
| 527 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 528 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 529 |
+
|
| 530 |
+
# predict noise model_output
|
| 531 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 532 |
+
noise_pred = self.transformer(
|
| 533 |
+
x=latent_model_input,
|
| 534 |
+
context=in_prompt_embeds,
|
| 535 |
+
t=timestep,
|
| 536 |
+
seq_len=seq_len,
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# perform guidance
|
| 540 |
+
if do_classifier_free_guidance:
|
| 541 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 542 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 543 |
+
|
| 544 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 545 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 546 |
+
|
| 547 |
+
if callback_on_step_end is not None:
|
| 548 |
+
callback_kwargs = {}
|
| 549 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 550 |
+
callback_kwargs[k] = locals()[k]
|
| 551 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 552 |
+
|
| 553 |
+
latents = callback_outputs.pop("latents", latents)
|
| 554 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 555 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 556 |
+
|
| 557 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 558 |
+
progress_bar.update()
|
| 559 |
+
if comfyui_progressbar:
|
| 560 |
+
pbar.update(1)
|
| 561 |
+
|
| 562 |
+
if output_type == "numpy":
|
| 563 |
+
video = self.decode_latents(latents)
|
| 564 |
+
elif not output_type == "latent":
|
| 565 |
+
video = self.decode_latents(latents)
|
| 566 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 567 |
+
else:
|
| 568 |
+
video = latents
|
| 569 |
+
|
| 570 |
+
# Offload all models
|
| 571 |
+
self.maybe_free_model_hooks()
|
| 572 |
+
|
| 573 |
+
if not return_dict:
|
| 574 |
+
video = torch.from_numpy(video)
|
| 575 |
+
|
| 576 |
+
return WanPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_wan2_2.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 9 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 10 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 11 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 12 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 13 |
+
from diffusers.video_processor import VideoProcessor
|
| 14 |
+
|
| 15 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer,
|
| 16 |
+
WanT5EncoderModel, Wan2_2Transformer3DModel)
|
| 17 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 18 |
+
get_sampling_sigmas)
|
| 19 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
EXAMPLE_DOC_STRING = """
|
| 25 |
+
Examples:
|
| 26 |
+
```python
|
| 27 |
+
pass
|
| 28 |
+
```
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 33 |
+
def retrieve_timesteps(
|
| 34 |
+
scheduler,
|
| 35 |
+
num_inference_steps: Optional[int] = None,
|
| 36 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 37 |
+
timesteps: Optional[List[int]] = None,
|
| 38 |
+
sigmas: Optional[List[float]] = None,
|
| 39 |
+
**kwargs,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 43 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
scheduler (`SchedulerMixin`):
|
| 47 |
+
The scheduler to get timesteps from.
|
| 48 |
+
num_inference_steps (`int`):
|
| 49 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 50 |
+
must be `None`.
|
| 51 |
+
device (`str` or `torch.device`, *optional*):
|
| 52 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 53 |
+
timesteps (`List[int]`, *optional*):
|
| 54 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 55 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 56 |
+
sigmas (`List[float]`, *optional*):
|
| 57 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 58 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 62 |
+
second element is the number of inference steps.
|
| 63 |
+
"""
|
| 64 |
+
if timesteps is not None and sigmas is not None:
|
| 65 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 66 |
+
if timesteps is not None:
|
| 67 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 68 |
+
if not accepts_timesteps:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 71 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 72 |
+
)
|
| 73 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 74 |
+
timesteps = scheduler.timesteps
|
| 75 |
+
num_inference_steps = len(timesteps)
|
| 76 |
+
elif sigmas is not None:
|
| 77 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 78 |
+
if not accept_sigmas:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 81 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 82 |
+
)
|
| 83 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 84 |
+
timesteps = scheduler.timesteps
|
| 85 |
+
num_inference_steps = len(timesteps)
|
| 86 |
+
else:
|
| 87 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 88 |
+
timesteps = scheduler.timesteps
|
| 89 |
+
return timesteps, num_inference_steps
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass
|
| 93 |
+
class WanPipelineOutput(BaseOutput):
|
| 94 |
+
r"""
|
| 95 |
+
Output class for CogVideo pipelines.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 99 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 100 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 101 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
videos: torch.Tensor
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class Wan2_2Pipeline(DiffusionPipeline):
|
| 108 |
+
r"""
|
| 109 |
+
Pipeline for text-to-video generation using Wan.
|
| 110 |
+
|
| 111 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 112 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
_optional_components = ["transformer_2"]
|
| 116 |
+
model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
|
| 117 |
+
|
| 118 |
+
_callback_tensor_inputs = [
|
| 119 |
+
"latents",
|
| 120 |
+
"prompt_embeds",
|
| 121 |
+
"negative_prompt_embeds",
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
tokenizer: AutoTokenizer,
|
| 127 |
+
text_encoder: WanT5EncoderModel,
|
| 128 |
+
vae: AutoencoderKLWan,
|
| 129 |
+
transformer: Wan2_2Transformer3DModel,
|
| 130 |
+
transformer_2: Wan2_2Transformer3DModel = None,
|
| 131 |
+
scheduler: FlowMatchEulerDiscreteScheduler = None,
|
| 132 |
+
):
|
| 133 |
+
super().__init__()
|
| 134 |
+
|
| 135 |
+
self.register_modules(
|
| 136 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
|
| 137 |
+
transformer_2=transformer_2, scheduler=scheduler
|
| 138 |
+
)
|
| 139 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 140 |
+
|
| 141 |
+
def _get_t5_prompt_embeds(
|
| 142 |
+
self,
|
| 143 |
+
prompt: Union[str, List[str]] = None,
|
| 144 |
+
num_videos_per_prompt: int = 1,
|
| 145 |
+
max_sequence_length: int = 512,
|
| 146 |
+
device: Optional[torch.device] = None,
|
| 147 |
+
dtype: Optional[torch.dtype] = None,
|
| 148 |
+
):
|
| 149 |
+
device = device or self._execution_device
|
| 150 |
+
dtype = dtype or self.text_encoder.dtype
|
| 151 |
+
|
| 152 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 153 |
+
batch_size = len(prompt)
|
| 154 |
+
|
| 155 |
+
text_inputs = self.tokenizer(
|
| 156 |
+
prompt,
|
| 157 |
+
padding="max_length",
|
| 158 |
+
max_length=max_sequence_length,
|
| 159 |
+
truncation=True,
|
| 160 |
+
add_special_tokens=True,
|
| 161 |
+
return_tensors="pt",
|
| 162 |
+
)
|
| 163 |
+
text_input_ids = text_inputs.input_ids
|
| 164 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 165 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 166 |
+
|
| 167 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 168 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 169 |
+
logger.warning(
|
| 170 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 171 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 175 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 176 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 177 |
+
|
| 178 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 179 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 180 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 181 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 182 |
+
|
| 183 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 184 |
+
|
| 185 |
+
def encode_prompt(
|
| 186 |
+
self,
|
| 187 |
+
prompt: Union[str, List[str]],
|
| 188 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 189 |
+
do_classifier_free_guidance: bool = True,
|
| 190 |
+
num_videos_per_prompt: int = 1,
|
| 191 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 192 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 193 |
+
max_sequence_length: int = 512,
|
| 194 |
+
device: Optional[torch.device] = None,
|
| 195 |
+
dtype: Optional[torch.dtype] = None,
|
| 196 |
+
):
|
| 197 |
+
r"""
|
| 198 |
+
Encodes the prompt into text encoder hidden states.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 202 |
+
prompt to be encoded
|
| 203 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 204 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 205 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 206 |
+
less than `1`).
|
| 207 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 208 |
+
Whether to use classifier free guidance or not.
|
| 209 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 210 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 211 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 212 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 213 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 214 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 215 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 216 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 217 |
+
argument.
|
| 218 |
+
device: (`torch.device`, *optional*):
|
| 219 |
+
torch device
|
| 220 |
+
dtype: (`torch.dtype`, *optional*):
|
| 221 |
+
torch dtype
|
| 222 |
+
"""
|
| 223 |
+
device = device or self._execution_device
|
| 224 |
+
|
| 225 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 226 |
+
if prompt is not None:
|
| 227 |
+
batch_size = len(prompt)
|
| 228 |
+
else:
|
| 229 |
+
batch_size = prompt_embeds.shape[0]
|
| 230 |
+
|
| 231 |
+
if prompt_embeds is None:
|
| 232 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 233 |
+
prompt=prompt,
|
| 234 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 235 |
+
max_sequence_length=max_sequence_length,
|
| 236 |
+
device=device,
|
| 237 |
+
dtype=dtype,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 241 |
+
negative_prompt = negative_prompt or ""
|
| 242 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 243 |
+
|
| 244 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 245 |
+
raise TypeError(
|
| 246 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 247 |
+
f" {type(prompt)}."
|
| 248 |
+
)
|
| 249 |
+
elif batch_size != len(negative_prompt):
|
| 250 |
+
raise ValueError(
|
| 251 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 252 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 253 |
+
" the batch size of `prompt`."
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 257 |
+
prompt=negative_prompt,
|
| 258 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 259 |
+
max_sequence_length=max_sequence_length,
|
| 260 |
+
device=device,
|
| 261 |
+
dtype=dtype,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
return prompt_embeds, negative_prompt_embeds
|
| 265 |
+
|
| 266 |
+
def prepare_latents(
|
| 267 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 268 |
+
):
|
| 269 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 270 |
+
raise ValueError(
|
| 271 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 272 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
shape = (
|
| 276 |
+
batch_size,
|
| 277 |
+
num_channels_latents,
|
| 278 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1,
|
| 279 |
+
height // self.vae.spatial_compression_ratio,
|
| 280 |
+
width // self.vae.spatial_compression_ratio,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if latents is None:
|
| 284 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 285 |
+
else:
|
| 286 |
+
latents = latents.to(device)
|
| 287 |
+
|
| 288 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 289 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 290 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 291 |
+
return latents
|
| 292 |
+
|
| 293 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 294 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 295 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 296 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 297 |
+
frames = frames.cpu().float().numpy()
|
| 298 |
+
return frames
|
| 299 |
+
|
| 300 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 301 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 302 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 303 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 304 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 305 |
+
# and should be between [0, 1]
|
| 306 |
+
|
| 307 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 308 |
+
extra_step_kwargs = {}
|
| 309 |
+
if accepts_eta:
|
| 310 |
+
extra_step_kwargs["eta"] = eta
|
| 311 |
+
|
| 312 |
+
# check if the scheduler accepts generator
|
| 313 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 314 |
+
if accepts_generator:
|
| 315 |
+
extra_step_kwargs["generator"] = generator
|
| 316 |
+
return extra_step_kwargs
|
| 317 |
+
|
| 318 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 319 |
+
def check_inputs(
|
| 320 |
+
self,
|
| 321 |
+
prompt,
|
| 322 |
+
height,
|
| 323 |
+
width,
|
| 324 |
+
negative_prompt,
|
| 325 |
+
callback_on_step_end_tensor_inputs,
|
| 326 |
+
prompt_embeds=None,
|
| 327 |
+
negative_prompt_embeds=None,
|
| 328 |
+
):
|
| 329 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 330 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 331 |
+
|
| 332 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 333 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 334 |
+
):
|
| 335 |
+
raise ValueError(
|
| 336 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 337 |
+
)
|
| 338 |
+
if prompt is not None and prompt_embeds is not None:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 341 |
+
" only forward one of the two."
|
| 342 |
+
)
|
| 343 |
+
elif prompt is None and prompt_embeds is None:
|
| 344 |
+
raise ValueError(
|
| 345 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 346 |
+
)
|
| 347 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 348 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 349 |
+
|
| 350 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 351 |
+
raise ValueError(
|
| 352 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 353 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 357 |
+
raise ValueError(
|
| 358 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 359 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 363 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 364 |
+
raise ValueError(
|
| 365 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 366 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 367 |
+
f" {negative_prompt_embeds.shape}."
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
@property
|
| 371 |
+
def guidance_scale(self):
|
| 372 |
+
return self._guidance_scale
|
| 373 |
+
|
| 374 |
+
@property
|
| 375 |
+
def num_timesteps(self):
|
| 376 |
+
return self._num_timesteps
|
| 377 |
+
|
| 378 |
+
@property
|
| 379 |
+
def attention_kwargs(self):
|
| 380 |
+
return self._attention_kwargs
|
| 381 |
+
|
| 382 |
+
@property
|
| 383 |
+
def interrupt(self):
|
| 384 |
+
return self._interrupt
|
| 385 |
+
|
| 386 |
+
@torch.no_grad()
|
| 387 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 388 |
+
def __call__(
|
| 389 |
+
self,
|
| 390 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 391 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 392 |
+
height: int = 480,
|
| 393 |
+
width: int = 720,
|
| 394 |
+
num_frames: int = 49,
|
| 395 |
+
num_inference_steps: int = 50,
|
| 396 |
+
timesteps: Optional[List[int]] = None,
|
| 397 |
+
guidance_scale: float = 6,
|
| 398 |
+
num_videos_per_prompt: int = 1,
|
| 399 |
+
eta: float = 0.0,
|
| 400 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 401 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 402 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 403 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 404 |
+
output_type: str = "numpy",
|
| 405 |
+
return_dict: bool = False,
|
| 406 |
+
callback_on_step_end: Optional[
|
| 407 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 408 |
+
] = None,
|
| 409 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 410 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 411 |
+
max_sequence_length: int = 512,
|
| 412 |
+
boundary: float = 0.875,
|
| 413 |
+
comfyui_progressbar: bool = False,
|
| 414 |
+
shift: int = 5,
|
| 415 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 416 |
+
"""
|
| 417 |
+
Function invoked when calling the pipeline for generation.
|
| 418 |
+
Args:
|
| 419 |
+
|
| 420 |
+
Examples:
|
| 421 |
+
|
| 422 |
+
Returns:
|
| 423 |
+
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 427 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 428 |
+
num_videos_per_prompt = 1
|
| 429 |
+
|
| 430 |
+
# 1. Check inputs. Raise error if not correct
|
| 431 |
+
self.check_inputs(
|
| 432 |
+
prompt,
|
| 433 |
+
height,
|
| 434 |
+
width,
|
| 435 |
+
negative_prompt,
|
| 436 |
+
callback_on_step_end_tensor_inputs,
|
| 437 |
+
prompt_embeds,
|
| 438 |
+
negative_prompt_embeds,
|
| 439 |
+
)
|
| 440 |
+
self._guidance_scale = guidance_scale
|
| 441 |
+
self._attention_kwargs = attention_kwargs
|
| 442 |
+
self._interrupt = False
|
| 443 |
+
|
| 444 |
+
# 2. Default call parameters
|
| 445 |
+
if prompt is not None and isinstance(prompt, str):
|
| 446 |
+
batch_size = 1
|
| 447 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 448 |
+
batch_size = len(prompt)
|
| 449 |
+
else:
|
| 450 |
+
batch_size = prompt_embeds.shape[0]
|
| 451 |
+
|
| 452 |
+
device = self._execution_device
|
| 453 |
+
weight_dtype = self.text_encoder.dtype
|
| 454 |
+
|
| 455 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 456 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 457 |
+
# corresponds to doing no classifier free guidance.
|
| 458 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 459 |
+
|
| 460 |
+
# 3. Encode input prompt
|
| 461 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 462 |
+
prompt,
|
| 463 |
+
negative_prompt,
|
| 464 |
+
do_classifier_free_guidance,
|
| 465 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 466 |
+
prompt_embeds=prompt_embeds,
|
| 467 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 468 |
+
max_sequence_length=max_sequence_length,
|
| 469 |
+
device=device,
|
| 470 |
+
)
|
| 471 |
+
if do_classifier_free_guidance:
|
| 472 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 473 |
+
else:
|
| 474 |
+
in_prompt_embeds = prompt_embeds
|
| 475 |
+
|
| 476 |
+
# 4. Prepare timesteps
|
| 477 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 478 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
| 479 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 480 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 481 |
+
timesteps = self.scheduler.timesteps
|
| 482 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 483 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 484 |
+
timesteps, _ = retrieve_timesteps(
|
| 485 |
+
self.scheduler,
|
| 486 |
+
device=device,
|
| 487 |
+
sigmas=sampling_sigmas)
|
| 488 |
+
else:
|
| 489 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 490 |
+
self._num_timesteps = len(timesteps)
|
| 491 |
+
if comfyui_progressbar:
|
| 492 |
+
from comfy.utils import ProgressBar
|
| 493 |
+
pbar = ProgressBar(num_inference_steps + 1)
|
| 494 |
+
|
| 495 |
+
# 5. Prepare latents
|
| 496 |
+
latent_channels = self.transformer.config.in_channels
|
| 497 |
+
latents = self.prepare_latents(
|
| 498 |
+
batch_size * num_videos_per_prompt,
|
| 499 |
+
latent_channels,
|
| 500 |
+
num_frames,
|
| 501 |
+
height,
|
| 502 |
+
width,
|
| 503 |
+
weight_dtype,
|
| 504 |
+
device,
|
| 505 |
+
generator,
|
| 506 |
+
latents,
|
| 507 |
+
)
|
| 508 |
+
if comfyui_progressbar:
|
| 509 |
+
pbar.update(1)
|
| 510 |
+
|
| 511 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 512 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 513 |
+
|
| 514 |
+
target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
|
| 515 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 516 |
+
# 7. Denoising loop
|
| 517 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 518 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 519 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 520 |
+
for i, t in enumerate(timesteps):
|
| 521 |
+
self.transformer.current_steps = i
|
| 522 |
+
|
| 523 |
+
if self.interrupt:
|
| 524 |
+
continue
|
| 525 |
+
|
| 526 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 527 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 528 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 529 |
+
|
| 530 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 531 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 532 |
+
|
| 533 |
+
if self.transformer_2 is not None:
|
| 534 |
+
if t >= boundary * self.scheduler.config.num_train_timesteps:
|
| 535 |
+
local_transformer = self.transformer_2
|
| 536 |
+
else:
|
| 537 |
+
local_transformer = self.transformer
|
| 538 |
+
else:
|
| 539 |
+
local_transformer = self.transformer
|
| 540 |
+
|
| 541 |
+
# predict noise model_output
|
| 542 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 543 |
+
noise_pred = local_transformer(
|
| 544 |
+
x=latent_model_input,
|
| 545 |
+
context=in_prompt_embeds,
|
| 546 |
+
t=timestep,
|
| 547 |
+
seq_len=seq_len,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
# perform guidance
|
| 551 |
+
if do_classifier_free_guidance:
|
| 552 |
+
if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))):
|
| 553 |
+
sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0]
|
| 554 |
+
else:
|
| 555 |
+
sample_guide_scale = self.guidance_scale
|
| 556 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 557 |
+
noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond)
|
| 558 |
+
|
| 559 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 560 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 561 |
+
|
| 562 |
+
if callback_on_step_end is not None:
|
| 563 |
+
callback_kwargs = {}
|
| 564 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 565 |
+
callback_kwargs[k] = locals()[k]
|
| 566 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 567 |
+
|
| 568 |
+
latents = callback_outputs.pop("latents", latents)
|
| 569 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 570 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 571 |
+
|
| 572 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 573 |
+
progress_bar.update()
|
| 574 |
+
if comfyui_progressbar:
|
| 575 |
+
pbar.update(1)
|
| 576 |
+
|
| 577 |
+
if output_type == "numpy":
|
| 578 |
+
video = self.decode_latents(latents)
|
| 579 |
+
elif not output_type == "latent":
|
| 580 |
+
video = self.decode_latents(latents)
|
| 581 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 582 |
+
else:
|
| 583 |
+
video = latents
|
| 584 |
+
|
| 585 |
+
# Offload all models
|
| 586 |
+
self.maybe_free_model_hooks()
|
| 587 |
+
|
| 588 |
+
if not return_dict:
|
| 589 |
+
video = torch.from_numpy(video)
|
| 590 |
+
|
| 591 |
+
return WanPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_wan2_2_animate.py
ADDED
|
@@ -0,0 +1,929 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import copy
|
| 9 |
+
import torch
|
| 10 |
+
import cv2
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 14 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 15 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 16 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 17 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 18 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 19 |
+
from diffusers.video_processor import VideoProcessor
|
| 20 |
+
from decord import VideoReader
|
| 21 |
+
|
| 22 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
|
| 23 |
+
WanT5EncoderModel, Wan2_2Transformer3DModel_Animate)
|
| 24 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 25 |
+
get_sampling_sigmas)
|
| 26 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
EXAMPLE_DOC_STRING = """
|
| 32 |
+
Examples:
|
| 33 |
+
```python
|
| 34 |
+
pass
|
| 35 |
+
```
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 40 |
+
def retrieve_timesteps(
|
| 41 |
+
scheduler,
|
| 42 |
+
num_inference_steps: Optional[int] = None,
|
| 43 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 44 |
+
timesteps: Optional[List[int]] = None,
|
| 45 |
+
sigmas: Optional[List[float]] = None,
|
| 46 |
+
**kwargs,
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 50 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
scheduler (`SchedulerMixin`):
|
| 54 |
+
The scheduler to get timesteps from.
|
| 55 |
+
num_inference_steps (`int`):
|
| 56 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 57 |
+
must be `None`.
|
| 58 |
+
device (`str` or `torch.device`, *optional*):
|
| 59 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 60 |
+
timesteps (`List[int]`, *optional*):
|
| 61 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 62 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 63 |
+
sigmas (`List[float]`, *optional*):
|
| 64 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 65 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 69 |
+
second element is the number of inference steps.
|
| 70 |
+
"""
|
| 71 |
+
if timesteps is not None and sigmas is not None:
|
| 72 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 73 |
+
if timesteps is not None:
|
| 74 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 75 |
+
if not accepts_timesteps:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 78 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 79 |
+
)
|
| 80 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 81 |
+
timesteps = scheduler.timesteps
|
| 82 |
+
num_inference_steps = len(timesteps)
|
| 83 |
+
elif sigmas is not None:
|
| 84 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 85 |
+
if not accept_sigmas:
|
| 86 |
+
raise ValueError(
|
| 87 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 88 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 89 |
+
)
|
| 90 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 91 |
+
timesteps = scheduler.timesteps
|
| 92 |
+
num_inference_steps = len(timesteps)
|
| 93 |
+
else:
|
| 94 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 95 |
+
timesteps = scheduler.timesteps
|
| 96 |
+
return timesteps, num_inference_steps
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@dataclass
|
| 100 |
+
class WanPipelineOutput(BaseOutput):
|
| 101 |
+
r"""
|
| 102 |
+
Output class for CogVideo pipelines.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 106 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 107 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 108 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
videos: torch.Tensor
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class Wan2_2AnimatePipeline(DiffusionPipeline):
|
| 115 |
+
r"""
|
| 116 |
+
Pipeline for text-to-video generation using Wan.
|
| 117 |
+
|
| 118 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 119 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
_optional_components = ["transformer_2", "clip_image_encoder"]
|
| 123 |
+
model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer_2->transformer->vae"
|
| 124 |
+
|
| 125 |
+
_callback_tensor_inputs = [
|
| 126 |
+
"latents",
|
| 127 |
+
"prompt_embeds",
|
| 128 |
+
"negative_prompt_embeds",
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
tokenizer: AutoTokenizer,
|
| 134 |
+
text_encoder: WanT5EncoderModel,
|
| 135 |
+
vae: AutoencoderKLWan,
|
| 136 |
+
transformer: Wan2_2Transformer3DModel_Animate,
|
| 137 |
+
transformer_2: Wan2_2Transformer3DModel_Animate = None,
|
| 138 |
+
clip_image_encoder: CLIPModel = None,
|
| 139 |
+
scheduler: FlowMatchEulerDiscreteScheduler = None,
|
| 140 |
+
):
|
| 141 |
+
super().__init__()
|
| 142 |
+
|
| 143 |
+
self.register_modules(
|
| 144 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
|
| 145 |
+
transformer_2=transformer_2, clip_image_encoder=clip_image_encoder, scheduler=scheduler
|
| 146 |
+
)
|
| 147 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 148 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 149 |
+
self.mask_processor = VaeImageProcessor(
|
| 150 |
+
vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def _get_t5_prompt_embeds(
|
| 154 |
+
self,
|
| 155 |
+
prompt: Union[str, List[str]] = None,
|
| 156 |
+
num_videos_per_prompt: int = 1,
|
| 157 |
+
max_sequence_length: int = 512,
|
| 158 |
+
device: Optional[torch.device] = None,
|
| 159 |
+
dtype: Optional[torch.dtype] = None,
|
| 160 |
+
):
|
| 161 |
+
device = device or self._execution_device
|
| 162 |
+
dtype = dtype or self.text_encoder.dtype
|
| 163 |
+
|
| 164 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 165 |
+
batch_size = len(prompt)
|
| 166 |
+
|
| 167 |
+
text_inputs = self.tokenizer(
|
| 168 |
+
prompt,
|
| 169 |
+
padding="max_length",
|
| 170 |
+
max_length=max_sequence_length,
|
| 171 |
+
truncation=True,
|
| 172 |
+
add_special_tokens=True,
|
| 173 |
+
return_tensors="pt",
|
| 174 |
+
)
|
| 175 |
+
text_input_ids = text_inputs.input_ids
|
| 176 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 177 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 178 |
+
|
| 179 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 180 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 181 |
+
logger.warning(
|
| 182 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 183 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 187 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 188 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 189 |
+
|
| 190 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 191 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 192 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 193 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 194 |
+
|
| 195 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 196 |
+
|
| 197 |
+
def encode_prompt(
|
| 198 |
+
self,
|
| 199 |
+
prompt: Union[str, List[str]],
|
| 200 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 201 |
+
do_classifier_free_guidance: bool = True,
|
| 202 |
+
num_videos_per_prompt: int = 1,
|
| 203 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 204 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 205 |
+
max_sequence_length: int = 512,
|
| 206 |
+
device: Optional[torch.device] = None,
|
| 207 |
+
dtype: Optional[torch.dtype] = None,
|
| 208 |
+
):
|
| 209 |
+
r"""
|
| 210 |
+
Encodes the prompt into text encoder hidden states.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 214 |
+
prompt to be encoded
|
| 215 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 216 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 217 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 218 |
+
less than `1`).
|
| 219 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 220 |
+
Whether to use classifier free guidance or not.
|
| 221 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 222 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 223 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 224 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 225 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 226 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 227 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 228 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 229 |
+
argument.
|
| 230 |
+
device: (`torch.device`, *optional*):
|
| 231 |
+
torch device
|
| 232 |
+
dtype: (`torch.dtype`, *optional*):
|
| 233 |
+
torch dtype
|
| 234 |
+
"""
|
| 235 |
+
device = device or self._execution_device
|
| 236 |
+
|
| 237 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 238 |
+
if prompt is not None:
|
| 239 |
+
batch_size = len(prompt)
|
| 240 |
+
else:
|
| 241 |
+
batch_size = prompt_embeds.shape[0]
|
| 242 |
+
|
| 243 |
+
if prompt_embeds is None:
|
| 244 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 245 |
+
prompt=prompt,
|
| 246 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 247 |
+
max_sequence_length=max_sequence_length,
|
| 248 |
+
device=device,
|
| 249 |
+
dtype=dtype,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 253 |
+
negative_prompt = negative_prompt or ""
|
| 254 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 255 |
+
|
| 256 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 257 |
+
raise TypeError(
|
| 258 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 259 |
+
f" {type(prompt)}."
|
| 260 |
+
)
|
| 261 |
+
elif batch_size != len(negative_prompt):
|
| 262 |
+
raise ValueError(
|
| 263 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 264 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 265 |
+
" the batch size of `prompt`."
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 269 |
+
prompt=negative_prompt,
|
| 270 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 271 |
+
max_sequence_length=max_sequence_length,
|
| 272 |
+
device=device,
|
| 273 |
+
dtype=dtype,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
return prompt_embeds, negative_prompt_embeds
|
| 277 |
+
|
| 278 |
+
def prepare_latents(
|
| 279 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 280 |
+
):
|
| 281 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 282 |
+
raise ValueError(
|
| 283 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 284 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
shape = (
|
| 288 |
+
batch_size,
|
| 289 |
+
num_channels_latents,
|
| 290 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1,
|
| 291 |
+
height // self.vae.spatial_compression_ratio,
|
| 292 |
+
width // self.vae.spatial_compression_ratio,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
if latents is None:
|
| 296 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 297 |
+
else:
|
| 298 |
+
latents = latents.to(device)
|
| 299 |
+
|
| 300 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 301 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 302 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 303 |
+
return latents
|
| 304 |
+
|
| 305 |
+
def padding_resize(self, img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR):
|
| 306 |
+
ori_height = img_ori.shape[0]
|
| 307 |
+
ori_width = img_ori.shape[1]
|
| 308 |
+
channel = img_ori.shape[2]
|
| 309 |
+
|
| 310 |
+
img_pad = np.zeros((height, width, channel))
|
| 311 |
+
if channel == 1:
|
| 312 |
+
img_pad[:, :, 0] = padding_color[0]
|
| 313 |
+
else:
|
| 314 |
+
img_pad[:, :, 0] = padding_color[0]
|
| 315 |
+
img_pad[:, :, 1] = padding_color[1]
|
| 316 |
+
img_pad[:, :, 2] = padding_color[2]
|
| 317 |
+
|
| 318 |
+
if (ori_height / ori_width) > (height / width):
|
| 319 |
+
new_width = int(height / ori_height * ori_width)
|
| 320 |
+
img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)
|
| 321 |
+
padding = int((width - new_width) / 2)
|
| 322 |
+
if len(img.shape) == 2:
|
| 323 |
+
img = img[:, :, np.newaxis]
|
| 324 |
+
img_pad[:, padding: padding + new_width, :] = img
|
| 325 |
+
else:
|
| 326 |
+
new_height = int(width / ori_width * ori_height)
|
| 327 |
+
img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)
|
| 328 |
+
padding = int((height - new_height) / 2)
|
| 329 |
+
if len(img.shape) == 2:
|
| 330 |
+
img = img[:, :, np.newaxis]
|
| 331 |
+
img_pad[padding: padding + new_height, :, :] = img
|
| 332 |
+
|
| 333 |
+
img_pad = np.uint8(img_pad)
|
| 334 |
+
|
| 335 |
+
return img_pad
|
| 336 |
+
|
| 337 |
+
def inputs_padding(self, x, target_len):
|
| 338 |
+
ndim = x.ndim
|
| 339 |
+
|
| 340 |
+
if ndim == 4:
|
| 341 |
+
f = x.shape[0]
|
| 342 |
+
if target_len <= f:
|
| 343 |
+
return [deepcopy(x[i]) for i in range(target_len)]
|
| 344 |
+
|
| 345 |
+
idx = 0
|
| 346 |
+
flip = False
|
| 347 |
+
target_array = []
|
| 348 |
+
while len(target_array) < target_len:
|
| 349 |
+
target_array.append(deepcopy(x[idx]))
|
| 350 |
+
if flip:
|
| 351 |
+
idx -= 1
|
| 352 |
+
else:
|
| 353 |
+
idx += 1
|
| 354 |
+
if idx == 0 or idx == f - 1:
|
| 355 |
+
flip = not flip
|
| 356 |
+
return target_array[:target_len]
|
| 357 |
+
|
| 358 |
+
elif ndim == 5:
|
| 359 |
+
b, c, f, h, w = x.shape
|
| 360 |
+
|
| 361 |
+
if target_len <= f:
|
| 362 |
+
return x[:, :, :target_len, :, :]
|
| 363 |
+
|
| 364 |
+
indices = []
|
| 365 |
+
idx = 0
|
| 366 |
+
flip = False
|
| 367 |
+
while len(indices) < target_len:
|
| 368 |
+
indices.append(idx)
|
| 369 |
+
if flip:
|
| 370 |
+
idx -= 1
|
| 371 |
+
else:
|
| 372 |
+
idx += 1
|
| 373 |
+
if idx == 0 or idx == f - 1:
|
| 374 |
+
flip = not flip
|
| 375 |
+
indices = indices[:target_len]
|
| 376 |
+
|
| 377 |
+
if isinstance(x, torch.Tensor):
|
| 378 |
+
indices_tensor = torch.tensor(indices, device=x.device, dtype=torch.long)
|
| 379 |
+
return x[:, :, indices_tensor, :, :]
|
| 380 |
+
else:
|
| 381 |
+
indices_array = np.array(indices)
|
| 382 |
+
return x[:, :, indices_array, :, :]
|
| 383 |
+
|
| 384 |
+
else:
|
| 385 |
+
raise ValueError(f"Unsupported input dimension: {ndim}. Expected 4D or 5D.")
|
| 386 |
+
|
| 387 |
+
def get_valid_len(self, real_len, clip_len=81, overlap=1):
|
| 388 |
+
real_clip_len = clip_len - overlap
|
| 389 |
+
last_clip_num = (real_len - overlap) % real_clip_len
|
| 390 |
+
if last_clip_num == 0:
|
| 391 |
+
extra = 0
|
| 392 |
+
else:
|
| 393 |
+
extra = real_clip_len - last_clip_num
|
| 394 |
+
target_len = real_len + extra
|
| 395 |
+
return target_len
|
| 396 |
+
|
| 397 |
+
def prepare_source(self, src_pose_path, src_face_path, src_ref_path):
|
| 398 |
+
pose_video_reader = VideoReader(src_pose_path)
|
| 399 |
+
pose_len = len(pose_video_reader)
|
| 400 |
+
pose_idxs = list(range(pose_len))
|
| 401 |
+
pose_video = pose_video_reader.get_batch(pose_idxs).asnumpy()
|
| 402 |
+
|
| 403 |
+
face_video_reader = VideoReader(src_face_path)
|
| 404 |
+
face_len = len(face_video_reader)
|
| 405 |
+
face_idxs = list(range(face_len))
|
| 406 |
+
face_video = face_video_reader.get_batch(face_idxs).asnumpy()
|
| 407 |
+
height, width = pose_video[0].shape[:2]
|
| 408 |
+
|
| 409 |
+
ref_image = cv2.imread(src_ref_path)[..., ::-1]
|
| 410 |
+
ref_image = self.padding_resize(ref_image, height=height, width=width)
|
| 411 |
+
return pose_video, face_video, ref_image
|
| 412 |
+
|
| 413 |
+
def prepare_source_for_replace(self, src_bg_path, src_mask_path):
|
| 414 |
+
bg_video_reader = VideoReader(src_bg_path)
|
| 415 |
+
bg_len = len(bg_video_reader)
|
| 416 |
+
bg_idxs = list(range(bg_len))
|
| 417 |
+
bg_video = bg_video_reader.get_batch(bg_idxs).asnumpy()
|
| 418 |
+
|
| 419 |
+
mask_video_reader = VideoReader(src_mask_path)
|
| 420 |
+
mask_len = len(mask_video_reader)
|
| 421 |
+
mask_idxs = list(range(mask_len))
|
| 422 |
+
mask_video = mask_video_reader.get_batch(mask_idxs).asnumpy()
|
| 423 |
+
mask_video = mask_video[:, :, :, 0] / 255
|
| 424 |
+
return bg_video, mask_video
|
| 425 |
+
|
| 426 |
+
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
|
| 427 |
+
if mask_pixel_values is None:
|
| 428 |
+
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
|
| 429 |
+
else:
|
| 430 |
+
msk = mask_pixel_values.clone()
|
| 431 |
+
msk[:, :mask_len] = 1
|
| 432 |
+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
| 433 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
| 434 |
+
msk = msk.transpose(1, 2)
|
| 435 |
+
return msk
|
| 436 |
+
|
| 437 |
+
def prepare_control_latents(
|
| 438 |
+
self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
| 439 |
+
):
|
| 440 |
+
# resize the control to latents shape as we concatenate the control to the latents
|
| 441 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 442 |
+
# and half precision
|
| 443 |
+
|
| 444 |
+
if control is not None:
|
| 445 |
+
control = control.to(device=device, dtype=dtype)
|
| 446 |
+
bs = 1
|
| 447 |
+
new_control = []
|
| 448 |
+
for i in range(0, control.shape[0], bs):
|
| 449 |
+
control_bs = control[i : i + bs]
|
| 450 |
+
control_bs = self.vae.encode(control_bs)[0]
|
| 451 |
+
control_bs = control_bs.mode()
|
| 452 |
+
new_control.append(control_bs)
|
| 453 |
+
control = torch.cat(new_control, dim = 0)
|
| 454 |
+
|
| 455 |
+
if control_image is not None:
|
| 456 |
+
control_image = control_image.to(device=device, dtype=dtype)
|
| 457 |
+
bs = 1
|
| 458 |
+
new_control_pixel_values = []
|
| 459 |
+
for i in range(0, control_image.shape[0], bs):
|
| 460 |
+
control_pixel_values_bs = control_image[i : i + bs]
|
| 461 |
+
control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
|
| 462 |
+
control_pixel_values_bs = control_pixel_values_bs.mode()
|
| 463 |
+
new_control_pixel_values.append(control_pixel_values_bs)
|
| 464 |
+
control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
|
| 465 |
+
else:
|
| 466 |
+
control_image_latents = None
|
| 467 |
+
|
| 468 |
+
return control, control_image_latents
|
| 469 |
+
|
| 470 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 471 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 472 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 473 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 474 |
+
# frames = frames.cpu().float().numpy()
|
| 475 |
+
return frames
|
| 476 |
+
|
| 477 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 478 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 479 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 480 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 481 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 482 |
+
# and should be between [0, 1]
|
| 483 |
+
|
| 484 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 485 |
+
extra_step_kwargs = {}
|
| 486 |
+
if accepts_eta:
|
| 487 |
+
extra_step_kwargs["eta"] = eta
|
| 488 |
+
|
| 489 |
+
# check if the scheduler accepts generator
|
| 490 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 491 |
+
if accepts_generator:
|
| 492 |
+
extra_step_kwargs["generator"] = generator
|
| 493 |
+
return extra_step_kwargs
|
| 494 |
+
|
| 495 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 496 |
+
def check_inputs(
|
| 497 |
+
self,
|
| 498 |
+
prompt,
|
| 499 |
+
height,
|
| 500 |
+
width,
|
| 501 |
+
negative_prompt,
|
| 502 |
+
callback_on_step_end_tensor_inputs,
|
| 503 |
+
prompt_embeds=None,
|
| 504 |
+
negative_prompt_embeds=None,
|
| 505 |
+
):
|
| 506 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 507 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 508 |
+
|
| 509 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 510 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 511 |
+
):
|
| 512 |
+
raise ValueError(
|
| 513 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 514 |
+
)
|
| 515 |
+
if prompt is not None and prompt_embeds is not None:
|
| 516 |
+
raise ValueError(
|
| 517 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 518 |
+
" only forward one of the two."
|
| 519 |
+
)
|
| 520 |
+
elif prompt is None and prompt_embeds is None:
|
| 521 |
+
raise ValueError(
|
| 522 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 523 |
+
)
|
| 524 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 525 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 526 |
+
|
| 527 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 528 |
+
raise ValueError(
|
| 529 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 530 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 534 |
+
raise ValueError(
|
| 535 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 536 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 540 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 541 |
+
raise ValueError(
|
| 542 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 543 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 544 |
+
f" {negative_prompt_embeds.shape}."
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
@property
|
| 548 |
+
def guidance_scale(self):
|
| 549 |
+
return self._guidance_scale
|
| 550 |
+
|
| 551 |
+
@property
|
| 552 |
+
def num_timesteps(self):
|
| 553 |
+
return self._num_timesteps
|
| 554 |
+
|
| 555 |
+
@property
|
| 556 |
+
def attention_kwargs(self):
|
| 557 |
+
return self._attention_kwargs
|
| 558 |
+
|
| 559 |
+
@property
|
| 560 |
+
def interrupt(self):
|
| 561 |
+
return self._interrupt
|
| 562 |
+
|
| 563 |
+
@torch.no_grad()
|
| 564 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 565 |
+
def __call__(
|
| 566 |
+
self,
|
| 567 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 568 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 569 |
+
height: int = 480,
|
| 570 |
+
width: int = 720,
|
| 571 |
+
clip_len=77,
|
| 572 |
+
num_frames: int = 49,
|
| 573 |
+
num_inference_steps: int = 50,
|
| 574 |
+
pose_video = None,
|
| 575 |
+
face_video = None,
|
| 576 |
+
ref_image = None,
|
| 577 |
+
bg_video = None,
|
| 578 |
+
mask_video = None,
|
| 579 |
+
replace_flag = True,
|
| 580 |
+
timesteps: Optional[List[int]] = None,
|
| 581 |
+
guidance_scale: float = 6,
|
| 582 |
+
num_videos_per_prompt: int = 1,
|
| 583 |
+
eta: float = 0.0,
|
| 584 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 585 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 586 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 587 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 588 |
+
output_type: str = "numpy",
|
| 589 |
+
return_dict: bool = False,
|
| 590 |
+
callback_on_step_end: Optional[
|
| 591 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 592 |
+
] = None,
|
| 593 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 594 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 595 |
+
max_sequence_length: int = 512,
|
| 596 |
+
boundary: float = 0.875,
|
| 597 |
+
comfyui_progressbar: bool = False,
|
| 598 |
+
shift: int = 5,
|
| 599 |
+
refert_num = 1,
|
| 600 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 601 |
+
"""
|
| 602 |
+
Function invoked when calling the pipeline for generation.
|
| 603 |
+
Args:
|
| 604 |
+
|
| 605 |
+
Examples:
|
| 606 |
+
|
| 607 |
+
Returns:
|
| 608 |
+
|
| 609 |
+
"""
|
| 610 |
+
|
| 611 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 612 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 613 |
+
num_videos_per_prompt = 1
|
| 614 |
+
|
| 615 |
+
# 1. Check inputs. Raise error if not correct
|
| 616 |
+
self.check_inputs(
|
| 617 |
+
prompt,
|
| 618 |
+
height,
|
| 619 |
+
width,
|
| 620 |
+
negative_prompt,
|
| 621 |
+
callback_on_step_end_tensor_inputs,
|
| 622 |
+
prompt_embeds,
|
| 623 |
+
negative_prompt_embeds,
|
| 624 |
+
)
|
| 625 |
+
self._guidance_scale = guidance_scale
|
| 626 |
+
self._attention_kwargs = attention_kwargs
|
| 627 |
+
self._interrupt = False
|
| 628 |
+
|
| 629 |
+
# 2. Default call parameters
|
| 630 |
+
if prompt is not None and isinstance(prompt, str):
|
| 631 |
+
batch_size = 1
|
| 632 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 633 |
+
batch_size = len(prompt)
|
| 634 |
+
else:
|
| 635 |
+
batch_size = prompt_embeds.shape[0]
|
| 636 |
+
|
| 637 |
+
device = self._execution_device
|
| 638 |
+
weight_dtype = self.text_encoder.dtype
|
| 639 |
+
|
| 640 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 641 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 642 |
+
# corresponds to doing no classifier free guidance.
|
| 643 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 644 |
+
|
| 645 |
+
# 3. Encode input prompt
|
| 646 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 647 |
+
prompt,
|
| 648 |
+
negative_prompt,
|
| 649 |
+
do_classifier_free_guidance,
|
| 650 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 651 |
+
prompt_embeds=prompt_embeds,
|
| 652 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 653 |
+
max_sequence_length=max_sequence_length,
|
| 654 |
+
device=device,
|
| 655 |
+
)
|
| 656 |
+
if do_classifier_free_guidance:
|
| 657 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 658 |
+
else:
|
| 659 |
+
in_prompt_embeds = prompt_embeds
|
| 660 |
+
|
| 661 |
+
if comfyui_progressbar:
|
| 662 |
+
from comfy.utils import ProgressBar
|
| 663 |
+
pbar = ProgressBar(num_inference_steps + 1)
|
| 664 |
+
|
| 665 |
+
# 4. Prepare latents
|
| 666 |
+
if pose_video is not None:
|
| 667 |
+
video_length = pose_video.shape[2]
|
| 668 |
+
pose_video = self.image_processor.preprocess(rearrange(pose_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 669 |
+
pose_video = pose_video.to(dtype=torch.float32)
|
| 670 |
+
pose_video = rearrange(pose_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 671 |
+
else:
|
| 672 |
+
pose_video = None
|
| 673 |
+
|
| 674 |
+
if face_video is not None:
|
| 675 |
+
video_length = face_video.shape[2]
|
| 676 |
+
face_video = self.image_processor.preprocess(rearrange(face_video, "b c f h w -> (b f) c h w"))
|
| 677 |
+
face_video = face_video.to(dtype=torch.float32)
|
| 678 |
+
face_video = rearrange(face_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 679 |
+
else:
|
| 680 |
+
face_video = None
|
| 681 |
+
|
| 682 |
+
real_frame_len = pose_video.size()[2]
|
| 683 |
+
target_len = self.get_valid_len(real_frame_len, clip_len, overlap=refert_num)
|
| 684 |
+
print('real frames: {} target frames: {}'.format(real_frame_len, target_len))
|
| 685 |
+
pose_video = self.inputs_padding(pose_video, target_len).to(device, weight_dtype)
|
| 686 |
+
face_video = self.inputs_padding(face_video, target_len).to(device, weight_dtype)
|
| 687 |
+
ref_image = self.padding_resize(np.array(ref_image), height=height, width=width)
|
| 688 |
+
ref_image = torch.tensor(ref_image / 127.5 - 1).unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0).to(device, weight_dtype)
|
| 689 |
+
|
| 690 |
+
if replace_flag:
|
| 691 |
+
if bg_video is not None:
|
| 692 |
+
video_length = bg_video.shape[2]
|
| 693 |
+
bg_video = self.image_processor.preprocess(rearrange(bg_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 694 |
+
bg_video = bg_video.to(dtype=torch.float32)
|
| 695 |
+
bg_video = rearrange(bg_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 696 |
+
else:
|
| 697 |
+
bg_video = None
|
| 698 |
+
bg_video = self.inputs_padding(bg_video, target_len).to(device, weight_dtype)
|
| 699 |
+
mask_video = self.inputs_padding(mask_video, target_len).to(device, weight_dtype)
|
| 700 |
+
|
| 701 |
+
if comfyui_progressbar:
|
| 702 |
+
pbar.update(1)
|
| 703 |
+
|
| 704 |
+
# 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 705 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 706 |
+
|
| 707 |
+
target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
|
| 708 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 709 |
+
|
| 710 |
+
# 6. Denoising loop
|
| 711 |
+
start = 0
|
| 712 |
+
end = clip_len
|
| 713 |
+
all_out_frames = []
|
| 714 |
+
copy_timesteps = copy.deepcopy(timesteps)
|
| 715 |
+
copy_latents = copy.deepcopy(latents)
|
| 716 |
+
bs = pose_video.size()[0]
|
| 717 |
+
while True:
|
| 718 |
+
if start + refert_num >= pose_video.size()[2]:
|
| 719 |
+
break
|
| 720 |
+
|
| 721 |
+
# Prepare timesteps
|
| 722 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 723 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps, mu=1)
|
| 724 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 725 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 726 |
+
timesteps = self.scheduler.timesteps
|
| 727 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 728 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 729 |
+
timesteps, _ = retrieve_timesteps(
|
| 730 |
+
self.scheduler,
|
| 731 |
+
device=device,
|
| 732 |
+
sigmas=sampling_sigmas)
|
| 733 |
+
else:
|
| 734 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps)
|
| 735 |
+
self._num_timesteps = len(timesteps)
|
| 736 |
+
|
| 737 |
+
latent_channels = self.transformer.config.in_channels
|
| 738 |
+
latents = self.prepare_latents(
|
| 739 |
+
batch_size * num_videos_per_prompt,
|
| 740 |
+
latent_channels,
|
| 741 |
+
num_frames,
|
| 742 |
+
height,
|
| 743 |
+
width,
|
| 744 |
+
weight_dtype,
|
| 745 |
+
device,
|
| 746 |
+
generator,
|
| 747 |
+
copy_latents,
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
if start == 0:
|
| 751 |
+
mask_reft_len = 0
|
| 752 |
+
else:
|
| 753 |
+
mask_reft_len = refert_num
|
| 754 |
+
|
| 755 |
+
conditioning_pixel_values = pose_video[:, :, start:end]
|
| 756 |
+
face_pixel_values = face_video[:, :, start:end]
|
| 757 |
+
ref_pixel_values = ref_image.clone().detach()
|
| 758 |
+
if start > 0:
|
| 759 |
+
refer_t_pixel_values = out_frames[:, :, -refert_num:].clone().detach()
|
| 760 |
+
refer_t_pixel_values = (refer_t_pixel_values - 0.5) / 0.5
|
| 761 |
+
else:
|
| 762 |
+
refer_t_pixel_values = torch.zeros(bs, 3, refert_num, height, width)
|
| 763 |
+
refer_t_pixel_values = refer_t_pixel_values.to(device=device, dtype=weight_dtype)
|
| 764 |
+
|
| 765 |
+
pose_latents, ref_latents = self.prepare_control_latents(
|
| 766 |
+
conditioning_pixel_values,
|
| 767 |
+
ref_pixel_values,
|
| 768 |
+
batch_size,
|
| 769 |
+
height,
|
| 770 |
+
width,
|
| 771 |
+
weight_dtype,
|
| 772 |
+
device,
|
| 773 |
+
generator,
|
| 774 |
+
do_classifier_free_guidance
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
mask_ref = self.get_i2v_mask(1, target_shape[-1], target_shape[-2], 1, device=device)
|
| 778 |
+
y_ref = torch.concat([mask_ref, ref_latents], dim=1).to(device=device, dtype=weight_dtype)
|
| 779 |
+
if mask_reft_len > 0:
|
| 780 |
+
if replace_flag:
|
| 781 |
+
# Image.fromarray(np.array((refer_t_pixel_values[0, :, 0].permute(1,2,0) * 0.5 + 0.5).float().cpu().numpy() *255, np.uint8)).save("1.jpg")
|
| 782 |
+
bg_pixel_values = bg_video[:, :, start:end]
|
| 783 |
+
y_reft = self.vae.encode(
|
| 784 |
+
torch.concat(
|
| 785 |
+
[
|
| 786 |
+
refer_t_pixel_values[:, :, :mask_reft_len],
|
| 787 |
+
bg_pixel_values[:, :, mask_reft_len:]
|
| 788 |
+
], dim=2
|
| 789 |
+
).to(device=device, dtype=weight_dtype)
|
| 790 |
+
)[0].mode()
|
| 791 |
+
|
| 792 |
+
mask_pixel_values = 1 - mask_video[:, :, start:end]
|
| 793 |
+
mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w")
|
| 794 |
+
mask_pixel_values = F.interpolate(mask_pixel_values, size=(target_shape[-1], target_shape[-2]), mode='nearest')
|
| 795 |
+
mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b c t h w", b = bs)[:, 0]
|
| 796 |
+
msk_reft = self.get_i2v_mask(
|
| 797 |
+
int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, mask_pixel_values=mask_pixel_values, device=device
|
| 798 |
+
)
|
| 799 |
+
else:
|
| 800 |
+
refer_t_pixel_values = rearrange(refer_t_pixel_values[:, :, :mask_reft_len], "b c t h w -> (b t) c h w")
|
| 801 |
+
refer_t_pixel_values = F.interpolate(refer_t_pixel_values, size=(height, width), mode="bicubic")
|
| 802 |
+
refer_t_pixel_values = rearrange(refer_t_pixel_values, "(b t) c h w -> b c t h w", b = bs)
|
| 803 |
+
|
| 804 |
+
y_reft = self.vae.encode(
|
| 805 |
+
torch.concat(
|
| 806 |
+
[
|
| 807 |
+
refer_t_pixel_values,
|
| 808 |
+
torch.zeros(bs, 3, clip_len - mask_reft_len, height, width).to(device=device, dtype=weight_dtype),
|
| 809 |
+
], dim=2,
|
| 810 |
+
).to(device=device, dtype=weight_dtype)
|
| 811 |
+
)[0].mode()
|
| 812 |
+
msk_reft = self.get_i2v_mask(
|
| 813 |
+
int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, device=device
|
| 814 |
+
)
|
| 815 |
+
else:
|
| 816 |
+
if replace_flag:
|
| 817 |
+
bg_pixel_values = bg_video[:, :, start:end]
|
| 818 |
+
y_reft = self.vae.encode(
|
| 819 |
+
bg_pixel_values.to(device=device, dtype=weight_dtype)
|
| 820 |
+
)[0].mode()
|
| 821 |
+
|
| 822 |
+
mask_pixel_values = 1 - mask_video[:, :, start:end]
|
| 823 |
+
mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w")
|
| 824 |
+
mask_pixel_values = F.interpolate(mask_pixel_values, size=(target_shape[-1], target_shape[-2]), mode='nearest')
|
| 825 |
+
mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b c t h w", b = bs)[:, 0]
|
| 826 |
+
msk_reft = self.get_i2v_mask(
|
| 827 |
+
int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, mask_pixel_values=mask_pixel_values, device=device
|
| 828 |
+
)
|
| 829 |
+
else:
|
| 830 |
+
y_reft = self.vae.encode(
|
| 831 |
+
torch.zeros(1, 3, clip_len - mask_reft_len, height, width).to(device=device, dtype=weight_dtype)
|
| 832 |
+
)[0].mode()
|
| 833 |
+
msk_reft = self.get_i2v_mask(
|
| 834 |
+
int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, device=device
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
y_reft = torch.concat([msk_reft, y_reft], dim=1).to(device=device, dtype=weight_dtype)
|
| 838 |
+
y = torch.concat([y_ref, y_reft], dim=2)
|
| 839 |
+
|
| 840 |
+
clip_context = self.clip_image_encoder([ref_pixel_values[0, :, :, :]]).to(device=device, dtype=weight_dtype)
|
| 841 |
+
|
| 842 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 843 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 844 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 845 |
+
for i, t in enumerate(timesteps):
|
| 846 |
+
self.transformer.current_steps = i
|
| 847 |
+
|
| 848 |
+
if self.interrupt:
|
| 849 |
+
continue
|
| 850 |
+
|
| 851 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 852 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 853 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 854 |
+
|
| 855 |
+
y_in = torch.cat([y] * 2) if do_classifier_free_guidance else y
|
| 856 |
+
clip_context_input = (
|
| 857 |
+
torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
|
| 858 |
+
)
|
| 859 |
+
pose_latents_input = (
|
| 860 |
+
torch.cat([pose_latents] * 2) if do_classifier_free_guidance else pose_latents
|
| 861 |
+
)
|
| 862 |
+
face_pixel_values_input = (
|
| 863 |
+
torch.cat([torch.ones_like(face_pixel_values) * -1] + [face_pixel_values]) if do_classifier_free_guidance else face_pixel_values
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 867 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 868 |
+
|
| 869 |
+
if self.transformer_2 is not None:
|
| 870 |
+
if t >= boundary * self.scheduler.config.num_train_timesteps:
|
| 871 |
+
local_transformer = self.transformer_2
|
| 872 |
+
else:
|
| 873 |
+
local_transformer = self.transformer
|
| 874 |
+
else:
|
| 875 |
+
local_transformer = self.transformer
|
| 876 |
+
|
| 877 |
+
# predict noise model_output
|
| 878 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 879 |
+
noise_pred = local_transformer(
|
| 880 |
+
x=latent_model_input,
|
| 881 |
+
context=in_prompt_embeds,
|
| 882 |
+
t=timestep,
|
| 883 |
+
seq_len=seq_len,
|
| 884 |
+
y=y_in,
|
| 885 |
+
clip_fea=clip_context_input,
|
| 886 |
+
pose_latents=pose_latents_input,
|
| 887 |
+
face_pixel_values=face_pixel_values_input,
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
# Perform guidance
|
| 891 |
+
if do_classifier_free_guidance:
|
| 892 |
+
if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))):
|
| 893 |
+
sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0]
|
| 894 |
+
else:
|
| 895 |
+
sample_guide_scale = self.guidance_scale
|
| 896 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 897 |
+
noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond)
|
| 898 |
+
|
| 899 |
+
# Compute the previous noisy sample x_t -> x_t-1
|
| 900 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 901 |
+
|
| 902 |
+
if callback_on_step_end is not None:
|
| 903 |
+
callback_kwargs = {}
|
| 904 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 905 |
+
callback_kwargs[k] = locals()[k]
|
| 906 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 907 |
+
|
| 908 |
+
latents = callback_outputs.pop("latents", latents)
|
| 909 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 910 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 911 |
+
|
| 912 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 913 |
+
progress_bar.update()
|
| 914 |
+
if comfyui_progressbar:
|
| 915 |
+
pbar.update(1)
|
| 916 |
+
|
| 917 |
+
out_frames = self.decode_latents(latents[:, :, 1:])
|
| 918 |
+
if start != 0:
|
| 919 |
+
out_frames = out_frames[:, :, refert_num:]
|
| 920 |
+
all_out_frames.append(out_frames.cpu())
|
| 921 |
+
start += clip_len - refert_num
|
| 922 |
+
end += clip_len - refert_num
|
| 923 |
+
|
| 924 |
+
videos = torch.cat(all_out_frames, dim=2)[:, :, :real_frame_len]
|
| 925 |
+
|
| 926 |
+
# Offload all models
|
| 927 |
+
self.maybe_free_model_hooks()
|
| 928 |
+
|
| 929 |
+
return WanPipelineOutput(videos=videos.float().cpu())
|
videox_fun/pipeline/pipeline_wan2_2_fun_control.py
ADDED
|
@@ -0,0 +1,903 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 11 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 12 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 13 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 14 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 15 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 16 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 17 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 18 |
+
from diffusers.video_processor import VideoProcessor
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from transformers import T5Tokenizer
|
| 22 |
+
|
| 23 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer,
|
| 24 |
+
Wan2_2Transformer3DModel, WanT5EncoderModel)
|
| 25 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 26 |
+
get_sampling_sigmas)
|
| 27 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
EXAMPLE_DOC_STRING = """
|
| 33 |
+
Examples:
|
| 34 |
+
```python
|
| 35 |
+
pass
|
| 36 |
+
```
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 41 |
+
def retrieve_timesteps(
|
| 42 |
+
scheduler,
|
| 43 |
+
num_inference_steps: Optional[int] = None,
|
| 44 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 45 |
+
timesteps: Optional[List[int]] = None,
|
| 46 |
+
sigmas: Optional[List[float]] = None,
|
| 47 |
+
**kwargs,
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 51 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
scheduler (`SchedulerMixin`):
|
| 55 |
+
The scheduler to get timesteps from.
|
| 56 |
+
num_inference_steps (`int`):
|
| 57 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 58 |
+
must be `None`.
|
| 59 |
+
device (`str` or `torch.device`, *optional*):
|
| 60 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 61 |
+
timesteps (`List[int]`, *optional*):
|
| 62 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 63 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 64 |
+
sigmas (`List[float]`, *optional*):
|
| 65 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 66 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 70 |
+
second element is the number of inference steps.
|
| 71 |
+
"""
|
| 72 |
+
if timesteps is not None and sigmas is not None:
|
| 73 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 74 |
+
if timesteps is not None:
|
| 75 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 76 |
+
if not accepts_timesteps:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 79 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 80 |
+
)
|
| 81 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 82 |
+
timesteps = scheduler.timesteps
|
| 83 |
+
num_inference_steps = len(timesteps)
|
| 84 |
+
elif sigmas is not None:
|
| 85 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 86 |
+
if not accept_sigmas:
|
| 87 |
+
raise ValueError(
|
| 88 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 89 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 90 |
+
)
|
| 91 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 92 |
+
timesteps = scheduler.timesteps
|
| 93 |
+
num_inference_steps = len(timesteps)
|
| 94 |
+
else:
|
| 95 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 96 |
+
timesteps = scheduler.timesteps
|
| 97 |
+
return timesteps, num_inference_steps
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def resize_mask(mask, latent, process_first_frame_only=True):
|
| 101 |
+
latent_size = latent.size()
|
| 102 |
+
batch_size, channels, num_frames, height, width = mask.shape
|
| 103 |
+
|
| 104 |
+
if process_first_frame_only:
|
| 105 |
+
target_size = list(latent_size[2:])
|
| 106 |
+
target_size[0] = 1
|
| 107 |
+
first_frame_resized = F.interpolate(
|
| 108 |
+
mask[:, :, 0:1, :, :],
|
| 109 |
+
size=target_size,
|
| 110 |
+
mode='trilinear',
|
| 111 |
+
align_corners=False
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
target_size = list(latent_size[2:])
|
| 115 |
+
target_size[0] = target_size[0] - 1
|
| 116 |
+
if target_size[0] != 0:
|
| 117 |
+
remaining_frames_resized = F.interpolate(
|
| 118 |
+
mask[:, :, 1:, :, :],
|
| 119 |
+
size=target_size,
|
| 120 |
+
mode='trilinear',
|
| 121 |
+
align_corners=False
|
| 122 |
+
)
|
| 123 |
+
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
| 124 |
+
else:
|
| 125 |
+
resized_mask = first_frame_resized
|
| 126 |
+
else:
|
| 127 |
+
target_size = list(latent_size[2:])
|
| 128 |
+
resized_mask = F.interpolate(
|
| 129 |
+
mask,
|
| 130 |
+
size=target_size,
|
| 131 |
+
mode='trilinear',
|
| 132 |
+
align_corners=False
|
| 133 |
+
)
|
| 134 |
+
return resized_mask
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@dataclass
|
| 138 |
+
class WanPipelineOutput(BaseOutput):
|
| 139 |
+
r"""
|
| 140 |
+
Output class for CogVideo pipelines.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 144 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 145 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 146 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
videos: torch.Tensor
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class Wan2_2FunControlPipeline(DiffusionPipeline):
|
| 153 |
+
r"""
|
| 154 |
+
Pipeline for text-to-video generation using Wan.
|
| 155 |
+
|
| 156 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 157 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
_optional_components = ["transformer_2"]
|
| 161 |
+
model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
|
| 162 |
+
|
| 163 |
+
_callback_tensor_inputs = [
|
| 164 |
+
"latents",
|
| 165 |
+
"prompt_embeds",
|
| 166 |
+
"negative_prompt_embeds",
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
tokenizer: AutoTokenizer,
|
| 172 |
+
text_encoder: WanT5EncoderModel,
|
| 173 |
+
vae: AutoencoderKLWan,
|
| 174 |
+
transformer: Wan2_2Transformer3DModel,
|
| 175 |
+
transformer_2: Wan2_2Transformer3DModel = None,
|
| 176 |
+
scheduler: FlowMatchEulerDiscreteScheduler = None,
|
| 177 |
+
):
|
| 178 |
+
super().__init__()
|
| 179 |
+
|
| 180 |
+
self.register_modules(
|
| 181 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
|
| 182 |
+
transformer_2=transformer_2, scheduler=scheduler
|
| 183 |
+
)
|
| 184 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 185 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 186 |
+
self.mask_processor = VaeImageProcessor(
|
| 187 |
+
vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def _get_t5_prompt_embeds(
|
| 191 |
+
self,
|
| 192 |
+
prompt: Union[str, List[str]] = None,
|
| 193 |
+
num_videos_per_prompt: int = 1,
|
| 194 |
+
max_sequence_length: int = 512,
|
| 195 |
+
device: Optional[torch.device] = None,
|
| 196 |
+
dtype: Optional[torch.dtype] = None,
|
| 197 |
+
):
|
| 198 |
+
device = device or self._execution_device
|
| 199 |
+
dtype = dtype or self.text_encoder.dtype
|
| 200 |
+
|
| 201 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 202 |
+
batch_size = len(prompt)
|
| 203 |
+
|
| 204 |
+
text_inputs = self.tokenizer(
|
| 205 |
+
prompt,
|
| 206 |
+
padding="max_length",
|
| 207 |
+
max_length=max_sequence_length,
|
| 208 |
+
truncation=True,
|
| 209 |
+
add_special_tokens=True,
|
| 210 |
+
return_tensors="pt",
|
| 211 |
+
)
|
| 212 |
+
text_input_ids = text_inputs.input_ids
|
| 213 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 214 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 215 |
+
|
| 216 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 217 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 218 |
+
logger.warning(
|
| 219 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 220 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 224 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 225 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 226 |
+
|
| 227 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 228 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 229 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 230 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 231 |
+
|
| 232 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 233 |
+
|
| 234 |
+
def encode_prompt(
|
| 235 |
+
self,
|
| 236 |
+
prompt: Union[str, List[str]],
|
| 237 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 238 |
+
do_classifier_free_guidance: bool = True,
|
| 239 |
+
num_videos_per_prompt: int = 1,
|
| 240 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 241 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 242 |
+
max_sequence_length: int = 512,
|
| 243 |
+
device: Optional[torch.device] = None,
|
| 244 |
+
dtype: Optional[torch.dtype] = None,
|
| 245 |
+
):
|
| 246 |
+
r"""
|
| 247 |
+
Encodes the prompt into text encoder hidden states.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 251 |
+
prompt to be encoded
|
| 252 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 253 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 254 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 255 |
+
less than `1`).
|
| 256 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 257 |
+
Whether to use classifier free guidance or not.
|
| 258 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 259 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 260 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 261 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 262 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 263 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 264 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 265 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 266 |
+
argument.
|
| 267 |
+
device: (`torch.device`, *optional*):
|
| 268 |
+
torch device
|
| 269 |
+
dtype: (`torch.dtype`, *optional*):
|
| 270 |
+
torch dtype
|
| 271 |
+
"""
|
| 272 |
+
device = device or self._execution_device
|
| 273 |
+
|
| 274 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 275 |
+
if prompt is not None:
|
| 276 |
+
batch_size = len(prompt)
|
| 277 |
+
else:
|
| 278 |
+
batch_size = prompt_embeds.shape[0]
|
| 279 |
+
|
| 280 |
+
if prompt_embeds is None:
|
| 281 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 282 |
+
prompt=prompt,
|
| 283 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 284 |
+
max_sequence_length=max_sequence_length,
|
| 285 |
+
device=device,
|
| 286 |
+
dtype=dtype,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 290 |
+
negative_prompt = negative_prompt or ""
|
| 291 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 292 |
+
|
| 293 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 294 |
+
raise TypeError(
|
| 295 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 296 |
+
f" {type(prompt)}."
|
| 297 |
+
)
|
| 298 |
+
elif batch_size != len(negative_prompt):
|
| 299 |
+
raise ValueError(
|
| 300 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 301 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 302 |
+
" the batch size of `prompt`."
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 306 |
+
prompt=negative_prompt,
|
| 307 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 308 |
+
max_sequence_length=max_sequence_length,
|
| 309 |
+
device=device,
|
| 310 |
+
dtype=dtype,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
return prompt_embeds, negative_prompt_embeds
|
| 314 |
+
|
| 315 |
+
def prepare_latents(
|
| 316 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 317 |
+
):
|
| 318 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 319 |
+
raise ValueError(
|
| 320 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 321 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
shape = (
|
| 325 |
+
batch_size,
|
| 326 |
+
num_channels_latents,
|
| 327 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1,
|
| 328 |
+
height // self.vae.spatial_compression_ratio,
|
| 329 |
+
width // self.vae.spatial_compression_ratio,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if latents is None:
|
| 333 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 334 |
+
else:
|
| 335 |
+
latents = latents.to(device)
|
| 336 |
+
|
| 337 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 338 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 339 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 340 |
+
return latents
|
| 341 |
+
|
| 342 |
+
def prepare_mask_latents(
|
| 343 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
|
| 344 |
+
):
|
| 345 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
| 346 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 347 |
+
# and half precision
|
| 348 |
+
|
| 349 |
+
if mask is not None:
|
| 350 |
+
mask = mask.to(device=device, dtype=self.vae.dtype)
|
| 351 |
+
bs = 1
|
| 352 |
+
new_mask = []
|
| 353 |
+
for i in range(0, mask.shape[0], bs):
|
| 354 |
+
mask_bs = mask[i : i + bs]
|
| 355 |
+
mask_bs = self.vae.encode(mask_bs)[0]
|
| 356 |
+
mask_bs = mask_bs.mode()
|
| 357 |
+
new_mask.append(mask_bs)
|
| 358 |
+
mask = torch.cat(new_mask, dim = 0)
|
| 359 |
+
# mask = mask * self.vae.config.scaling_factor
|
| 360 |
+
|
| 361 |
+
if masked_image is not None:
|
| 362 |
+
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
| 363 |
+
bs = 1
|
| 364 |
+
new_mask_pixel_values = []
|
| 365 |
+
for i in range(0, masked_image.shape[0], bs):
|
| 366 |
+
mask_pixel_values_bs = masked_image[i : i + bs]
|
| 367 |
+
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
| 368 |
+
mask_pixel_values_bs = mask_pixel_values_bs.mode()
|
| 369 |
+
new_mask_pixel_values.append(mask_pixel_values_bs)
|
| 370 |
+
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
| 371 |
+
# masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
| 372 |
+
else:
|
| 373 |
+
masked_image_latents = None
|
| 374 |
+
|
| 375 |
+
return mask, masked_image_latents
|
| 376 |
+
|
| 377 |
+
def prepare_control_latents(
|
| 378 |
+
self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
| 379 |
+
):
|
| 380 |
+
# resize the control to latents shape as we concatenate the control to the latents
|
| 381 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 382 |
+
# and half precision
|
| 383 |
+
|
| 384 |
+
if control is not None:
|
| 385 |
+
control = control.to(device=device, dtype=dtype)
|
| 386 |
+
bs = 1
|
| 387 |
+
new_control = []
|
| 388 |
+
for i in range(0, control.shape[0], bs):
|
| 389 |
+
control_bs = control[i : i + bs]
|
| 390 |
+
control_bs = self.vae.encode(control_bs)[0]
|
| 391 |
+
control_bs = control_bs.mode()
|
| 392 |
+
new_control.append(control_bs)
|
| 393 |
+
control = torch.cat(new_control, dim = 0)
|
| 394 |
+
|
| 395 |
+
if control_image is not None:
|
| 396 |
+
control_image = control_image.to(device=device, dtype=dtype)
|
| 397 |
+
bs = 1
|
| 398 |
+
new_control_pixel_values = []
|
| 399 |
+
for i in range(0, control_image.shape[0], bs):
|
| 400 |
+
control_pixel_values_bs = control_image[i : i + bs]
|
| 401 |
+
control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
|
| 402 |
+
control_pixel_values_bs = control_pixel_values_bs.mode()
|
| 403 |
+
new_control_pixel_values.append(control_pixel_values_bs)
|
| 404 |
+
control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
|
| 405 |
+
else:
|
| 406 |
+
control_image_latents = None
|
| 407 |
+
|
| 408 |
+
return control, control_image_latents
|
| 409 |
+
|
| 410 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 411 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 412 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 413 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 414 |
+
frames = frames.cpu().float().numpy()
|
| 415 |
+
return frames
|
| 416 |
+
|
| 417 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 418 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 419 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 420 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 421 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 422 |
+
# and should be between [0, 1]
|
| 423 |
+
|
| 424 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 425 |
+
extra_step_kwargs = {}
|
| 426 |
+
if accepts_eta:
|
| 427 |
+
extra_step_kwargs["eta"] = eta
|
| 428 |
+
|
| 429 |
+
# check if the scheduler accepts generator
|
| 430 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 431 |
+
if accepts_generator:
|
| 432 |
+
extra_step_kwargs["generator"] = generator
|
| 433 |
+
return extra_step_kwargs
|
| 434 |
+
|
| 435 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 436 |
+
def check_inputs(
|
| 437 |
+
self,
|
| 438 |
+
prompt,
|
| 439 |
+
height,
|
| 440 |
+
width,
|
| 441 |
+
negative_prompt,
|
| 442 |
+
callback_on_step_end_tensor_inputs,
|
| 443 |
+
prompt_embeds=None,
|
| 444 |
+
negative_prompt_embeds=None,
|
| 445 |
+
):
|
| 446 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 447 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 448 |
+
|
| 449 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 450 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 451 |
+
):
|
| 452 |
+
raise ValueError(
|
| 453 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 454 |
+
)
|
| 455 |
+
if prompt is not None and prompt_embeds is not None:
|
| 456 |
+
raise ValueError(
|
| 457 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 458 |
+
" only forward one of the two."
|
| 459 |
+
)
|
| 460 |
+
elif prompt is None and prompt_embeds is None:
|
| 461 |
+
raise ValueError(
|
| 462 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 463 |
+
)
|
| 464 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 465 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 466 |
+
|
| 467 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 468 |
+
raise ValueError(
|
| 469 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 470 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 474 |
+
raise ValueError(
|
| 475 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 476 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 480 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 481 |
+
raise ValueError(
|
| 482 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 483 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 484 |
+
f" {negative_prompt_embeds.shape}."
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
@property
|
| 488 |
+
def guidance_scale(self):
|
| 489 |
+
return self._guidance_scale
|
| 490 |
+
|
| 491 |
+
@property
|
| 492 |
+
def num_timesteps(self):
|
| 493 |
+
return self._num_timesteps
|
| 494 |
+
|
| 495 |
+
@property
|
| 496 |
+
def attention_kwargs(self):
|
| 497 |
+
return self._attention_kwargs
|
| 498 |
+
|
| 499 |
+
@property
|
| 500 |
+
def interrupt(self):
|
| 501 |
+
return self._interrupt
|
| 502 |
+
|
| 503 |
+
@torch.no_grad()
|
| 504 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 505 |
+
def __call__(
|
| 506 |
+
self,
|
| 507 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 508 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 509 |
+
height: int = 480,
|
| 510 |
+
width: int = 720,
|
| 511 |
+
video: Union[torch.FloatTensor] = None,
|
| 512 |
+
mask_video: Union[torch.FloatTensor] = None,
|
| 513 |
+
control_video: Union[torch.FloatTensor] = None,
|
| 514 |
+
control_camera_video: Union[torch.FloatTensor] = None,
|
| 515 |
+
start_image: Union[torch.FloatTensor] = None,
|
| 516 |
+
ref_image: Union[torch.FloatTensor] = None,
|
| 517 |
+
num_frames: int = 49,
|
| 518 |
+
num_inference_steps: int = 50,
|
| 519 |
+
timesteps: Optional[List[int]] = None,
|
| 520 |
+
guidance_scale: float = 6,
|
| 521 |
+
num_videos_per_prompt: int = 1,
|
| 522 |
+
eta: float = 0.0,
|
| 523 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 524 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 525 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 526 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 527 |
+
output_type: str = "numpy",
|
| 528 |
+
return_dict: bool = False,
|
| 529 |
+
callback_on_step_end: Optional[
|
| 530 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 531 |
+
] = None,
|
| 532 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 533 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 534 |
+
max_sequence_length: int = 512,
|
| 535 |
+
boundary: float = 0.875,
|
| 536 |
+
comfyui_progressbar: bool = False,
|
| 537 |
+
shift: int = 5,
|
| 538 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 539 |
+
"""
|
| 540 |
+
Function invoked when calling the pipeline for generation.
|
| 541 |
+
Args:
|
| 542 |
+
|
| 543 |
+
Examples:
|
| 544 |
+
|
| 545 |
+
Returns:
|
| 546 |
+
|
| 547 |
+
"""
|
| 548 |
+
|
| 549 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 550 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 551 |
+
num_videos_per_prompt = 1
|
| 552 |
+
|
| 553 |
+
# 1. Check inputs. Raise error if not correct
|
| 554 |
+
self.check_inputs(
|
| 555 |
+
prompt,
|
| 556 |
+
height,
|
| 557 |
+
width,
|
| 558 |
+
negative_prompt,
|
| 559 |
+
callback_on_step_end_tensor_inputs,
|
| 560 |
+
prompt_embeds,
|
| 561 |
+
negative_prompt_embeds,
|
| 562 |
+
)
|
| 563 |
+
self._guidance_scale = guidance_scale
|
| 564 |
+
self._attention_kwargs = attention_kwargs
|
| 565 |
+
self._interrupt = False
|
| 566 |
+
|
| 567 |
+
# 2. Default call parameters
|
| 568 |
+
if prompt is not None and isinstance(prompt, str):
|
| 569 |
+
batch_size = 1
|
| 570 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 571 |
+
batch_size = len(prompt)
|
| 572 |
+
else:
|
| 573 |
+
batch_size = prompt_embeds.shape[0]
|
| 574 |
+
|
| 575 |
+
device = self._execution_device
|
| 576 |
+
weight_dtype = self.text_encoder.dtype
|
| 577 |
+
|
| 578 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 579 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 580 |
+
# corresponds to doing no classifier free guidance.
|
| 581 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 582 |
+
|
| 583 |
+
# 3. Encode input prompt
|
| 584 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 585 |
+
prompt,
|
| 586 |
+
negative_prompt,
|
| 587 |
+
do_classifier_free_guidance,
|
| 588 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 589 |
+
prompt_embeds=prompt_embeds,
|
| 590 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 591 |
+
max_sequence_length=max_sequence_length,
|
| 592 |
+
device=device,
|
| 593 |
+
)
|
| 594 |
+
if do_classifier_free_guidance:
|
| 595 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 596 |
+
else:
|
| 597 |
+
in_prompt_embeds = prompt_embeds
|
| 598 |
+
|
| 599 |
+
# 4. Prepare timesteps
|
| 600 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 601 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
| 602 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 603 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 604 |
+
timesteps = self.scheduler.timesteps
|
| 605 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 606 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 607 |
+
timesteps, _ = retrieve_timesteps(
|
| 608 |
+
self.scheduler,
|
| 609 |
+
device=device,
|
| 610 |
+
sigmas=sampling_sigmas)
|
| 611 |
+
else:
|
| 612 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 613 |
+
self._num_timesteps = len(timesteps)
|
| 614 |
+
if comfyui_progressbar:
|
| 615 |
+
from comfy.utils import ProgressBar
|
| 616 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 617 |
+
|
| 618 |
+
# 5. Prepare latents.
|
| 619 |
+
if video is not None:
|
| 620 |
+
video_length = video.shape[2]
|
| 621 |
+
init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 622 |
+
init_video = init_video.to(dtype=torch.float32)
|
| 623 |
+
init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 624 |
+
else:
|
| 625 |
+
init_video = None
|
| 626 |
+
|
| 627 |
+
latent_channels = self.vae.config.latent_channels
|
| 628 |
+
latents = self.prepare_latents(
|
| 629 |
+
batch_size * num_videos_per_prompt,
|
| 630 |
+
latent_channels,
|
| 631 |
+
num_frames,
|
| 632 |
+
height,
|
| 633 |
+
width,
|
| 634 |
+
weight_dtype,
|
| 635 |
+
device,
|
| 636 |
+
generator,
|
| 637 |
+
latents,
|
| 638 |
+
)
|
| 639 |
+
if comfyui_progressbar:
|
| 640 |
+
pbar.update(1)
|
| 641 |
+
|
| 642 |
+
# Prepare mask latent variables
|
| 643 |
+
if init_video is not None:
|
| 644 |
+
if (mask_video == 255).all():
|
| 645 |
+
mask_latents = torch.tile(
|
| 646 |
+
torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1]
|
| 647 |
+
)
|
| 648 |
+
masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
|
| 649 |
+
if self.vae.spatial_compression_ratio >= 16:
|
| 650 |
+
mask = torch.ones_like(latents).to(device, weight_dtype)[:, :1].to(device, weight_dtype)
|
| 651 |
+
else:
|
| 652 |
+
bs, _, video_length, height, width = video.size()
|
| 653 |
+
mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 654 |
+
mask_condition = mask_condition.to(dtype=torch.float32)
|
| 655 |
+
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
|
| 656 |
+
|
| 657 |
+
masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5)
|
| 658 |
+
_, masked_video_latents = self.prepare_mask_latents(
|
| 659 |
+
None,
|
| 660 |
+
masked_video,
|
| 661 |
+
batch_size,
|
| 662 |
+
height,
|
| 663 |
+
width,
|
| 664 |
+
weight_dtype,
|
| 665 |
+
device,
|
| 666 |
+
generator,
|
| 667 |
+
do_classifier_free_guidance,
|
| 668 |
+
noise_aug_strength=None,
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
mask_condition = torch.concat(
|
| 672 |
+
[
|
| 673 |
+
torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
|
| 674 |
+
mask_condition[:, :, 1:]
|
| 675 |
+
], dim=2
|
| 676 |
+
)
|
| 677 |
+
mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width)
|
| 678 |
+
mask_condition = mask_condition.transpose(1, 2)
|
| 679 |
+
mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype)
|
| 680 |
+
|
| 681 |
+
if self.vae.spatial_compression_ratio >= 16:
|
| 682 |
+
mask = F.interpolate(mask_condition[:, :1], size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, weight_dtype)
|
| 683 |
+
if not mask[:, :, 0, :, :].any():
|
| 684 |
+
mask[:, :, 1:, :, :] = 1
|
| 685 |
+
latents = (1 - mask) * masked_video_latents + mask * latents
|
| 686 |
+
|
| 687 |
+
# Prepare mask latent variables
|
| 688 |
+
if control_camera_video is not None:
|
| 689 |
+
control_latents = None
|
| 690 |
+
# Rearrange dimensions
|
| 691 |
+
# Concatenate and transpose dimensions
|
| 692 |
+
control_camera_latents = torch.concat(
|
| 693 |
+
[
|
| 694 |
+
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
|
| 695 |
+
control_camera_video[:, :, 1:]
|
| 696 |
+
], dim=2
|
| 697 |
+
).transpose(1, 2)
|
| 698 |
+
|
| 699 |
+
# Reshape, transpose, and view into desired shape
|
| 700 |
+
b, f, c, h, w = control_camera_latents.shape
|
| 701 |
+
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
| 702 |
+
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
| 703 |
+
elif control_video is not None:
|
| 704 |
+
video_length = control_video.shape[2]
|
| 705 |
+
control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 706 |
+
control_video = control_video.to(dtype=torch.float32)
|
| 707 |
+
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 708 |
+
control_video_latents = self.prepare_control_latents(
|
| 709 |
+
None,
|
| 710 |
+
control_video,
|
| 711 |
+
batch_size,
|
| 712 |
+
height,
|
| 713 |
+
width,
|
| 714 |
+
weight_dtype,
|
| 715 |
+
device,
|
| 716 |
+
generator,
|
| 717 |
+
do_classifier_free_guidance
|
| 718 |
+
)[1]
|
| 719 |
+
control_camera_latents = None
|
| 720 |
+
else:
|
| 721 |
+
control_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
|
| 722 |
+
control_camera_latents = None
|
| 723 |
+
|
| 724 |
+
if start_image is not None:
|
| 725 |
+
video_length = start_image.shape[2]
|
| 726 |
+
start_image = self.image_processor.preprocess(rearrange(start_image, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 727 |
+
start_image = start_image.to(dtype=torch.float32)
|
| 728 |
+
start_image = rearrange(start_image, "(b f) c h w -> b c f h w", f=video_length)
|
| 729 |
+
|
| 730 |
+
start_image_latentes = self.prepare_control_latents(
|
| 731 |
+
None,
|
| 732 |
+
start_image,
|
| 733 |
+
batch_size,
|
| 734 |
+
height,
|
| 735 |
+
width,
|
| 736 |
+
weight_dtype,
|
| 737 |
+
device,
|
| 738 |
+
generator,
|
| 739 |
+
do_classifier_free_guidance
|
| 740 |
+
)[1]
|
| 741 |
+
|
| 742 |
+
start_image_latentes_conv_in = torch.zeros_like(latents)
|
| 743 |
+
if latents.size()[2] != 1:
|
| 744 |
+
start_image_latentes_conv_in[:, :, :1] = start_image_latentes
|
| 745 |
+
else:
|
| 746 |
+
start_image_latentes_conv_in = torch.zeros_like(latents)
|
| 747 |
+
|
| 748 |
+
if self.transformer.config.get("add_ref_conv", False):
|
| 749 |
+
if ref_image is not None:
|
| 750 |
+
video_length = ref_image.shape[2]
|
| 751 |
+
ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 752 |
+
ref_image = ref_image.to(dtype=torch.float32)
|
| 753 |
+
ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length)
|
| 754 |
+
|
| 755 |
+
ref_image_latentes = self.prepare_control_latents(
|
| 756 |
+
None,
|
| 757 |
+
ref_image,
|
| 758 |
+
batch_size,
|
| 759 |
+
height,
|
| 760 |
+
width,
|
| 761 |
+
weight_dtype,
|
| 762 |
+
device,
|
| 763 |
+
generator,
|
| 764 |
+
do_classifier_free_guidance
|
| 765 |
+
)[1]
|
| 766 |
+
ref_image_latentes = ref_image_latentes[:, :, 0]
|
| 767 |
+
else:
|
| 768 |
+
ref_image_latentes = torch.zeros_like(latents)[:, :, 0]
|
| 769 |
+
else:
|
| 770 |
+
if ref_image is not None:
|
| 771 |
+
raise ValueError("The add_ref_conv is False, but ref_image is not None")
|
| 772 |
+
else:
|
| 773 |
+
ref_image_latentes = None
|
| 774 |
+
|
| 775 |
+
if comfyui_progressbar:
|
| 776 |
+
pbar.update(1)
|
| 777 |
+
|
| 778 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 779 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 780 |
+
|
| 781 |
+
target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
|
| 782 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 783 |
+
# 7. Denoising loop
|
| 784 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 785 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 786 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 787 |
+
for i, t in enumerate(timesteps):
|
| 788 |
+
self.transformer.current_steps = i
|
| 789 |
+
|
| 790 |
+
if self.interrupt:
|
| 791 |
+
continue
|
| 792 |
+
|
| 793 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 794 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 795 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 796 |
+
|
| 797 |
+
# Prepare mask latent variables
|
| 798 |
+
if control_camera_video is not None:
|
| 799 |
+
control_latents_input = None
|
| 800 |
+
control_camera_latents_input = (
|
| 801 |
+
torch.cat([control_camera_latents] * 2) if do_classifier_free_guidance else control_camera_latents
|
| 802 |
+
).to(device, weight_dtype)
|
| 803 |
+
else:
|
| 804 |
+
control_latents_input = (
|
| 805 |
+
torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
|
| 806 |
+
).to(device, weight_dtype)
|
| 807 |
+
control_camera_latents_input = None
|
| 808 |
+
|
| 809 |
+
if init_video is not None:
|
| 810 |
+
mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
|
| 811 |
+
masked_video_latents_input = (
|
| 812 |
+
torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
|
| 813 |
+
)
|
| 814 |
+
y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype)
|
| 815 |
+
control_latents_input = y if control_latents_input is None else \
|
| 816 |
+
torch.cat([control_latents_input, y], dim = 1)
|
| 817 |
+
else:
|
| 818 |
+
start_image_latentes_conv_in_input = (
|
| 819 |
+
torch.cat([start_image_latentes_conv_in] * 2) if do_classifier_free_guidance else start_image_latentes_conv_in
|
| 820 |
+
).to(device, weight_dtype)
|
| 821 |
+
control_latents_input = start_image_latentes_conv_in_input if control_latents_input is None else \
|
| 822 |
+
torch.cat([control_latents_input, start_image_latentes_conv_in_input], dim = 1)
|
| 823 |
+
|
| 824 |
+
if ref_image_latentes is not None:
|
| 825 |
+
full_ref = (
|
| 826 |
+
torch.cat([ref_image_latentes] * 2) if do_classifier_free_guidance else ref_image_latentes
|
| 827 |
+
).to(device, weight_dtype)
|
| 828 |
+
else:
|
| 829 |
+
full_ref = None
|
| 830 |
+
|
| 831 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 832 |
+
if self.vae.spatial_compression_ratio >= 16 and init_video is not None:
|
| 833 |
+
temp_ts = ((mask[0][0][:, ::2, ::2]) * t).flatten()
|
| 834 |
+
temp_ts = torch.cat([
|
| 835 |
+
temp_ts,
|
| 836 |
+
temp_ts.new_ones(seq_len - temp_ts.size(0)) * t
|
| 837 |
+
])
|
| 838 |
+
temp_ts = temp_ts.unsqueeze(0)
|
| 839 |
+
timestep = temp_ts.expand(latent_model_input.shape[0], temp_ts.size(1))
|
| 840 |
+
else:
|
| 841 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 842 |
+
|
| 843 |
+
if self.transformer_2 is not None:
|
| 844 |
+
if t >= boundary * self.scheduler.config.num_train_timesteps:
|
| 845 |
+
local_transformer = self.transformer_2
|
| 846 |
+
else:
|
| 847 |
+
local_transformer = self.transformer
|
| 848 |
+
else:
|
| 849 |
+
local_transformer = self.transformer
|
| 850 |
+
|
| 851 |
+
# predict noise model_output
|
| 852 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 853 |
+
noise_pred = local_transformer(
|
| 854 |
+
x=latent_model_input,
|
| 855 |
+
context=in_prompt_embeds,
|
| 856 |
+
t=timestep,
|
| 857 |
+
seq_len=seq_len,
|
| 858 |
+
y=control_latents_input,
|
| 859 |
+
y_camera=control_camera_latents_input,
|
| 860 |
+
full_ref=full_ref,
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
# perform guidance
|
| 864 |
+
if do_classifier_free_guidance:
|
| 865 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 866 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 867 |
+
|
| 868 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 869 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 870 |
+
|
| 871 |
+
if self.vae.spatial_compression_ratio >= 16 and not mask[:, :, 0, :, :].any():
|
| 872 |
+
latents = (1 - mask) * masked_video_latents + mask * latents
|
| 873 |
+
|
| 874 |
+
if callback_on_step_end is not None:
|
| 875 |
+
callback_kwargs = {}
|
| 876 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 877 |
+
callback_kwargs[k] = locals()[k]
|
| 878 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 879 |
+
|
| 880 |
+
latents = callback_outputs.pop("latents", latents)
|
| 881 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 882 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 883 |
+
|
| 884 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 885 |
+
progress_bar.update()
|
| 886 |
+
if comfyui_progressbar:
|
| 887 |
+
pbar.update(1)
|
| 888 |
+
|
| 889 |
+
if output_type == "numpy":
|
| 890 |
+
video = self.decode_latents(latents)
|
| 891 |
+
elif not output_type == "latent":
|
| 892 |
+
video = self.decode_latents(latents)
|
| 893 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 894 |
+
else:
|
| 895 |
+
video = latents
|
| 896 |
+
|
| 897 |
+
# Offload all models
|
| 898 |
+
self.maybe_free_model_hooks()
|
| 899 |
+
|
| 900 |
+
if not return_dict:
|
| 901 |
+
video = torch.from_numpy(video)
|
| 902 |
+
|
| 903 |
+
return WanPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_wan2_2_fun_inpaint.py
ADDED
|
@@ -0,0 +1,752 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 11 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 12 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 13 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 14 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 15 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 16 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 17 |
+
from diffusers.video_processor import VideoProcessor
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from transformers import T5Tokenizer
|
| 21 |
+
|
| 22 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
|
| 23 |
+
WanT5EncoderModel, Wan2_2Transformer3DModel)
|
| 24 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 25 |
+
get_sampling_sigmas)
|
| 26 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
EXAMPLE_DOC_STRING = """
|
| 32 |
+
Examples:
|
| 33 |
+
```python
|
| 34 |
+
pass
|
| 35 |
+
```
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 40 |
+
def retrieve_timesteps(
|
| 41 |
+
scheduler,
|
| 42 |
+
num_inference_steps: Optional[int] = None,
|
| 43 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 44 |
+
timesteps: Optional[List[int]] = None,
|
| 45 |
+
sigmas: Optional[List[float]] = None,
|
| 46 |
+
**kwargs,
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 50 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
scheduler (`SchedulerMixin`):
|
| 54 |
+
The scheduler to get timesteps from.
|
| 55 |
+
num_inference_steps (`int`):
|
| 56 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 57 |
+
must be `None`.
|
| 58 |
+
device (`str` or `torch.device`, *optional*):
|
| 59 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 60 |
+
timesteps (`List[int]`, *optional*):
|
| 61 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 62 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 63 |
+
sigmas (`List[float]`, *optional*):
|
| 64 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 65 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 69 |
+
second element is the number of inference steps.
|
| 70 |
+
"""
|
| 71 |
+
if timesteps is not None and sigmas is not None:
|
| 72 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 73 |
+
if timesteps is not None:
|
| 74 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 75 |
+
if not accepts_timesteps:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 78 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 79 |
+
)
|
| 80 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 81 |
+
timesteps = scheduler.timesteps
|
| 82 |
+
num_inference_steps = len(timesteps)
|
| 83 |
+
elif sigmas is not None:
|
| 84 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 85 |
+
if not accept_sigmas:
|
| 86 |
+
raise ValueError(
|
| 87 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 88 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 89 |
+
)
|
| 90 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 91 |
+
timesteps = scheduler.timesteps
|
| 92 |
+
num_inference_steps = len(timesteps)
|
| 93 |
+
else:
|
| 94 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 95 |
+
timesteps = scheduler.timesteps
|
| 96 |
+
return timesteps, num_inference_steps
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def resize_mask(mask, latent, process_first_frame_only=True):
|
| 100 |
+
latent_size = latent.size()
|
| 101 |
+
batch_size, channels, num_frames, height, width = mask.shape
|
| 102 |
+
|
| 103 |
+
if process_first_frame_only:
|
| 104 |
+
target_size = list(latent_size[2:])
|
| 105 |
+
target_size[0] = 1
|
| 106 |
+
first_frame_resized = F.interpolate(
|
| 107 |
+
mask[:, :, 0:1, :, :],
|
| 108 |
+
size=target_size,
|
| 109 |
+
mode='trilinear',
|
| 110 |
+
align_corners=False
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
target_size = list(latent_size[2:])
|
| 114 |
+
target_size[0] = target_size[0] - 1
|
| 115 |
+
if target_size[0] != 0:
|
| 116 |
+
remaining_frames_resized = F.interpolate(
|
| 117 |
+
mask[:, :, 1:, :, :],
|
| 118 |
+
size=target_size,
|
| 119 |
+
mode='trilinear',
|
| 120 |
+
align_corners=False
|
| 121 |
+
)
|
| 122 |
+
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
| 123 |
+
else:
|
| 124 |
+
resized_mask = first_frame_resized
|
| 125 |
+
else:
|
| 126 |
+
target_size = list(latent_size[2:])
|
| 127 |
+
resized_mask = F.interpolate(
|
| 128 |
+
mask,
|
| 129 |
+
size=target_size,
|
| 130 |
+
mode='trilinear',
|
| 131 |
+
align_corners=False
|
| 132 |
+
)
|
| 133 |
+
return resized_mask
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@dataclass
|
| 137 |
+
class WanPipelineOutput(BaseOutput):
|
| 138 |
+
r"""
|
| 139 |
+
Output class for CogVideo pipelines.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 143 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 144 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 145 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
videos: torch.Tensor
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class Wan2_2FunInpaintPipeline(DiffusionPipeline):
|
| 152 |
+
r"""
|
| 153 |
+
Pipeline for text-to-video generation using Wan.
|
| 154 |
+
|
| 155 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 156 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
_optional_components = ["transformer_2"]
|
| 160 |
+
model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
|
| 161 |
+
|
| 162 |
+
_callback_tensor_inputs = [
|
| 163 |
+
"latents",
|
| 164 |
+
"prompt_embeds",
|
| 165 |
+
"negative_prompt_embeds",
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
def __init__(
|
| 169 |
+
self,
|
| 170 |
+
tokenizer: AutoTokenizer,
|
| 171 |
+
text_encoder: WanT5EncoderModel,
|
| 172 |
+
vae: AutoencoderKLWan,
|
| 173 |
+
transformer: Wan2_2Transformer3DModel,
|
| 174 |
+
transformer_2: Wan2_2Transformer3DModel = None,
|
| 175 |
+
scheduler: FlowMatchEulerDiscreteScheduler = None,
|
| 176 |
+
):
|
| 177 |
+
super().__init__()
|
| 178 |
+
|
| 179 |
+
self.register_modules(
|
| 180 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
|
| 181 |
+
transformer_2=transformer_2, scheduler=scheduler
|
| 182 |
+
)
|
| 183 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 184 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 185 |
+
self.mask_processor = VaeImageProcessor(
|
| 186 |
+
vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def _get_t5_prompt_embeds(
|
| 190 |
+
self,
|
| 191 |
+
prompt: Union[str, List[str]] = None,
|
| 192 |
+
num_videos_per_prompt: int = 1,
|
| 193 |
+
max_sequence_length: int = 512,
|
| 194 |
+
device: Optional[torch.device] = None,
|
| 195 |
+
dtype: Optional[torch.dtype] = None,
|
| 196 |
+
):
|
| 197 |
+
device = device or self._execution_device
|
| 198 |
+
dtype = dtype or self.text_encoder.dtype
|
| 199 |
+
|
| 200 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 201 |
+
batch_size = len(prompt)
|
| 202 |
+
|
| 203 |
+
text_inputs = self.tokenizer(
|
| 204 |
+
prompt,
|
| 205 |
+
padding="max_length",
|
| 206 |
+
max_length=max_sequence_length,
|
| 207 |
+
truncation=True,
|
| 208 |
+
add_special_tokens=True,
|
| 209 |
+
return_tensors="pt",
|
| 210 |
+
)
|
| 211 |
+
text_input_ids = text_inputs.input_ids
|
| 212 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 213 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 214 |
+
|
| 215 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 216 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 217 |
+
logger.warning(
|
| 218 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 219 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 223 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 224 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 225 |
+
|
| 226 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 227 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 228 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 229 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 230 |
+
|
| 231 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 232 |
+
|
| 233 |
+
def encode_prompt(
|
| 234 |
+
self,
|
| 235 |
+
prompt: Union[str, List[str]],
|
| 236 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 237 |
+
do_classifier_free_guidance: bool = True,
|
| 238 |
+
num_videos_per_prompt: int = 1,
|
| 239 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 240 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 241 |
+
max_sequence_length: int = 512,
|
| 242 |
+
device: Optional[torch.device] = None,
|
| 243 |
+
dtype: Optional[torch.dtype] = None,
|
| 244 |
+
):
|
| 245 |
+
r"""
|
| 246 |
+
Encodes the prompt into text encoder hidden states.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 250 |
+
prompt to be encoded
|
| 251 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 252 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 253 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 254 |
+
less than `1`).
|
| 255 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 256 |
+
Whether to use classifier free guidance or not.
|
| 257 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 258 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 259 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 260 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 261 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 262 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 263 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 264 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 265 |
+
argument.
|
| 266 |
+
device: (`torch.device`, *optional*):
|
| 267 |
+
torch device
|
| 268 |
+
dtype: (`torch.dtype`, *optional*):
|
| 269 |
+
torch dtype
|
| 270 |
+
"""
|
| 271 |
+
device = device or self._execution_device
|
| 272 |
+
|
| 273 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 274 |
+
if prompt is not None:
|
| 275 |
+
batch_size = len(prompt)
|
| 276 |
+
else:
|
| 277 |
+
batch_size = prompt_embeds.shape[0]
|
| 278 |
+
|
| 279 |
+
if prompt_embeds is None:
|
| 280 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 281 |
+
prompt=prompt,
|
| 282 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 283 |
+
max_sequence_length=max_sequence_length,
|
| 284 |
+
device=device,
|
| 285 |
+
dtype=dtype,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 289 |
+
negative_prompt = negative_prompt or ""
|
| 290 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 291 |
+
|
| 292 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 293 |
+
raise TypeError(
|
| 294 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 295 |
+
f" {type(prompt)}."
|
| 296 |
+
)
|
| 297 |
+
elif batch_size != len(negative_prompt):
|
| 298 |
+
raise ValueError(
|
| 299 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 300 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 301 |
+
" the batch size of `prompt`."
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 305 |
+
prompt=negative_prompt,
|
| 306 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 307 |
+
max_sequence_length=max_sequence_length,
|
| 308 |
+
device=device,
|
| 309 |
+
dtype=dtype,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return prompt_embeds, negative_prompt_embeds
|
| 313 |
+
|
| 314 |
+
def prepare_latents(
|
| 315 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 316 |
+
):
|
| 317 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 318 |
+
raise ValueError(
|
| 319 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 320 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
shape = (
|
| 324 |
+
batch_size,
|
| 325 |
+
num_channels_latents,
|
| 326 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1,
|
| 327 |
+
height // self.vae.spatial_compression_ratio,
|
| 328 |
+
width // self.vae.spatial_compression_ratio,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if latents is None:
|
| 332 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 333 |
+
else:
|
| 334 |
+
latents = latents.to(device)
|
| 335 |
+
|
| 336 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 337 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 338 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 339 |
+
return latents
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def prepare_mask_latents(
|
| 343 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
|
| 344 |
+
):
|
| 345 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
| 346 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 347 |
+
# and half precision
|
| 348 |
+
|
| 349 |
+
if mask is not None:
|
| 350 |
+
mask = mask.to(device=device, dtype=self.vae.dtype)
|
| 351 |
+
bs = 1
|
| 352 |
+
new_mask = []
|
| 353 |
+
for i in range(0, mask.shape[0], bs):
|
| 354 |
+
mask_bs = mask[i : i + bs]
|
| 355 |
+
mask_bs = self.vae.encode(mask_bs)[0]
|
| 356 |
+
mask_bs = mask_bs.mode()
|
| 357 |
+
new_mask.append(mask_bs)
|
| 358 |
+
mask = torch.cat(new_mask, dim = 0)
|
| 359 |
+
# mask = mask * self.vae.config.scaling_factor
|
| 360 |
+
|
| 361 |
+
if masked_image is not None:
|
| 362 |
+
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
| 363 |
+
bs = 1
|
| 364 |
+
new_mask_pixel_values = []
|
| 365 |
+
for i in range(0, masked_image.shape[0], bs):
|
| 366 |
+
mask_pixel_values_bs = masked_image[i : i + bs]
|
| 367 |
+
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
| 368 |
+
mask_pixel_values_bs = mask_pixel_values_bs.mode()
|
| 369 |
+
new_mask_pixel_values.append(mask_pixel_values_bs)
|
| 370 |
+
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
| 371 |
+
# masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
| 372 |
+
else:
|
| 373 |
+
masked_image_latents = None
|
| 374 |
+
|
| 375 |
+
return mask, masked_image_latents
|
| 376 |
+
|
| 377 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 378 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 379 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 380 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 381 |
+
frames = frames.cpu().float().numpy()
|
| 382 |
+
return frames
|
| 383 |
+
|
| 384 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 385 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 386 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 387 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 388 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 389 |
+
# and should be between [0, 1]
|
| 390 |
+
|
| 391 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 392 |
+
extra_step_kwargs = {}
|
| 393 |
+
if accepts_eta:
|
| 394 |
+
extra_step_kwargs["eta"] = eta
|
| 395 |
+
|
| 396 |
+
# check if the scheduler accepts generator
|
| 397 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 398 |
+
if accepts_generator:
|
| 399 |
+
extra_step_kwargs["generator"] = generator
|
| 400 |
+
return extra_step_kwargs
|
| 401 |
+
|
| 402 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 403 |
+
def check_inputs(
|
| 404 |
+
self,
|
| 405 |
+
prompt,
|
| 406 |
+
height,
|
| 407 |
+
width,
|
| 408 |
+
negative_prompt,
|
| 409 |
+
callback_on_step_end_tensor_inputs,
|
| 410 |
+
prompt_embeds=None,
|
| 411 |
+
negative_prompt_embeds=None,
|
| 412 |
+
):
|
| 413 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 414 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 415 |
+
|
| 416 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 417 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 418 |
+
):
|
| 419 |
+
raise ValueError(
|
| 420 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 421 |
+
)
|
| 422 |
+
if prompt is not None and prompt_embeds is not None:
|
| 423 |
+
raise ValueError(
|
| 424 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 425 |
+
" only forward one of the two."
|
| 426 |
+
)
|
| 427 |
+
elif prompt is None and prompt_embeds is None:
|
| 428 |
+
raise ValueError(
|
| 429 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 430 |
+
)
|
| 431 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 432 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 433 |
+
|
| 434 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 435 |
+
raise ValueError(
|
| 436 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 437 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 441 |
+
raise ValueError(
|
| 442 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 443 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 447 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 448 |
+
raise ValueError(
|
| 449 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 450 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 451 |
+
f" {negative_prompt_embeds.shape}."
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
@property
|
| 455 |
+
def guidance_scale(self):
|
| 456 |
+
return self._guidance_scale
|
| 457 |
+
|
| 458 |
+
@property
|
| 459 |
+
def num_timesteps(self):
|
| 460 |
+
return self._num_timesteps
|
| 461 |
+
|
| 462 |
+
@property
|
| 463 |
+
def attention_kwargs(self):
|
| 464 |
+
return self._attention_kwargs
|
| 465 |
+
|
| 466 |
+
@property
|
| 467 |
+
def interrupt(self):
|
| 468 |
+
return self._interrupt
|
| 469 |
+
|
| 470 |
+
@torch.no_grad()
|
| 471 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 472 |
+
def __call__(
|
| 473 |
+
self,
|
| 474 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 475 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 476 |
+
height: int = 480,
|
| 477 |
+
width: int = 720,
|
| 478 |
+
video: Union[torch.FloatTensor] = None,
|
| 479 |
+
mask_video: Union[torch.FloatTensor] = None,
|
| 480 |
+
num_frames: int = 49,
|
| 481 |
+
num_inference_steps: int = 50,
|
| 482 |
+
timesteps: Optional[List[int]] = None,
|
| 483 |
+
guidance_scale: float = 6,
|
| 484 |
+
num_videos_per_prompt: int = 1,
|
| 485 |
+
eta: float = 0.0,
|
| 486 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 487 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 488 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 489 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 490 |
+
output_type: str = "numpy",
|
| 491 |
+
return_dict: bool = False,
|
| 492 |
+
callback_on_step_end: Optional[
|
| 493 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 494 |
+
] = None,
|
| 495 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 496 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 497 |
+
max_sequence_length: int = 512,
|
| 498 |
+
boundary: float = 0.875,
|
| 499 |
+
comfyui_progressbar: bool = False,
|
| 500 |
+
shift: int = 5,
|
| 501 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 502 |
+
"""
|
| 503 |
+
Function invoked when calling the pipeline for generation.
|
| 504 |
+
Args:
|
| 505 |
+
|
| 506 |
+
Examples:
|
| 507 |
+
|
| 508 |
+
Returns:
|
| 509 |
+
|
| 510 |
+
"""
|
| 511 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 512 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 513 |
+
num_videos_per_prompt = 1
|
| 514 |
+
|
| 515 |
+
# 1. Check inputs. Raise error if not correct
|
| 516 |
+
self.check_inputs(
|
| 517 |
+
prompt,
|
| 518 |
+
height,
|
| 519 |
+
width,
|
| 520 |
+
negative_prompt,
|
| 521 |
+
callback_on_step_end_tensor_inputs,
|
| 522 |
+
prompt_embeds,
|
| 523 |
+
negative_prompt_embeds,
|
| 524 |
+
)
|
| 525 |
+
self._guidance_scale = guidance_scale
|
| 526 |
+
self._attention_kwargs = attention_kwargs
|
| 527 |
+
self._interrupt = False
|
| 528 |
+
|
| 529 |
+
# 2. Default call parameters
|
| 530 |
+
if prompt is not None and isinstance(prompt, str):
|
| 531 |
+
batch_size = 1
|
| 532 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 533 |
+
batch_size = len(prompt)
|
| 534 |
+
else:
|
| 535 |
+
batch_size = prompt_embeds.shape[0]
|
| 536 |
+
|
| 537 |
+
device = self._execution_device
|
| 538 |
+
weight_dtype = self.text_encoder.dtype
|
| 539 |
+
|
| 540 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 541 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 542 |
+
# corresponds to doing no classifier free guidance.
|
| 543 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 544 |
+
|
| 545 |
+
# 3. Encode input prompt
|
| 546 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 547 |
+
prompt,
|
| 548 |
+
negative_prompt,
|
| 549 |
+
do_classifier_free_guidance,
|
| 550 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 551 |
+
prompt_embeds=prompt_embeds,
|
| 552 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 553 |
+
max_sequence_length=max_sequence_length,
|
| 554 |
+
device=device,
|
| 555 |
+
)
|
| 556 |
+
if do_classifier_free_guidance:
|
| 557 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 558 |
+
else:
|
| 559 |
+
in_prompt_embeds = prompt_embeds
|
| 560 |
+
|
| 561 |
+
# 4. Prepare timesteps
|
| 562 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 563 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
| 564 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 565 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 566 |
+
timesteps = self.scheduler.timesteps
|
| 567 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 568 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 569 |
+
timesteps, _ = retrieve_timesteps(
|
| 570 |
+
self.scheduler,
|
| 571 |
+
device=device,
|
| 572 |
+
sigmas=sampling_sigmas)
|
| 573 |
+
else:
|
| 574 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 575 |
+
self._num_timesteps = len(timesteps)
|
| 576 |
+
if comfyui_progressbar:
|
| 577 |
+
from comfy.utils import ProgressBar
|
| 578 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 579 |
+
|
| 580 |
+
# 5. Prepare latents.
|
| 581 |
+
if video is not None:
|
| 582 |
+
video_length = video.shape[2]
|
| 583 |
+
init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 584 |
+
init_video = init_video.to(dtype=torch.float32)
|
| 585 |
+
init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 586 |
+
else:
|
| 587 |
+
init_video = None
|
| 588 |
+
|
| 589 |
+
latent_channels = self.vae.config.latent_channels
|
| 590 |
+
latents = self.prepare_latents(
|
| 591 |
+
batch_size * num_videos_per_prompt,
|
| 592 |
+
latent_channels,
|
| 593 |
+
num_frames,
|
| 594 |
+
height,
|
| 595 |
+
width,
|
| 596 |
+
weight_dtype,
|
| 597 |
+
device,
|
| 598 |
+
generator,
|
| 599 |
+
latents,
|
| 600 |
+
)
|
| 601 |
+
if comfyui_progressbar:
|
| 602 |
+
pbar.update(1)
|
| 603 |
+
|
| 604 |
+
# Prepare mask latent variables
|
| 605 |
+
if init_video is not None:
|
| 606 |
+
if (mask_video == 255).all():
|
| 607 |
+
mask_latents = torch.tile(
|
| 608 |
+
torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1]
|
| 609 |
+
)
|
| 610 |
+
masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
|
| 611 |
+
if self.vae.spatial_compression_ratio >= 16:
|
| 612 |
+
mask = torch.ones_like(latents).to(device, weight_dtype)[:, :1].to(device, weight_dtype)
|
| 613 |
+
else:
|
| 614 |
+
bs, _, video_length, height, width = video.size()
|
| 615 |
+
mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 616 |
+
mask_condition = mask_condition.to(dtype=torch.float32)
|
| 617 |
+
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
|
| 618 |
+
|
| 619 |
+
masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5)
|
| 620 |
+
_, masked_video_latents = self.prepare_mask_latents(
|
| 621 |
+
None,
|
| 622 |
+
masked_video,
|
| 623 |
+
batch_size,
|
| 624 |
+
height,
|
| 625 |
+
width,
|
| 626 |
+
weight_dtype,
|
| 627 |
+
device,
|
| 628 |
+
generator,
|
| 629 |
+
do_classifier_free_guidance,
|
| 630 |
+
noise_aug_strength=None,
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
mask_condition = torch.concat(
|
| 634 |
+
[
|
| 635 |
+
torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
|
| 636 |
+
mask_condition[:, :, 1:]
|
| 637 |
+
], dim=2
|
| 638 |
+
)
|
| 639 |
+
mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width)
|
| 640 |
+
mask_condition = mask_condition.transpose(1, 2)
|
| 641 |
+
mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype)
|
| 642 |
+
|
| 643 |
+
if self.vae.spatial_compression_ratio >= 16:
|
| 644 |
+
mask = F.interpolate(mask_condition[:, :1], size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, weight_dtype)
|
| 645 |
+
if not mask[:, :, 0, :, :].any():
|
| 646 |
+
mask[:, :, 1:, :, :] = 1
|
| 647 |
+
latents = (1 - mask) * masked_video_latents + mask * latents
|
| 648 |
+
|
| 649 |
+
if comfyui_progressbar:
|
| 650 |
+
pbar.update(1)
|
| 651 |
+
|
| 652 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 653 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 654 |
+
|
| 655 |
+
target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
|
| 656 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 657 |
+
# 7. Denoising loop
|
| 658 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 659 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 660 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 661 |
+
for i, t in enumerate(timesteps):
|
| 662 |
+
self.transformer.current_steps = i
|
| 663 |
+
|
| 664 |
+
if self.interrupt:
|
| 665 |
+
continue
|
| 666 |
+
|
| 667 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 668 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 669 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 670 |
+
|
| 671 |
+
if init_video is not None:
|
| 672 |
+
mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
|
| 673 |
+
masked_video_latents_input = (
|
| 674 |
+
torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
|
| 675 |
+
)
|
| 676 |
+
y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype)
|
| 677 |
+
|
| 678 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 679 |
+
if self.vae.spatial_compression_ratio >= 16 and init_video is not None:
|
| 680 |
+
temp_ts = ((mask[0][0][:, ::2, ::2]) * t).flatten()
|
| 681 |
+
temp_ts = torch.cat([
|
| 682 |
+
temp_ts,
|
| 683 |
+
temp_ts.new_ones(seq_len - temp_ts.size(0)) * t
|
| 684 |
+
])
|
| 685 |
+
temp_ts = temp_ts.unsqueeze(0)
|
| 686 |
+
timestep = temp_ts.expand(latent_model_input.shape[0], temp_ts.size(1))
|
| 687 |
+
else:
|
| 688 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 689 |
+
|
| 690 |
+
if self.transformer_2 is not None:
|
| 691 |
+
if t >= boundary * self.scheduler.config.num_train_timesteps:
|
| 692 |
+
local_transformer = self.transformer_2
|
| 693 |
+
else:
|
| 694 |
+
local_transformer = self.transformer
|
| 695 |
+
else:
|
| 696 |
+
local_transformer = self.transformer
|
| 697 |
+
|
| 698 |
+
# predict noise model_output
|
| 699 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 700 |
+
noise_pred = local_transformer(
|
| 701 |
+
x=latent_model_input,
|
| 702 |
+
context=in_prompt_embeds,
|
| 703 |
+
t=timestep,
|
| 704 |
+
seq_len=seq_len,
|
| 705 |
+
y=y,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
# perform guidance
|
| 709 |
+
if do_classifier_free_guidance:
|
| 710 |
+
if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))):
|
| 711 |
+
sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0]
|
| 712 |
+
else:
|
| 713 |
+
sample_guide_scale = self.guidance_scale
|
| 714 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 715 |
+
noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond)
|
| 716 |
+
|
| 717 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 718 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 719 |
+
|
| 720 |
+
if self.vae.spatial_compression_ratio >= 16 and not mask[:, :, 0, :, :].any():
|
| 721 |
+
latents = (1 - mask) * masked_video_latents + mask * latents
|
| 722 |
+
|
| 723 |
+
if callback_on_step_end is not None:
|
| 724 |
+
callback_kwargs = {}
|
| 725 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 726 |
+
callback_kwargs[k] = locals()[k]
|
| 727 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 728 |
+
|
| 729 |
+
latents = callback_outputs.pop("latents", latents)
|
| 730 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 731 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 732 |
+
|
| 733 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 734 |
+
progress_bar.update()
|
| 735 |
+
if comfyui_progressbar:
|
| 736 |
+
pbar.update(1)
|
| 737 |
+
|
| 738 |
+
if output_type == "numpy":
|
| 739 |
+
video = self.decode_latents(latents)
|
| 740 |
+
elif not output_type == "latent":
|
| 741 |
+
video = self.decode_latents(latents)
|
| 742 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 743 |
+
else:
|
| 744 |
+
video = latents
|
| 745 |
+
|
| 746 |
+
# Offload all models
|
| 747 |
+
self.maybe_free_model_hooks()
|
| 748 |
+
|
| 749 |
+
if not return_dict:
|
| 750 |
+
video = torch.from_numpy(video)
|
| 751 |
+
|
| 752 |
+
return WanPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_wan2_2_s2v.py
ADDED
|
@@ -0,0 +1,815 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
import copy
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 11 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 12 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 13 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 14 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 15 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 16 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 17 |
+
from diffusers.video_processor import VideoProcessor
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from torchvision import transforms
|
| 21 |
+
from transformers import T5Tokenizer
|
| 22 |
+
|
| 23 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
|
| 24 |
+
Wan2_2Transformer3DModel_S2V, WanAudioEncoder,
|
| 25 |
+
WanT5EncoderModel)
|
| 26 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 27 |
+
get_sampling_sigmas)
|
| 28 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
EXAMPLE_DOC_STRING = """
|
| 34 |
+
Examples:
|
| 35 |
+
```python
|
| 36 |
+
pass
|
| 37 |
+
```
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 42 |
+
def retrieve_timesteps(
|
| 43 |
+
scheduler,
|
| 44 |
+
num_inference_steps: Optional[int] = None,
|
| 45 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 46 |
+
timesteps: Optional[List[int]] = None,
|
| 47 |
+
sigmas: Optional[List[float]] = None,
|
| 48 |
+
**kwargs,
|
| 49 |
+
):
|
| 50 |
+
"""
|
| 51 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 52 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
scheduler (`SchedulerMixin`):
|
| 56 |
+
The scheduler to get timesteps from.
|
| 57 |
+
num_inference_steps (`int`):
|
| 58 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 59 |
+
must be `None`.
|
| 60 |
+
device (`str` or `torch.device`, *optional*):
|
| 61 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 62 |
+
timesteps (`List[int]`, *optional*):
|
| 63 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 64 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 65 |
+
sigmas (`List[float]`, *optional*):
|
| 66 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 67 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 71 |
+
second element is the number of inference steps.
|
| 72 |
+
"""
|
| 73 |
+
if timesteps is not None and sigmas is not None:
|
| 74 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 75 |
+
if timesteps is not None:
|
| 76 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 77 |
+
if not accepts_timesteps:
|
| 78 |
+
raise ValueError(
|
| 79 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 80 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 81 |
+
)
|
| 82 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 83 |
+
timesteps = scheduler.timesteps
|
| 84 |
+
num_inference_steps = len(timesteps)
|
| 85 |
+
elif sigmas is not None:
|
| 86 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 87 |
+
if not accept_sigmas:
|
| 88 |
+
raise ValueError(
|
| 89 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 90 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 91 |
+
)
|
| 92 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 93 |
+
timesteps = scheduler.timesteps
|
| 94 |
+
num_inference_steps = len(timesteps)
|
| 95 |
+
else:
|
| 96 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 97 |
+
timesteps = scheduler.timesteps
|
| 98 |
+
return timesteps, num_inference_steps
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def resize_mask(mask, latent, process_first_frame_only=True):
|
| 102 |
+
latent_size = latent.size()
|
| 103 |
+
batch_size, channels, num_frames, height, width = mask.shape
|
| 104 |
+
|
| 105 |
+
if process_first_frame_only:
|
| 106 |
+
target_size = list(latent_size[2:])
|
| 107 |
+
target_size[0] = 1
|
| 108 |
+
first_frame_resized = F.interpolate(
|
| 109 |
+
mask[:, :, 0:1, :, :],
|
| 110 |
+
size=target_size,
|
| 111 |
+
mode='trilinear',
|
| 112 |
+
align_corners=False
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
target_size = list(latent_size[2:])
|
| 116 |
+
target_size[0] = target_size[0] - 1
|
| 117 |
+
if target_size[0] != 0:
|
| 118 |
+
remaining_frames_resized = F.interpolate(
|
| 119 |
+
mask[:, :, 1:, :, :],
|
| 120 |
+
size=target_size,
|
| 121 |
+
mode='trilinear',
|
| 122 |
+
align_corners=False
|
| 123 |
+
)
|
| 124 |
+
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
| 125 |
+
else:
|
| 126 |
+
resized_mask = first_frame_resized
|
| 127 |
+
else:
|
| 128 |
+
target_size = list(latent_size[2:])
|
| 129 |
+
resized_mask = F.interpolate(
|
| 130 |
+
mask,
|
| 131 |
+
size=target_size,
|
| 132 |
+
mode='trilinear',
|
| 133 |
+
align_corners=False
|
| 134 |
+
)
|
| 135 |
+
return resized_mask
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclass
|
| 139 |
+
class WanPipelineOutput(BaseOutput):
|
| 140 |
+
r"""
|
| 141 |
+
Output class for CogVideo pipelines.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 145 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 146 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 147 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
videos: torch.Tensor
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class Wan2_2S2VPipeline(DiffusionPipeline):
|
| 154 |
+
r"""
|
| 155 |
+
Pipeline for text-to-video generation using Wan.
|
| 156 |
+
|
| 157 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 158 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
_optional_components = ["transformer_2", "audio_encoder"]
|
| 162 |
+
model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
|
| 163 |
+
|
| 164 |
+
_callback_tensor_inputs = [
|
| 165 |
+
"latents",
|
| 166 |
+
"prompt_embeds",
|
| 167 |
+
"negative_prompt_embeds",
|
| 168 |
+
]
|
| 169 |
+
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
tokenizer: AutoTokenizer,
|
| 173 |
+
text_encoder: WanT5EncoderModel,
|
| 174 |
+
audio_encoder: WanAudioEncoder,
|
| 175 |
+
vae: AutoencoderKLWan,
|
| 176 |
+
transformer: Wan2_2Transformer3DModel_S2V,
|
| 177 |
+
transformer_2: Wan2_2Transformer3DModel_S2V = None,
|
| 178 |
+
scheduler: FlowMatchEulerDiscreteScheduler = None,
|
| 179 |
+
):
|
| 180 |
+
super().__init__()
|
| 181 |
+
|
| 182 |
+
self.register_modules(
|
| 183 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
|
| 184 |
+
transformer_2=transformer_2, scheduler=scheduler, audio_encoder=audio_encoder
|
| 185 |
+
)
|
| 186 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 187 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 188 |
+
self.mask_processor = VaeImageProcessor(
|
| 189 |
+
vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 190 |
+
)
|
| 191 |
+
self.motion_frames = 73
|
| 192 |
+
self.audio_sample_m = 0
|
| 193 |
+
self.drop_first_motion = True
|
| 194 |
+
|
| 195 |
+
def _get_t5_prompt_embeds(
|
| 196 |
+
self,
|
| 197 |
+
prompt: Union[str, List[str]] = None,
|
| 198 |
+
num_videos_per_prompt: int = 1,
|
| 199 |
+
max_sequence_length: int = 512,
|
| 200 |
+
device: Optional[torch.device] = None,
|
| 201 |
+
dtype: Optional[torch.dtype] = None,
|
| 202 |
+
):
|
| 203 |
+
device = device or self._execution_device
|
| 204 |
+
dtype = dtype or self.text_encoder.dtype
|
| 205 |
+
|
| 206 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 207 |
+
batch_size = len(prompt)
|
| 208 |
+
|
| 209 |
+
text_inputs = self.tokenizer(
|
| 210 |
+
prompt,
|
| 211 |
+
padding="max_length",
|
| 212 |
+
max_length=max_sequence_length,
|
| 213 |
+
truncation=True,
|
| 214 |
+
add_special_tokens=True,
|
| 215 |
+
return_tensors="pt",
|
| 216 |
+
)
|
| 217 |
+
text_input_ids = text_inputs.input_ids
|
| 218 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 219 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 220 |
+
|
| 221 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 222 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 223 |
+
logger.warning(
|
| 224 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 225 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 229 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 230 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 231 |
+
|
| 232 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 233 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 234 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 235 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 236 |
+
|
| 237 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 238 |
+
|
| 239 |
+
def encode_prompt(
|
| 240 |
+
self,
|
| 241 |
+
prompt: Union[str, List[str]],
|
| 242 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 243 |
+
do_classifier_free_guidance: bool = True,
|
| 244 |
+
num_videos_per_prompt: int = 1,
|
| 245 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 246 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 247 |
+
max_sequence_length: int = 512,
|
| 248 |
+
device: Optional[torch.device] = None,
|
| 249 |
+
dtype: Optional[torch.dtype] = None,
|
| 250 |
+
):
|
| 251 |
+
r"""
|
| 252 |
+
Encodes the prompt into text encoder hidden states.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 256 |
+
prompt to be encoded
|
| 257 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 258 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 259 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 260 |
+
less than `1`).
|
| 261 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 262 |
+
Whether to use classifier free guidance or not.
|
| 263 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 264 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 265 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 266 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 267 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 268 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 269 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 270 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 271 |
+
argument.
|
| 272 |
+
device: (`torch.device`, *optional*):
|
| 273 |
+
torch device
|
| 274 |
+
dtype: (`torch.dtype`, *optional*):
|
| 275 |
+
torch dtype
|
| 276 |
+
"""
|
| 277 |
+
device = device or self._execution_device
|
| 278 |
+
|
| 279 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 280 |
+
if prompt is not None:
|
| 281 |
+
batch_size = len(prompt)
|
| 282 |
+
else:
|
| 283 |
+
batch_size = prompt_embeds.shape[0]
|
| 284 |
+
|
| 285 |
+
if prompt_embeds is None:
|
| 286 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 287 |
+
prompt=prompt,
|
| 288 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 289 |
+
max_sequence_length=max_sequence_length,
|
| 290 |
+
device=device,
|
| 291 |
+
dtype=dtype,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 295 |
+
negative_prompt = negative_prompt or ""
|
| 296 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 297 |
+
|
| 298 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 299 |
+
raise TypeError(
|
| 300 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 301 |
+
f" {type(prompt)}."
|
| 302 |
+
)
|
| 303 |
+
elif batch_size != len(negative_prompt):
|
| 304 |
+
raise ValueError(
|
| 305 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 306 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 307 |
+
" the batch size of `prompt`."
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 311 |
+
prompt=negative_prompt,
|
| 312 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 313 |
+
max_sequence_length=max_sequence_length,
|
| 314 |
+
device=device,
|
| 315 |
+
dtype=dtype,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
return prompt_embeds, negative_prompt_embeds
|
| 319 |
+
|
| 320 |
+
def encode_audio_embeddings(self, audio_path, num_frames, fps, weight_dtype, device):
|
| 321 |
+
z = self.audio_encoder.extract_audio_feat(
|
| 322 |
+
audio_path, return_all_layers=True)
|
| 323 |
+
audio_embed_bucket, num_repeat = self.audio_encoder.get_audio_embed_bucket_fps(
|
| 324 |
+
z, fps=fps, batch_frames=num_frames, m=self.audio_sample_m)
|
| 325 |
+
audio_embed_bucket = audio_embed_bucket.to(device,
|
| 326 |
+
weight_dtype)
|
| 327 |
+
audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
|
| 328 |
+
if len(audio_embed_bucket.shape) == 3:
|
| 329 |
+
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
|
| 330 |
+
elif len(audio_embed_bucket.shape) == 4:
|
| 331 |
+
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
|
| 332 |
+
return audio_embed_bucket, num_repeat
|
| 333 |
+
|
| 334 |
+
def encode_pose_latents(self, pose_video, num_repeat, num_frames, size, fps, weight_dtype, device):
|
| 335 |
+
height, width = size
|
| 336 |
+
if not pose_video is None:
|
| 337 |
+
padding_frame_num = num_repeat * num_frames - pose_video.shape[2]
|
| 338 |
+
pose_video = torch.cat(
|
| 339 |
+
[
|
| 340 |
+
pose_video,
|
| 341 |
+
-torch.ones([1, 3, padding_frame_num, height, width])
|
| 342 |
+
],
|
| 343 |
+
dim=2
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
cond_tensors = torch.chunk(pose_video, num_repeat, dim=2)
|
| 347 |
+
else:
|
| 348 |
+
cond_tensors = [-torch.ones([1, 3, num_frames, height, width])]
|
| 349 |
+
|
| 350 |
+
pose_latents = []
|
| 351 |
+
for r in range(len(cond_tensors)):
|
| 352 |
+
cond = cond_tensors[r]
|
| 353 |
+
cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond],
|
| 354 |
+
dim=2)
|
| 355 |
+
cond_lat = self.vae.encode(cond.to(dtype=weight_dtype, device=device))[0].mode()[:, :, 1:]
|
| 356 |
+
pose_latents.append(cond_lat)
|
| 357 |
+
return pose_latents
|
| 358 |
+
|
| 359 |
+
def prepare_latents(
|
| 360 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None
|
| 361 |
+
):
|
| 362 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 363 |
+
raise ValueError(
|
| 364 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 365 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
shape = (
|
| 369 |
+
batch_size,
|
| 370 |
+
num_channels_latents,
|
| 371 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents,
|
| 372 |
+
height // self.vae.spatial_compression_ratio,
|
| 373 |
+
width // self.vae.spatial_compression_ratio,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
if latents is None:
|
| 377 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 378 |
+
else:
|
| 379 |
+
latents = latents.to(device)
|
| 380 |
+
|
| 381 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 382 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 383 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 384 |
+
return latents
|
| 385 |
+
|
| 386 |
+
def prepare_control_latents(
|
| 387 |
+
self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
| 388 |
+
):
|
| 389 |
+
# resize the control to latents shape as we concatenate the control to the latents
|
| 390 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 391 |
+
# and half precision
|
| 392 |
+
|
| 393 |
+
if control is not None:
|
| 394 |
+
control = control.to(device=device, dtype=dtype)
|
| 395 |
+
bs = 1
|
| 396 |
+
new_control = []
|
| 397 |
+
for i in range(0, control.shape[0], bs):
|
| 398 |
+
control_bs = control[i : i + bs]
|
| 399 |
+
control_bs = self.vae.encode(control_bs)[0]
|
| 400 |
+
control_bs = control_bs.mode()
|
| 401 |
+
new_control.append(control_bs)
|
| 402 |
+
control = torch.cat(new_control, dim = 0)
|
| 403 |
+
|
| 404 |
+
if control_image is not None:
|
| 405 |
+
control_image = control_image.to(device=device, dtype=dtype)
|
| 406 |
+
bs = 1
|
| 407 |
+
new_control_pixel_values = []
|
| 408 |
+
for i in range(0, control_image.shape[0], bs):
|
| 409 |
+
control_pixel_values_bs = control_image[i : i + bs]
|
| 410 |
+
control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
|
| 411 |
+
control_pixel_values_bs = control_pixel_values_bs.mode()
|
| 412 |
+
new_control_pixel_values.append(control_pixel_values_bs)
|
| 413 |
+
control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
|
| 414 |
+
else:
|
| 415 |
+
control_image_latents = None
|
| 416 |
+
|
| 417 |
+
return control, control_image_latents
|
| 418 |
+
|
| 419 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 420 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 421 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 422 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 423 |
+
# frames = frames.cpu().float().numpy()
|
| 424 |
+
return frames
|
| 425 |
+
|
| 426 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 427 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 428 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 429 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 430 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 431 |
+
# and should be between [0, 1]
|
| 432 |
+
|
| 433 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 434 |
+
extra_step_kwargs = {}
|
| 435 |
+
if accepts_eta:
|
| 436 |
+
extra_step_kwargs["eta"] = eta
|
| 437 |
+
|
| 438 |
+
# check if the scheduler accepts generator
|
| 439 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 440 |
+
if accepts_generator:
|
| 441 |
+
extra_step_kwargs["generator"] = generator
|
| 442 |
+
return extra_step_kwargs
|
| 443 |
+
|
| 444 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 445 |
+
def check_inputs(
|
| 446 |
+
self,
|
| 447 |
+
prompt,
|
| 448 |
+
height,
|
| 449 |
+
width,
|
| 450 |
+
negative_prompt,
|
| 451 |
+
callback_on_step_end_tensor_inputs,
|
| 452 |
+
prompt_embeds=None,
|
| 453 |
+
negative_prompt_embeds=None,
|
| 454 |
+
):
|
| 455 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 456 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 457 |
+
|
| 458 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 459 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 460 |
+
):
|
| 461 |
+
raise ValueError(
|
| 462 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 463 |
+
)
|
| 464 |
+
if prompt is not None and prompt_embeds is not None:
|
| 465 |
+
raise ValueError(
|
| 466 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 467 |
+
" only forward one of the two."
|
| 468 |
+
)
|
| 469 |
+
elif prompt is None and prompt_embeds is None:
|
| 470 |
+
raise ValueError(
|
| 471 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 472 |
+
)
|
| 473 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 474 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 475 |
+
|
| 476 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 477 |
+
raise ValueError(
|
| 478 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 479 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 483 |
+
raise ValueError(
|
| 484 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 485 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 489 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 490 |
+
raise ValueError(
|
| 491 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 492 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 493 |
+
f" {negative_prompt_embeds.shape}."
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
@property
|
| 497 |
+
def guidance_scale(self):
|
| 498 |
+
return self._guidance_scale
|
| 499 |
+
|
| 500 |
+
@property
|
| 501 |
+
def num_timesteps(self):
|
| 502 |
+
return self._num_timesteps
|
| 503 |
+
|
| 504 |
+
@property
|
| 505 |
+
def attention_kwargs(self):
|
| 506 |
+
return self._attention_kwargs
|
| 507 |
+
|
| 508 |
+
@property
|
| 509 |
+
def interrupt(self):
|
| 510 |
+
return self._interrupt
|
| 511 |
+
|
| 512 |
+
@torch.no_grad()
|
| 513 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 514 |
+
def __call__(
|
| 515 |
+
self,
|
| 516 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 517 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 518 |
+
height: int = 480,
|
| 519 |
+
width: int = 720,
|
| 520 |
+
ref_image: Union[torch.FloatTensor] = None,
|
| 521 |
+
audio_path = None,
|
| 522 |
+
pose_video = None,
|
| 523 |
+
num_frames: int = 49,
|
| 524 |
+
num_inference_steps: int = 50,
|
| 525 |
+
timesteps: Optional[List[int]] = None,
|
| 526 |
+
guidance_scale: float = 6,
|
| 527 |
+
num_videos_per_prompt: int = 1,
|
| 528 |
+
eta: float = 0.0,
|
| 529 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 530 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 531 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 532 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 533 |
+
output_type: str = "numpy",
|
| 534 |
+
return_dict: bool = False,
|
| 535 |
+
callback_on_step_end: Optional[
|
| 536 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 537 |
+
] = None,
|
| 538 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 539 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 540 |
+
max_sequence_length: int = 512,
|
| 541 |
+
boundary: float = 0.875,
|
| 542 |
+
comfyui_progressbar: bool = False,
|
| 543 |
+
shift: int = 5,
|
| 544 |
+
fps: int = 16,
|
| 545 |
+
init_first_frame: bool = False,
|
| 546 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 547 |
+
"""
|
| 548 |
+
Function invoked when calling the pipeline for generation.
|
| 549 |
+
Args:
|
| 550 |
+
|
| 551 |
+
Examples:
|
| 552 |
+
|
| 553 |
+
Returns:
|
| 554 |
+
|
| 555 |
+
"""
|
| 556 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 557 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 558 |
+
num_videos_per_prompt = 1
|
| 559 |
+
|
| 560 |
+
# 1. Check inputs. Raise error if not correct
|
| 561 |
+
self.check_inputs(
|
| 562 |
+
prompt,
|
| 563 |
+
height,
|
| 564 |
+
width,
|
| 565 |
+
negative_prompt,
|
| 566 |
+
callback_on_step_end_tensor_inputs,
|
| 567 |
+
prompt_embeds,
|
| 568 |
+
negative_prompt_embeds,
|
| 569 |
+
)
|
| 570 |
+
self._guidance_scale = guidance_scale
|
| 571 |
+
self._attention_kwargs = attention_kwargs
|
| 572 |
+
self._interrupt = False
|
| 573 |
+
|
| 574 |
+
# 2. Default call parameters
|
| 575 |
+
if prompt is not None and isinstance(prompt, str):
|
| 576 |
+
batch_size = 1
|
| 577 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 578 |
+
batch_size = len(prompt)
|
| 579 |
+
else:
|
| 580 |
+
batch_size = prompt_embeds.shape[0]
|
| 581 |
+
|
| 582 |
+
device = self._execution_device
|
| 583 |
+
weight_dtype = self.text_encoder.dtype
|
| 584 |
+
|
| 585 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 586 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 587 |
+
# corresponds to doing no classifier free guidance.
|
| 588 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 589 |
+
|
| 590 |
+
lat_motion_frames = (self.motion_frames + 3) // 4
|
| 591 |
+
lat_target_frames = (num_frames + 3 + self.motion_frames) // 4 - lat_motion_frames
|
| 592 |
+
|
| 593 |
+
# 3. Encode input prompt
|
| 594 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 595 |
+
prompt,
|
| 596 |
+
negative_prompt,
|
| 597 |
+
do_classifier_free_guidance,
|
| 598 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 599 |
+
prompt_embeds=prompt_embeds,
|
| 600 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 601 |
+
max_sequence_length=max_sequence_length,
|
| 602 |
+
device=device,
|
| 603 |
+
)
|
| 604 |
+
if do_classifier_free_guidance:
|
| 605 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 606 |
+
else:
|
| 607 |
+
in_prompt_embeds = prompt_embeds
|
| 608 |
+
|
| 609 |
+
if comfyui_progressbar:
|
| 610 |
+
from comfy.utils import ProgressBar
|
| 611 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 612 |
+
|
| 613 |
+
# 5. Prepare latents.
|
| 614 |
+
latent_channels = self.vae.config.latent_channels
|
| 615 |
+
if comfyui_progressbar:
|
| 616 |
+
pbar.update(1)
|
| 617 |
+
|
| 618 |
+
video_length = ref_image.shape[2]
|
| 619 |
+
ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 620 |
+
ref_image = ref_image.to(dtype=torch.float32)
|
| 621 |
+
ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length)
|
| 622 |
+
|
| 623 |
+
ref_image_latentes = self.prepare_control_latents(
|
| 624 |
+
None,
|
| 625 |
+
ref_image,
|
| 626 |
+
batch_size,
|
| 627 |
+
height,
|
| 628 |
+
width,
|
| 629 |
+
weight_dtype,
|
| 630 |
+
device,
|
| 631 |
+
generator,
|
| 632 |
+
do_classifier_free_guidance
|
| 633 |
+
)[1]
|
| 634 |
+
ref_image_latentes = ref_image_latentes[:, :, :1]
|
| 635 |
+
|
| 636 |
+
# Extract audio emb
|
| 637 |
+
audio_emb, num_repeat = self.encode_audio_embeddings(
|
| 638 |
+
audio_path, num_frames=num_frames, fps=fps, weight_dtype=weight_dtype, device=device
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
# Encode the motion latents
|
| 642 |
+
motion_latents = torch.zeros(
|
| 643 |
+
[1, 3, self.motion_frames, height, width],
|
| 644 |
+
dtype=weight_dtype,
|
| 645 |
+
device=device
|
| 646 |
+
)
|
| 647 |
+
videos_last_frames = motion_latents.detach()
|
| 648 |
+
drop_first_motion = self.drop_first_motion
|
| 649 |
+
if init_first_frame:
|
| 650 |
+
drop_first_motion = False
|
| 651 |
+
motion_latents[:, :, -6:] = ref_image
|
| 652 |
+
motion_latents = self.vae.encode(motion_latents)[0].mode()
|
| 653 |
+
|
| 654 |
+
# Get pose cond input if need
|
| 655 |
+
if pose_video is not None:
|
| 656 |
+
video_length = pose_video.shape[2]
|
| 657 |
+
pose_video = self.image_processor.preprocess(rearrange(pose_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 658 |
+
pose_video = pose_video.to(dtype=torch.float32)
|
| 659 |
+
pose_video = rearrange(pose_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 660 |
+
pose_latents = self.encode_pose_latents(
|
| 661 |
+
pose_video=pose_video,
|
| 662 |
+
num_repeat=num_repeat,
|
| 663 |
+
num_frames=num_frames,
|
| 664 |
+
size=(height, width),
|
| 665 |
+
fps=fps,
|
| 666 |
+
weight_dtype=weight_dtype,
|
| 667 |
+
device=device
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
if comfyui_progressbar:
|
| 671 |
+
pbar.update(1)
|
| 672 |
+
|
| 673 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 674 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 675 |
+
|
| 676 |
+
videos = []
|
| 677 |
+
copy_timesteps = copy.deepcopy(timesteps)
|
| 678 |
+
copy_latents = copy.deepcopy(latents)
|
| 679 |
+
for r in range(num_repeat):
|
| 680 |
+
# Prepare timesteps
|
| 681 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 682 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps, mu=1)
|
| 683 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 684 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 685 |
+
timesteps = self.scheduler.timesteps
|
| 686 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 687 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 688 |
+
timesteps, _ = retrieve_timesteps(
|
| 689 |
+
self.scheduler,
|
| 690 |
+
device=device,
|
| 691 |
+
sigmas=sampling_sigmas)
|
| 692 |
+
else:
|
| 693 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps)
|
| 694 |
+
self._num_timesteps = len(timesteps)
|
| 695 |
+
|
| 696 |
+
target_shape = (self.vae.latent_channels, lat_target_frames, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
|
| 697 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 698 |
+
|
| 699 |
+
latents = self.prepare_latents(
|
| 700 |
+
batch_size * num_videos_per_prompt,
|
| 701 |
+
latent_channels,
|
| 702 |
+
num_frames,
|
| 703 |
+
height,
|
| 704 |
+
width,
|
| 705 |
+
weight_dtype,
|
| 706 |
+
device,
|
| 707 |
+
generator,
|
| 708 |
+
copy_latents,
|
| 709 |
+
num_length_latents=target_shape[1]
|
| 710 |
+
)
|
| 711 |
+
# 7. Denoising loop
|
| 712 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 713 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 714 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 715 |
+
for i, t in enumerate(timesteps):
|
| 716 |
+
self.transformer.current_steps = i
|
| 717 |
+
|
| 718 |
+
if self.interrupt:
|
| 719 |
+
continue
|
| 720 |
+
|
| 721 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 722 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 723 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 724 |
+
|
| 725 |
+
with torch.no_grad():
|
| 726 |
+
left_idx = r * num_frames
|
| 727 |
+
right_idx = r * num_frames + num_frames
|
| 728 |
+
cond_latents = pose_latents[r] if pose_video is not None else pose_latents[0] * 0
|
| 729 |
+
cond_latents = cond_latents.to(dtype=weight_dtype, device=device)
|
| 730 |
+
audio_input = audio_emb[..., left_idx:right_idx]
|
| 731 |
+
|
| 732 |
+
pose_latents_input = torch.cat([cond_latents] * 2) if do_classifier_free_guidance else cond_latents
|
| 733 |
+
motion_latents_input = torch.cat([motion_latents] * 2) if do_classifier_free_guidance else motion_latents
|
| 734 |
+
audio_emb_input = torch.cat([audio_input * 0] + [audio_input]) if do_classifier_free_guidance else audio_input
|
| 735 |
+
ref_image_latentes_input = torch.cat([ref_image_latentes] * 2) if do_classifier_free_guidance else ref_image_latentes
|
| 736 |
+
motion_frames=[[self.motion_frames, (self.motion_frames + 3) // 4]] * 2 if do_classifier_free_guidance else [[self.motion_frames, (self.motion_frames + 3) // 4]]
|
| 737 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 738 |
+
|
| 739 |
+
if self.transformer_2 is not None:
|
| 740 |
+
if t >= boundary * self.scheduler.config.num_train_timesteps:
|
| 741 |
+
local_transformer = self.transformer_2
|
| 742 |
+
else:
|
| 743 |
+
local_transformer = self.transformer
|
| 744 |
+
else:
|
| 745 |
+
local_transformer = self.transformer
|
| 746 |
+
|
| 747 |
+
# predict noise model_output
|
| 748 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 749 |
+
noise_pred = local_transformer(
|
| 750 |
+
x=latent_model_input,
|
| 751 |
+
context=in_prompt_embeds,
|
| 752 |
+
t=timestep,
|
| 753 |
+
seq_len=seq_len,
|
| 754 |
+
cond_states=pose_latents_input,
|
| 755 |
+
motion_latents=motion_latents_input,
|
| 756 |
+
ref_latents=ref_image_latentes_input,
|
| 757 |
+
audio_input=audio_emb_input,
|
| 758 |
+
motion_frames=motion_frames,
|
| 759 |
+
drop_motion_frames=drop_first_motion and r == 0,
|
| 760 |
+
)
|
| 761 |
+
# perform guidance
|
| 762 |
+
if do_classifier_free_guidance:
|
| 763 |
+
if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))):
|
| 764 |
+
sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0]
|
| 765 |
+
else:
|
| 766 |
+
sample_guide_scale = self.guidance_scale
|
| 767 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 768 |
+
noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond)
|
| 769 |
+
|
| 770 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 771 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 772 |
+
|
| 773 |
+
if callback_on_step_end is not None:
|
| 774 |
+
callback_kwargs = {}
|
| 775 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 776 |
+
callback_kwargs[k] = locals()[k]
|
| 777 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 778 |
+
|
| 779 |
+
latents = callback_outputs.pop("latents", latents)
|
| 780 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 781 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 782 |
+
|
| 783 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 784 |
+
progress_bar.update()
|
| 785 |
+
if comfyui_progressbar:
|
| 786 |
+
pbar.update(1)
|
| 787 |
+
|
| 788 |
+
if not (drop_first_motion and r == 0):
|
| 789 |
+
decode_latents = torch.cat([motion_latents, latents], dim=2)
|
| 790 |
+
else:
|
| 791 |
+
decode_latents = torch.cat([ref_image_latentes, latents], dim=2)
|
| 792 |
+
|
| 793 |
+
image = self.vae.decode(decode_latents).sample
|
| 794 |
+
image = image[:, :, -(num_frames):]
|
| 795 |
+
if (drop_first_motion and r == 0):
|
| 796 |
+
image = image[:, :, 3:]
|
| 797 |
+
|
| 798 |
+
overlap_frames_num = min(self.motion_frames, image.shape[2])
|
| 799 |
+
videos_last_frames = torch.cat(
|
| 800 |
+
[
|
| 801 |
+
videos_last_frames[:, :, overlap_frames_num:],
|
| 802 |
+
image[:, :, -overlap_frames_num:]
|
| 803 |
+
],
|
| 804 |
+
dim=2
|
| 805 |
+
).to(dtype=motion_latents.dtype, device=motion_latents.device)
|
| 806 |
+
motion_latents = self.vae.encode(videos_last_frames)[0].mode()
|
| 807 |
+
videos.append(image)
|
| 808 |
+
|
| 809 |
+
videos = torch.cat(videos, dim=2)
|
| 810 |
+
videos = (videos / 2 + 0.5).clamp(0, 1)
|
| 811 |
+
|
| 812 |
+
# Offload all models
|
| 813 |
+
self.maybe_free_model_hooks()
|
| 814 |
+
|
| 815 |
+
return WanPipelineOutput(videos=videos.float().cpu())
|
videox_fun/pipeline/pipeline_wan2_2_ti2v.py
ADDED
|
@@ -0,0 +1,732 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 11 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 12 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 13 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 14 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 15 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 16 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 17 |
+
from diffusers.video_processor import VideoProcessor
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from transformers import T5Tokenizer
|
| 21 |
+
|
| 22 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
|
| 23 |
+
WanT5EncoderModel, Wan2_2Transformer3DModel)
|
| 24 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 25 |
+
get_sampling_sigmas)
|
| 26 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
EXAMPLE_DOC_STRING = """
|
| 32 |
+
Examples:
|
| 33 |
+
```python
|
| 34 |
+
pass
|
| 35 |
+
```
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 40 |
+
def retrieve_timesteps(
|
| 41 |
+
scheduler,
|
| 42 |
+
num_inference_steps: Optional[int] = None,
|
| 43 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 44 |
+
timesteps: Optional[List[int]] = None,
|
| 45 |
+
sigmas: Optional[List[float]] = None,
|
| 46 |
+
**kwargs,
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 50 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
scheduler (`SchedulerMixin`):
|
| 54 |
+
The scheduler to get timesteps from.
|
| 55 |
+
num_inference_steps (`int`):
|
| 56 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 57 |
+
must be `None`.
|
| 58 |
+
device (`str` or `torch.device`, *optional*):
|
| 59 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 60 |
+
timesteps (`List[int]`, *optional*):
|
| 61 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 62 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 63 |
+
sigmas (`List[float]`, *optional*):
|
| 64 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 65 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 69 |
+
second element is the number of inference steps.
|
| 70 |
+
"""
|
| 71 |
+
if timesteps is not None and sigmas is not None:
|
| 72 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 73 |
+
if timesteps is not None:
|
| 74 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 75 |
+
if not accepts_timesteps:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 78 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 79 |
+
)
|
| 80 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 81 |
+
timesteps = scheduler.timesteps
|
| 82 |
+
num_inference_steps = len(timesteps)
|
| 83 |
+
elif sigmas is not None:
|
| 84 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 85 |
+
if not accept_sigmas:
|
| 86 |
+
raise ValueError(
|
| 87 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 88 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 89 |
+
)
|
| 90 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 91 |
+
timesteps = scheduler.timesteps
|
| 92 |
+
num_inference_steps = len(timesteps)
|
| 93 |
+
else:
|
| 94 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 95 |
+
timesteps = scheduler.timesteps
|
| 96 |
+
return timesteps, num_inference_steps
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def resize_mask(mask, latent, process_first_frame_only=True):
|
| 100 |
+
latent_size = latent.size()
|
| 101 |
+
batch_size, channels, num_frames, height, width = mask.shape
|
| 102 |
+
|
| 103 |
+
if process_first_frame_only:
|
| 104 |
+
target_size = list(latent_size[2:])
|
| 105 |
+
target_size[0] = 1
|
| 106 |
+
first_frame_resized = F.interpolate(
|
| 107 |
+
mask[:, :, 0:1, :, :],
|
| 108 |
+
size=target_size,
|
| 109 |
+
mode='trilinear',
|
| 110 |
+
align_corners=False
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
target_size = list(latent_size[2:])
|
| 114 |
+
target_size[0] = target_size[0] - 1
|
| 115 |
+
if target_size[0] != 0:
|
| 116 |
+
remaining_frames_resized = F.interpolate(
|
| 117 |
+
mask[:, :, 1:, :, :],
|
| 118 |
+
size=target_size,
|
| 119 |
+
mode='trilinear',
|
| 120 |
+
align_corners=False
|
| 121 |
+
)
|
| 122 |
+
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
| 123 |
+
else:
|
| 124 |
+
resized_mask = first_frame_resized
|
| 125 |
+
else:
|
| 126 |
+
target_size = list(latent_size[2:])
|
| 127 |
+
resized_mask = F.interpolate(
|
| 128 |
+
mask,
|
| 129 |
+
size=target_size,
|
| 130 |
+
mode='trilinear',
|
| 131 |
+
align_corners=False
|
| 132 |
+
)
|
| 133 |
+
return resized_mask
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@dataclass
|
| 137 |
+
class WanPipelineOutput(BaseOutput):
|
| 138 |
+
r"""
|
| 139 |
+
Output class for CogVideo pipelines.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 143 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 144 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 145 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
videos: torch.Tensor
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class Wan2_2TI2VPipeline(DiffusionPipeline):
|
| 152 |
+
r"""
|
| 153 |
+
Pipeline for text-to-video generation using Wan.
|
| 154 |
+
|
| 155 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 156 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
_optional_components = ["transformer_2"]
|
| 160 |
+
model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
|
| 161 |
+
|
| 162 |
+
_callback_tensor_inputs = [
|
| 163 |
+
"latents",
|
| 164 |
+
"prompt_embeds",
|
| 165 |
+
"negative_prompt_embeds",
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
def __init__(
|
| 169 |
+
self,
|
| 170 |
+
tokenizer: AutoTokenizer,
|
| 171 |
+
text_encoder: WanT5EncoderModel,
|
| 172 |
+
vae: AutoencoderKLWan,
|
| 173 |
+
transformer: Wan2_2Transformer3DModel,
|
| 174 |
+
transformer_2: Wan2_2Transformer3DModel = None,
|
| 175 |
+
scheduler: FlowMatchEulerDiscreteScheduler = None,
|
| 176 |
+
):
|
| 177 |
+
super().__init__()
|
| 178 |
+
|
| 179 |
+
self.register_modules(
|
| 180 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
|
| 181 |
+
transformer_2=transformer_2, scheduler=scheduler
|
| 182 |
+
)
|
| 183 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 184 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 185 |
+
self.mask_processor = VaeImageProcessor(
|
| 186 |
+
vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def _get_t5_prompt_embeds(
|
| 190 |
+
self,
|
| 191 |
+
prompt: Union[str, List[str]] = None,
|
| 192 |
+
num_videos_per_prompt: int = 1,
|
| 193 |
+
max_sequence_length: int = 512,
|
| 194 |
+
device: Optional[torch.device] = None,
|
| 195 |
+
dtype: Optional[torch.dtype] = None,
|
| 196 |
+
):
|
| 197 |
+
device = device or self._execution_device
|
| 198 |
+
dtype = dtype or self.text_encoder.dtype
|
| 199 |
+
|
| 200 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 201 |
+
batch_size = len(prompt)
|
| 202 |
+
|
| 203 |
+
text_inputs = self.tokenizer(
|
| 204 |
+
prompt,
|
| 205 |
+
padding="max_length",
|
| 206 |
+
max_length=max_sequence_length,
|
| 207 |
+
truncation=True,
|
| 208 |
+
add_special_tokens=True,
|
| 209 |
+
return_tensors="pt",
|
| 210 |
+
)
|
| 211 |
+
text_input_ids = text_inputs.input_ids
|
| 212 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 213 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 214 |
+
|
| 215 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 216 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 217 |
+
logger.warning(
|
| 218 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 219 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 223 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 224 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 225 |
+
|
| 226 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 227 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 228 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 229 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 230 |
+
|
| 231 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 232 |
+
|
| 233 |
+
def encode_prompt(
|
| 234 |
+
self,
|
| 235 |
+
prompt: Union[str, List[str]],
|
| 236 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 237 |
+
do_classifier_free_guidance: bool = True,
|
| 238 |
+
num_videos_per_prompt: int = 1,
|
| 239 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 240 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 241 |
+
max_sequence_length: int = 512,
|
| 242 |
+
device: Optional[torch.device] = None,
|
| 243 |
+
dtype: Optional[torch.dtype] = None,
|
| 244 |
+
):
|
| 245 |
+
r"""
|
| 246 |
+
Encodes the prompt into text encoder hidden states.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 250 |
+
prompt to be encoded
|
| 251 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 252 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 253 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 254 |
+
less than `1`).
|
| 255 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 256 |
+
Whether to use classifier free guidance or not.
|
| 257 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 258 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 259 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 260 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 261 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 262 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 263 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 264 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 265 |
+
argument.
|
| 266 |
+
device: (`torch.device`, *optional*):
|
| 267 |
+
torch device
|
| 268 |
+
dtype: (`torch.dtype`, *optional*):
|
| 269 |
+
torch dtype
|
| 270 |
+
"""
|
| 271 |
+
device = device or self._execution_device
|
| 272 |
+
|
| 273 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 274 |
+
if prompt is not None:
|
| 275 |
+
batch_size = len(prompt)
|
| 276 |
+
else:
|
| 277 |
+
batch_size = prompt_embeds.shape[0]
|
| 278 |
+
|
| 279 |
+
if prompt_embeds is None:
|
| 280 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 281 |
+
prompt=prompt,
|
| 282 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 283 |
+
max_sequence_length=max_sequence_length,
|
| 284 |
+
device=device,
|
| 285 |
+
dtype=dtype,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 289 |
+
negative_prompt = negative_prompt or ""
|
| 290 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 291 |
+
|
| 292 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 293 |
+
raise TypeError(
|
| 294 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 295 |
+
f" {type(prompt)}."
|
| 296 |
+
)
|
| 297 |
+
elif batch_size != len(negative_prompt):
|
| 298 |
+
raise ValueError(
|
| 299 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 300 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 301 |
+
" the batch size of `prompt`."
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 305 |
+
prompt=negative_prompt,
|
| 306 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 307 |
+
max_sequence_length=max_sequence_length,
|
| 308 |
+
device=device,
|
| 309 |
+
dtype=dtype,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return prompt_embeds, negative_prompt_embeds
|
| 313 |
+
|
| 314 |
+
def prepare_latents(
|
| 315 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 316 |
+
):
|
| 317 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 318 |
+
raise ValueError(
|
| 319 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 320 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
shape = (
|
| 324 |
+
batch_size,
|
| 325 |
+
num_channels_latents,
|
| 326 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1,
|
| 327 |
+
height // self.vae.spatial_compression_ratio,
|
| 328 |
+
width // self.vae.spatial_compression_ratio,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if latents is None:
|
| 332 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 333 |
+
else:
|
| 334 |
+
latents = latents.to(device)
|
| 335 |
+
|
| 336 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 337 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 338 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 339 |
+
return latents
|
| 340 |
+
|
| 341 |
+
def prepare_mask_latents(
|
| 342 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
|
| 343 |
+
):
|
| 344 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
| 345 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 346 |
+
# and half precision
|
| 347 |
+
|
| 348 |
+
if mask is not None:
|
| 349 |
+
mask = mask.to(device=device, dtype=self.vae.dtype)
|
| 350 |
+
bs = 1
|
| 351 |
+
new_mask = []
|
| 352 |
+
for i in range(0, mask.shape[0], bs):
|
| 353 |
+
mask_bs = mask[i : i + bs]
|
| 354 |
+
mask_bs = self.vae.encode(mask_bs)[0]
|
| 355 |
+
mask_bs = mask_bs.mode()
|
| 356 |
+
new_mask.append(mask_bs)
|
| 357 |
+
mask = torch.cat(new_mask, dim = 0)
|
| 358 |
+
# mask = mask * self.vae.config.scaling_factor
|
| 359 |
+
|
| 360 |
+
if masked_image is not None:
|
| 361 |
+
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
| 362 |
+
bs = 1
|
| 363 |
+
new_mask_pixel_values = []
|
| 364 |
+
for i in range(0, masked_image.shape[0], bs):
|
| 365 |
+
mask_pixel_values_bs = masked_image[i : i + bs]
|
| 366 |
+
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
| 367 |
+
mask_pixel_values_bs = mask_pixel_values_bs.mode()
|
| 368 |
+
new_mask_pixel_values.append(mask_pixel_values_bs)
|
| 369 |
+
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
| 370 |
+
# masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
| 371 |
+
else:
|
| 372 |
+
masked_image_latents = None
|
| 373 |
+
|
| 374 |
+
return mask, masked_image_latents
|
| 375 |
+
|
| 376 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 377 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 378 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 379 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 380 |
+
frames = frames.cpu().float().numpy()
|
| 381 |
+
return frames
|
| 382 |
+
|
| 383 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 384 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 385 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 386 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 387 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 388 |
+
# and should be between [0, 1]
|
| 389 |
+
|
| 390 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 391 |
+
extra_step_kwargs = {}
|
| 392 |
+
if accepts_eta:
|
| 393 |
+
extra_step_kwargs["eta"] = eta
|
| 394 |
+
|
| 395 |
+
# check if the scheduler accepts generator
|
| 396 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 397 |
+
if accepts_generator:
|
| 398 |
+
extra_step_kwargs["generator"] = generator
|
| 399 |
+
return extra_step_kwargs
|
| 400 |
+
|
| 401 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 402 |
+
def check_inputs(
|
| 403 |
+
self,
|
| 404 |
+
prompt,
|
| 405 |
+
height,
|
| 406 |
+
width,
|
| 407 |
+
negative_prompt,
|
| 408 |
+
callback_on_step_end_tensor_inputs,
|
| 409 |
+
prompt_embeds=None,
|
| 410 |
+
negative_prompt_embeds=None,
|
| 411 |
+
):
|
| 412 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 413 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 414 |
+
|
| 415 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 416 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 417 |
+
):
|
| 418 |
+
raise ValueError(
|
| 419 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 420 |
+
)
|
| 421 |
+
if prompt is not None and prompt_embeds is not None:
|
| 422 |
+
raise ValueError(
|
| 423 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 424 |
+
" only forward one of the two."
|
| 425 |
+
)
|
| 426 |
+
elif prompt is None and prompt_embeds is None:
|
| 427 |
+
raise ValueError(
|
| 428 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 429 |
+
)
|
| 430 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 431 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 432 |
+
|
| 433 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 434 |
+
raise ValueError(
|
| 435 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 436 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 440 |
+
raise ValueError(
|
| 441 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 442 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 446 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 447 |
+
raise ValueError(
|
| 448 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 449 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 450 |
+
f" {negative_prompt_embeds.shape}."
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
@property
|
| 454 |
+
def guidance_scale(self):
|
| 455 |
+
return self._guidance_scale
|
| 456 |
+
|
| 457 |
+
@property
|
| 458 |
+
def num_timesteps(self):
|
| 459 |
+
return self._num_timesteps
|
| 460 |
+
|
| 461 |
+
@property
|
| 462 |
+
def attention_kwargs(self):
|
| 463 |
+
return self._attention_kwargs
|
| 464 |
+
|
| 465 |
+
@property
|
| 466 |
+
def interrupt(self):
|
| 467 |
+
return self._interrupt
|
| 468 |
+
|
| 469 |
+
@torch.no_grad()
|
| 470 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 471 |
+
def __call__(
|
| 472 |
+
self,
|
| 473 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 474 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 475 |
+
height: int = 480,
|
| 476 |
+
width: int = 720,
|
| 477 |
+
video: Union[torch.FloatTensor] = None,
|
| 478 |
+
mask_video: Union[torch.FloatTensor] = None,
|
| 479 |
+
num_frames: int = 49,
|
| 480 |
+
num_inference_steps: int = 50,
|
| 481 |
+
timesteps: Optional[List[int]] = None,
|
| 482 |
+
guidance_scale: float = 6,
|
| 483 |
+
num_videos_per_prompt: int = 1,
|
| 484 |
+
eta: float = 0.0,
|
| 485 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 486 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 487 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 488 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 489 |
+
output_type: str = "numpy",
|
| 490 |
+
return_dict: bool = False,
|
| 491 |
+
callback_on_step_end: Optional[
|
| 492 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 493 |
+
] = None,
|
| 494 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 495 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 496 |
+
max_sequence_length: int = 512,
|
| 497 |
+
boundary: float = 0.875,
|
| 498 |
+
comfyui_progressbar: bool = False,
|
| 499 |
+
shift: int = 5,
|
| 500 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 501 |
+
"""
|
| 502 |
+
Function invoked when calling the pipeline for generation.
|
| 503 |
+
Args:
|
| 504 |
+
|
| 505 |
+
Examples:
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
|
| 509 |
+
"""
|
| 510 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 511 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 512 |
+
num_videos_per_prompt = 1
|
| 513 |
+
|
| 514 |
+
# 1. Check inputs. Raise error if not correct
|
| 515 |
+
self.check_inputs(
|
| 516 |
+
prompt,
|
| 517 |
+
height,
|
| 518 |
+
width,
|
| 519 |
+
negative_prompt,
|
| 520 |
+
callback_on_step_end_tensor_inputs,
|
| 521 |
+
prompt_embeds,
|
| 522 |
+
negative_prompt_embeds,
|
| 523 |
+
)
|
| 524 |
+
self._guidance_scale = guidance_scale
|
| 525 |
+
self._attention_kwargs = attention_kwargs
|
| 526 |
+
self._interrupt = False
|
| 527 |
+
|
| 528 |
+
# 2. Default call parameters
|
| 529 |
+
if prompt is not None and isinstance(prompt, str):
|
| 530 |
+
batch_size = 1
|
| 531 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 532 |
+
batch_size = len(prompt)
|
| 533 |
+
else:
|
| 534 |
+
batch_size = prompt_embeds.shape[0]
|
| 535 |
+
|
| 536 |
+
device = self._execution_device
|
| 537 |
+
weight_dtype = self.text_encoder.dtype
|
| 538 |
+
|
| 539 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 540 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 541 |
+
# corresponds to doing no classifier free guidance.
|
| 542 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 543 |
+
|
| 544 |
+
# 3. Encode input prompt
|
| 545 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 546 |
+
prompt,
|
| 547 |
+
negative_prompt,
|
| 548 |
+
do_classifier_free_guidance,
|
| 549 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 550 |
+
prompt_embeds=prompt_embeds,
|
| 551 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 552 |
+
max_sequence_length=max_sequence_length,
|
| 553 |
+
device=device,
|
| 554 |
+
)
|
| 555 |
+
if do_classifier_free_guidance:
|
| 556 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 557 |
+
else:
|
| 558 |
+
in_prompt_embeds = prompt_embeds
|
| 559 |
+
|
| 560 |
+
# 4. Prepare timesteps
|
| 561 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 562 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
| 563 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 564 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 565 |
+
timesteps = self.scheduler.timesteps
|
| 566 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 567 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 568 |
+
timesteps, _ = retrieve_timesteps(
|
| 569 |
+
self.scheduler,
|
| 570 |
+
device=device,
|
| 571 |
+
sigmas=sampling_sigmas)
|
| 572 |
+
else:
|
| 573 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 574 |
+
self._num_timesteps = len(timesteps)
|
| 575 |
+
if comfyui_progressbar:
|
| 576 |
+
from comfy.utils import ProgressBar
|
| 577 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 578 |
+
|
| 579 |
+
# 5. Prepare latents.
|
| 580 |
+
if video is not None:
|
| 581 |
+
video_length = video.shape[2]
|
| 582 |
+
init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 583 |
+
init_video = init_video.to(dtype=torch.float32)
|
| 584 |
+
init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 585 |
+
else:
|
| 586 |
+
init_video = None
|
| 587 |
+
|
| 588 |
+
latent_channels = self.vae.config.latent_channels
|
| 589 |
+
latents = self.prepare_latents(
|
| 590 |
+
batch_size * num_videos_per_prompt,
|
| 591 |
+
latent_channels,
|
| 592 |
+
num_frames,
|
| 593 |
+
height,
|
| 594 |
+
width,
|
| 595 |
+
weight_dtype,
|
| 596 |
+
device,
|
| 597 |
+
generator,
|
| 598 |
+
latents,
|
| 599 |
+
)
|
| 600 |
+
if comfyui_progressbar:
|
| 601 |
+
pbar.update(1)
|
| 602 |
+
|
| 603 |
+
# Prepare mask latent variables
|
| 604 |
+
if init_video is not None and not (mask_video == 255).all():
|
| 605 |
+
bs, _, video_length, height, width = video.size()
|
| 606 |
+
mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 607 |
+
mask_condition = mask_condition.to(dtype=torch.float32)
|
| 608 |
+
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
|
| 609 |
+
|
| 610 |
+
masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5)
|
| 611 |
+
_, masked_video_latents = self.prepare_mask_latents(
|
| 612 |
+
None,
|
| 613 |
+
masked_video,
|
| 614 |
+
batch_size,
|
| 615 |
+
height,
|
| 616 |
+
width,
|
| 617 |
+
weight_dtype,
|
| 618 |
+
device,
|
| 619 |
+
generator,
|
| 620 |
+
do_classifier_free_guidance,
|
| 621 |
+
noise_aug_strength=None,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
mask_condition = torch.concat(
|
| 625 |
+
[
|
| 626 |
+
torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
|
| 627 |
+
mask_condition[:, :, 1:]
|
| 628 |
+
], dim=2
|
| 629 |
+
)
|
| 630 |
+
mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width)
|
| 631 |
+
mask_condition = mask_condition.transpose(1, 2)
|
| 632 |
+
|
| 633 |
+
mask = F.interpolate(mask_condition[:, :1], size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, weight_dtype)
|
| 634 |
+
latents = (1 - mask) * masked_video_latents + mask * latents
|
| 635 |
+
else:
|
| 636 |
+
init_video = None
|
| 637 |
+
|
| 638 |
+
if comfyui_progressbar:
|
| 639 |
+
pbar.update(1)
|
| 640 |
+
|
| 641 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 642 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 643 |
+
|
| 644 |
+
target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
|
| 645 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 646 |
+
# 7. Denoising loop
|
| 647 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 648 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 649 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 650 |
+
for i, t in enumerate(timesteps):
|
| 651 |
+
self.transformer.current_steps = i
|
| 652 |
+
|
| 653 |
+
if self.interrupt:
|
| 654 |
+
continue
|
| 655 |
+
|
| 656 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 657 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 658 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 659 |
+
|
| 660 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 661 |
+
if init_video is not None:
|
| 662 |
+
temp_ts = ((mask[0][0][:, ::2, ::2]) * t).flatten()
|
| 663 |
+
temp_ts = torch.cat([
|
| 664 |
+
temp_ts,
|
| 665 |
+
temp_ts.new_ones(seq_len - temp_ts.size(0)) * t
|
| 666 |
+
])
|
| 667 |
+
temp_ts = temp_ts.unsqueeze(0)
|
| 668 |
+
timestep = temp_ts.expand(latent_model_input.shape[0], temp_ts.size(1))
|
| 669 |
+
else:
|
| 670 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 671 |
+
|
| 672 |
+
if self.transformer_2 is not None:
|
| 673 |
+
if t >= boundary * self.scheduler.config.num_train_timesteps:
|
| 674 |
+
local_transformer = self.transformer_2
|
| 675 |
+
else:
|
| 676 |
+
local_transformer = self.transformer
|
| 677 |
+
else:
|
| 678 |
+
local_transformer = self.transformer
|
| 679 |
+
|
| 680 |
+
# predict noise model_output
|
| 681 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 682 |
+
noise_pred = local_transformer(
|
| 683 |
+
x=latent_model_input,
|
| 684 |
+
context=in_prompt_embeds,
|
| 685 |
+
t=timestep,
|
| 686 |
+
seq_len=seq_len,
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
# perform guidance
|
| 690 |
+
if do_classifier_free_guidance:
|
| 691 |
+
if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))):
|
| 692 |
+
sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0]
|
| 693 |
+
else:
|
| 694 |
+
sample_guide_scale = self.guidance_scale
|
| 695 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 696 |
+
noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond)
|
| 697 |
+
|
| 698 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 699 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 700 |
+
if init_video is not None:
|
| 701 |
+
latents = (1 - mask) * masked_video_latents + mask * latents
|
| 702 |
+
|
| 703 |
+
if callback_on_step_end is not None:
|
| 704 |
+
callback_kwargs = {}
|
| 705 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 706 |
+
callback_kwargs[k] = locals()[k]
|
| 707 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 708 |
+
|
| 709 |
+
latents = callback_outputs.pop("latents", latents)
|
| 710 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 711 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 712 |
+
|
| 713 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 714 |
+
progress_bar.update()
|
| 715 |
+
if comfyui_progressbar:
|
| 716 |
+
pbar.update(1)
|
| 717 |
+
|
| 718 |
+
if output_type == "numpy":
|
| 719 |
+
video = self.decode_latents(latents)
|
| 720 |
+
elif not output_type == "latent":
|
| 721 |
+
video = self.decode_latents(latents)
|
| 722 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 723 |
+
else:
|
| 724 |
+
video = latents
|
| 725 |
+
|
| 726 |
+
# Offload all models
|
| 727 |
+
self.maybe_free_model_hooks()
|
| 728 |
+
|
| 729 |
+
if not return_dict:
|
| 730 |
+
video = torch.from_numpy(video)
|
| 731 |
+
|
| 732 |
+
return WanPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_wan2_2_vace_fun.py
ADDED
|
@@ -0,0 +1,801 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 11 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 12 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 13 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 14 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 15 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 16 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 17 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 18 |
+
from diffusers.video_processor import VideoProcessor
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from transformers import T5Tokenizer
|
| 22 |
+
|
| 23 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer,
|
| 24 |
+
WanT5EncoderModel, VaceWanTransformer3DModel)
|
| 25 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 26 |
+
get_sampling_sigmas)
|
| 27 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
EXAMPLE_DOC_STRING = """
|
| 33 |
+
Examples:
|
| 34 |
+
```python
|
| 35 |
+
pass
|
| 36 |
+
```
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 41 |
+
def retrieve_timesteps(
|
| 42 |
+
scheduler,
|
| 43 |
+
num_inference_steps: Optional[int] = None,
|
| 44 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 45 |
+
timesteps: Optional[List[int]] = None,
|
| 46 |
+
sigmas: Optional[List[float]] = None,
|
| 47 |
+
**kwargs,
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 51 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
scheduler (`SchedulerMixin`):
|
| 55 |
+
The scheduler to get timesteps from.
|
| 56 |
+
num_inference_steps (`int`):
|
| 57 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 58 |
+
must be `None`.
|
| 59 |
+
device (`str` or `torch.device`, *optional*):
|
| 60 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 61 |
+
timesteps (`List[int]`, *optional*):
|
| 62 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 63 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 64 |
+
sigmas (`List[float]`, *optional*):
|
| 65 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 66 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 70 |
+
second element is the number of inference steps.
|
| 71 |
+
"""
|
| 72 |
+
if timesteps is not None and sigmas is not None:
|
| 73 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 74 |
+
if timesteps is not None:
|
| 75 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 76 |
+
if not accepts_timesteps:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 79 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 80 |
+
)
|
| 81 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 82 |
+
timesteps = scheduler.timesteps
|
| 83 |
+
num_inference_steps = len(timesteps)
|
| 84 |
+
elif sigmas is not None:
|
| 85 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 86 |
+
if not accept_sigmas:
|
| 87 |
+
raise ValueError(
|
| 88 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 89 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 90 |
+
)
|
| 91 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 92 |
+
timesteps = scheduler.timesteps
|
| 93 |
+
num_inference_steps = len(timesteps)
|
| 94 |
+
else:
|
| 95 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 96 |
+
timesteps = scheduler.timesteps
|
| 97 |
+
return timesteps, num_inference_steps
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def resize_mask(mask, latent, process_first_frame_only=True):
|
| 101 |
+
latent_size = latent.size()
|
| 102 |
+
batch_size, channels, num_frames, height, width = mask.shape
|
| 103 |
+
|
| 104 |
+
if process_first_frame_only:
|
| 105 |
+
target_size = list(latent_size[2:])
|
| 106 |
+
target_size[0] = 1
|
| 107 |
+
first_frame_resized = F.interpolate(
|
| 108 |
+
mask[:, :, 0:1, :, :],
|
| 109 |
+
size=target_size,
|
| 110 |
+
mode='trilinear',
|
| 111 |
+
align_corners=False
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
target_size = list(latent_size[2:])
|
| 115 |
+
target_size[0] = target_size[0] - 1
|
| 116 |
+
if target_size[0] != 0:
|
| 117 |
+
remaining_frames_resized = F.interpolate(
|
| 118 |
+
mask[:, :, 1:, :, :],
|
| 119 |
+
size=target_size,
|
| 120 |
+
mode='trilinear',
|
| 121 |
+
align_corners=False
|
| 122 |
+
)
|
| 123 |
+
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
| 124 |
+
else:
|
| 125 |
+
resized_mask = first_frame_resized
|
| 126 |
+
else:
|
| 127 |
+
target_size = list(latent_size[2:])
|
| 128 |
+
resized_mask = F.interpolate(
|
| 129 |
+
mask,
|
| 130 |
+
size=target_size,
|
| 131 |
+
mode='trilinear',
|
| 132 |
+
align_corners=False
|
| 133 |
+
)
|
| 134 |
+
return resized_mask
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@dataclass
|
| 138 |
+
class WanPipelineOutput(BaseOutput):
|
| 139 |
+
r"""
|
| 140 |
+
Output class for CogVideo pipelines.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 144 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 145 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 146 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
videos: torch.Tensor
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class Wan2_2VaceFunPipeline(DiffusionPipeline):
|
| 153 |
+
r"""
|
| 154 |
+
Pipeline for text-to-video generation using Wan.
|
| 155 |
+
|
| 156 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 157 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
_optional_components = ["transformer_2"]
|
| 161 |
+
model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
|
| 162 |
+
|
| 163 |
+
_callback_tensor_inputs = [
|
| 164 |
+
"latents",
|
| 165 |
+
"prompt_embeds",
|
| 166 |
+
"negative_prompt_embeds",
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
tokenizer: AutoTokenizer,
|
| 172 |
+
text_encoder: WanT5EncoderModel,
|
| 173 |
+
vae: AutoencoderKLWan,
|
| 174 |
+
transformer: VaceWanTransformer3DModel,
|
| 175 |
+
transformer_2: VaceWanTransformer3DModel = None,
|
| 176 |
+
scheduler: FlowMatchEulerDiscreteScheduler = None,
|
| 177 |
+
):
|
| 178 |
+
super().__init__()
|
| 179 |
+
|
| 180 |
+
self.register_modules(
|
| 181 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
|
| 182 |
+
transformer_2=transformer_2, scheduler=scheduler
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 186 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 187 |
+
self.mask_processor = VaeImageProcessor(
|
| 188 |
+
vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def _get_t5_prompt_embeds(
|
| 192 |
+
self,
|
| 193 |
+
prompt: Union[str, List[str]] = None,
|
| 194 |
+
num_videos_per_prompt: int = 1,
|
| 195 |
+
max_sequence_length: int = 512,
|
| 196 |
+
device: Optional[torch.device] = None,
|
| 197 |
+
dtype: Optional[torch.dtype] = None,
|
| 198 |
+
):
|
| 199 |
+
device = device or self._execution_device
|
| 200 |
+
dtype = dtype or self.text_encoder.dtype
|
| 201 |
+
|
| 202 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 203 |
+
batch_size = len(prompt)
|
| 204 |
+
|
| 205 |
+
text_inputs = self.tokenizer(
|
| 206 |
+
prompt,
|
| 207 |
+
padding="max_length",
|
| 208 |
+
max_length=max_sequence_length,
|
| 209 |
+
truncation=True,
|
| 210 |
+
add_special_tokens=True,
|
| 211 |
+
return_tensors="pt",
|
| 212 |
+
)
|
| 213 |
+
text_input_ids = text_inputs.input_ids
|
| 214 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 215 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 216 |
+
|
| 217 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 218 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 219 |
+
logger.warning(
|
| 220 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 221 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 225 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 226 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 227 |
+
|
| 228 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 229 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 230 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 231 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 232 |
+
|
| 233 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 234 |
+
|
| 235 |
+
def encode_prompt(
|
| 236 |
+
self,
|
| 237 |
+
prompt: Union[str, List[str]],
|
| 238 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 239 |
+
do_classifier_free_guidance: bool = True,
|
| 240 |
+
num_videos_per_prompt: int = 1,
|
| 241 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 242 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 243 |
+
max_sequence_length: int = 512,
|
| 244 |
+
device: Optional[torch.device] = None,
|
| 245 |
+
dtype: Optional[torch.dtype] = None,
|
| 246 |
+
):
|
| 247 |
+
r"""
|
| 248 |
+
Encodes the prompt into text encoder hidden states.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 252 |
+
prompt to be encoded
|
| 253 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 254 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 255 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 256 |
+
less than `1`).
|
| 257 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 258 |
+
Whether to use classifier free guidance or not.
|
| 259 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 260 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 261 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 262 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 263 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 264 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 265 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 266 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 267 |
+
argument.
|
| 268 |
+
device: (`torch.device`, *optional*):
|
| 269 |
+
torch device
|
| 270 |
+
dtype: (`torch.dtype`, *optional*):
|
| 271 |
+
torch dtype
|
| 272 |
+
"""
|
| 273 |
+
device = device or self._execution_device
|
| 274 |
+
|
| 275 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 276 |
+
if prompt is not None:
|
| 277 |
+
batch_size = len(prompt)
|
| 278 |
+
else:
|
| 279 |
+
batch_size = prompt_embeds.shape[0]
|
| 280 |
+
|
| 281 |
+
if prompt_embeds is None:
|
| 282 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 283 |
+
prompt=prompt,
|
| 284 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 285 |
+
max_sequence_length=max_sequence_length,
|
| 286 |
+
device=device,
|
| 287 |
+
dtype=dtype,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 291 |
+
negative_prompt = negative_prompt or ""
|
| 292 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 293 |
+
|
| 294 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 295 |
+
raise TypeError(
|
| 296 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 297 |
+
f" {type(prompt)}."
|
| 298 |
+
)
|
| 299 |
+
elif batch_size != len(negative_prompt):
|
| 300 |
+
raise ValueError(
|
| 301 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 302 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 303 |
+
" the batch size of `prompt`."
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 307 |
+
prompt=negative_prompt,
|
| 308 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 309 |
+
max_sequence_length=max_sequence_length,
|
| 310 |
+
device=device,
|
| 311 |
+
dtype=dtype,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
return prompt_embeds, negative_prompt_embeds
|
| 315 |
+
|
| 316 |
+
def prepare_latents(
|
| 317 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None
|
| 318 |
+
):
|
| 319 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 320 |
+
raise ValueError(
|
| 321 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 322 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
shape = (
|
| 326 |
+
batch_size,
|
| 327 |
+
num_channels_latents,
|
| 328 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents,
|
| 329 |
+
height // self.vae.spatial_compression_ratio,
|
| 330 |
+
width // self.vae.spatial_compression_ratio,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
if latents is None:
|
| 334 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 335 |
+
else:
|
| 336 |
+
latents = latents.to(device)
|
| 337 |
+
|
| 338 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 339 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 340 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 341 |
+
return latents
|
| 342 |
+
|
| 343 |
+
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
|
| 344 |
+
vae = self.vae if vae is None else vae
|
| 345 |
+
weight_dtype = frames.dtype
|
| 346 |
+
if ref_images is None:
|
| 347 |
+
ref_images = [None] * len(frames)
|
| 348 |
+
else:
|
| 349 |
+
assert len(frames) == len(ref_images)
|
| 350 |
+
|
| 351 |
+
if masks is None:
|
| 352 |
+
latents = vae.encode(frames)[0].mode()
|
| 353 |
+
else:
|
| 354 |
+
masks = [torch.where(m > 0.5, 1.0, 0.0).to(weight_dtype) for m in masks]
|
| 355 |
+
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
|
| 356 |
+
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
| 357 |
+
inactive = vae.encode(inactive)[0].mode()
|
| 358 |
+
reactive = vae.encode(reactive)[0].mode()
|
| 359 |
+
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
| 360 |
+
|
| 361 |
+
cat_latents = []
|
| 362 |
+
for latent, refs in zip(latents, ref_images):
|
| 363 |
+
if refs is not None:
|
| 364 |
+
if masks is None:
|
| 365 |
+
ref_latent = vae.encode(refs)[0].mode()
|
| 366 |
+
else:
|
| 367 |
+
ref_latent = vae.encode(refs)[0].mode()
|
| 368 |
+
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
|
| 369 |
+
assert all([x.shape[1] == 1 for x in ref_latent])
|
| 370 |
+
latent = torch.cat([*ref_latent, latent], dim=1)
|
| 371 |
+
cat_latents.append(latent)
|
| 372 |
+
return cat_latents
|
| 373 |
+
|
| 374 |
+
def vace_encode_masks(self, masks, ref_images=None, vae_stride=[4, 8, 8]):
|
| 375 |
+
if ref_images is None:
|
| 376 |
+
ref_images = [None] * len(masks)
|
| 377 |
+
else:
|
| 378 |
+
assert len(masks) == len(ref_images)
|
| 379 |
+
|
| 380 |
+
result_masks = []
|
| 381 |
+
for mask, refs in zip(masks, ref_images):
|
| 382 |
+
c, depth, height, width = mask.shape
|
| 383 |
+
new_depth = int((depth + 3) // vae_stride[0])
|
| 384 |
+
height = 2 * (int(height) // (vae_stride[1] * 2))
|
| 385 |
+
width = 2 * (int(width) // (vae_stride[2] * 2))
|
| 386 |
+
|
| 387 |
+
# reshape
|
| 388 |
+
mask = mask[0, :, :, :]
|
| 389 |
+
mask = mask.view(
|
| 390 |
+
depth, height, vae_stride[1], width, vae_stride[1]
|
| 391 |
+
) # depth, height, 8, width, 8
|
| 392 |
+
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
|
| 393 |
+
mask = mask.reshape(
|
| 394 |
+
vae_stride[1] * vae_stride[2], depth, height, width
|
| 395 |
+
) # 8*8, depth, height, width
|
| 396 |
+
|
| 397 |
+
# interpolation
|
| 398 |
+
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
|
| 399 |
+
|
| 400 |
+
if refs is not None:
|
| 401 |
+
length = len(refs)
|
| 402 |
+
mask_pad = torch.zeros_like(mask[:, :length, :, :])
|
| 403 |
+
mask = torch.cat((mask_pad, mask), dim=1)
|
| 404 |
+
result_masks.append(mask)
|
| 405 |
+
return result_masks
|
| 406 |
+
|
| 407 |
+
def vace_latent(self, z, m):
|
| 408 |
+
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
| 409 |
+
|
| 410 |
+
def prepare_control_latents(
|
| 411 |
+
self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
| 412 |
+
):
|
| 413 |
+
# resize the control to latents shape as we concatenate the control to the latents
|
| 414 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 415 |
+
# and half precision
|
| 416 |
+
|
| 417 |
+
if control is not None:
|
| 418 |
+
control = control.to(device=device, dtype=dtype)
|
| 419 |
+
bs = 1
|
| 420 |
+
new_control = []
|
| 421 |
+
for i in range(0, control.shape[0], bs):
|
| 422 |
+
control_bs = control[i : i + bs]
|
| 423 |
+
control_bs = self.vae.encode(control_bs)[0]
|
| 424 |
+
control_bs = control_bs.mode()
|
| 425 |
+
new_control.append(control_bs)
|
| 426 |
+
control = torch.cat(new_control, dim = 0)
|
| 427 |
+
|
| 428 |
+
if control_image is not None:
|
| 429 |
+
control_image = control_image.to(device=device, dtype=dtype)
|
| 430 |
+
bs = 1
|
| 431 |
+
new_control_pixel_values = []
|
| 432 |
+
for i in range(0, control_image.shape[0], bs):
|
| 433 |
+
control_pixel_values_bs = control_image[i : i + bs]
|
| 434 |
+
control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
|
| 435 |
+
control_pixel_values_bs = control_pixel_values_bs.mode()
|
| 436 |
+
new_control_pixel_values.append(control_pixel_values_bs)
|
| 437 |
+
control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
|
| 438 |
+
else:
|
| 439 |
+
control_image_latents = None
|
| 440 |
+
|
| 441 |
+
return control, control_image_latents
|
| 442 |
+
|
| 443 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 444 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 445 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 446 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 447 |
+
frames = frames.cpu().float().numpy()
|
| 448 |
+
return frames
|
| 449 |
+
|
| 450 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 451 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 452 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 453 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 454 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 455 |
+
# and should be between [0, 1]
|
| 456 |
+
|
| 457 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 458 |
+
extra_step_kwargs = {}
|
| 459 |
+
if accepts_eta:
|
| 460 |
+
extra_step_kwargs["eta"] = eta
|
| 461 |
+
|
| 462 |
+
# check if the scheduler accepts generator
|
| 463 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 464 |
+
if accepts_generator:
|
| 465 |
+
extra_step_kwargs["generator"] = generator
|
| 466 |
+
return extra_step_kwargs
|
| 467 |
+
|
| 468 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 469 |
+
def check_inputs(
|
| 470 |
+
self,
|
| 471 |
+
prompt,
|
| 472 |
+
height,
|
| 473 |
+
width,
|
| 474 |
+
negative_prompt,
|
| 475 |
+
callback_on_step_end_tensor_inputs,
|
| 476 |
+
prompt_embeds=None,
|
| 477 |
+
negative_prompt_embeds=None,
|
| 478 |
+
):
|
| 479 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 480 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 481 |
+
|
| 482 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 483 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 484 |
+
):
|
| 485 |
+
raise ValueError(
|
| 486 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 487 |
+
)
|
| 488 |
+
if prompt is not None and prompt_embeds is not None:
|
| 489 |
+
raise ValueError(
|
| 490 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 491 |
+
" only forward one of the two."
|
| 492 |
+
)
|
| 493 |
+
elif prompt is None and prompt_embeds is None:
|
| 494 |
+
raise ValueError(
|
| 495 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 496 |
+
)
|
| 497 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 498 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 499 |
+
|
| 500 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 501 |
+
raise ValueError(
|
| 502 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 503 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 507 |
+
raise ValueError(
|
| 508 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 509 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 513 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 514 |
+
raise ValueError(
|
| 515 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 516 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 517 |
+
f" {negative_prompt_embeds.shape}."
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
@property
|
| 521 |
+
def guidance_scale(self):
|
| 522 |
+
return self._guidance_scale
|
| 523 |
+
|
| 524 |
+
@property
|
| 525 |
+
def num_timesteps(self):
|
| 526 |
+
return self._num_timesteps
|
| 527 |
+
|
| 528 |
+
@property
|
| 529 |
+
def attention_kwargs(self):
|
| 530 |
+
return self._attention_kwargs
|
| 531 |
+
|
| 532 |
+
@property
|
| 533 |
+
def interrupt(self):
|
| 534 |
+
return self._interrupt
|
| 535 |
+
|
| 536 |
+
@torch.no_grad()
|
| 537 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 538 |
+
def __call__(
|
| 539 |
+
self,
|
| 540 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 541 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 542 |
+
height: int = 480,
|
| 543 |
+
width: int = 720,
|
| 544 |
+
video: Union[torch.FloatTensor] = None,
|
| 545 |
+
mask_video: Union[torch.FloatTensor] = None,
|
| 546 |
+
control_video: Union[torch.FloatTensor] = None,
|
| 547 |
+
subject_ref_images: Union[torch.FloatTensor] = None,
|
| 548 |
+
num_frames: int = 49,
|
| 549 |
+
num_inference_steps: int = 50,
|
| 550 |
+
timesteps: Optional[List[int]] = None,
|
| 551 |
+
guidance_scale: float = 6,
|
| 552 |
+
num_videos_per_prompt: int = 1,
|
| 553 |
+
eta: float = 0.0,
|
| 554 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 555 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 556 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 557 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 558 |
+
output_type: str = "numpy",
|
| 559 |
+
return_dict: bool = False,
|
| 560 |
+
callback_on_step_end: Optional[
|
| 561 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 562 |
+
] = None,
|
| 563 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 564 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 565 |
+
max_sequence_length: int = 512,
|
| 566 |
+
boundary: float = 0.875,
|
| 567 |
+
comfyui_progressbar: bool = False,
|
| 568 |
+
shift: int = 5,
|
| 569 |
+
vace_context_scale: float = 1.0,
|
| 570 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 571 |
+
"""
|
| 572 |
+
Function invoked when calling the pipeline for generation.
|
| 573 |
+
Args:
|
| 574 |
+
|
| 575 |
+
Examples:
|
| 576 |
+
|
| 577 |
+
Returns:
|
| 578 |
+
|
| 579 |
+
"""
|
| 580 |
+
|
| 581 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 582 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 583 |
+
num_videos_per_prompt = 1
|
| 584 |
+
|
| 585 |
+
# 1. Check inputs. Raise error if not correct
|
| 586 |
+
self.check_inputs(
|
| 587 |
+
prompt,
|
| 588 |
+
height,
|
| 589 |
+
width,
|
| 590 |
+
negative_prompt,
|
| 591 |
+
callback_on_step_end_tensor_inputs,
|
| 592 |
+
prompt_embeds,
|
| 593 |
+
negative_prompt_embeds,
|
| 594 |
+
)
|
| 595 |
+
self._guidance_scale = guidance_scale
|
| 596 |
+
self._attention_kwargs = attention_kwargs
|
| 597 |
+
self._interrupt = False
|
| 598 |
+
|
| 599 |
+
# 2. Default call parameters
|
| 600 |
+
if prompt is not None and isinstance(prompt, str):
|
| 601 |
+
batch_size = 1
|
| 602 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 603 |
+
batch_size = len(prompt)
|
| 604 |
+
else:
|
| 605 |
+
batch_size = prompt_embeds.shape[0]
|
| 606 |
+
|
| 607 |
+
device = self._execution_device
|
| 608 |
+
weight_dtype = self.text_encoder.dtype
|
| 609 |
+
|
| 610 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 611 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 612 |
+
# corresponds to doing no classifier free guidance.
|
| 613 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 614 |
+
|
| 615 |
+
# 3. Encode input prompt
|
| 616 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 617 |
+
prompt,
|
| 618 |
+
negative_prompt,
|
| 619 |
+
do_classifier_free_guidance,
|
| 620 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 621 |
+
prompt_embeds=prompt_embeds,
|
| 622 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 623 |
+
max_sequence_length=max_sequence_length,
|
| 624 |
+
device=device,
|
| 625 |
+
)
|
| 626 |
+
if do_classifier_free_guidance:
|
| 627 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 628 |
+
else:
|
| 629 |
+
in_prompt_embeds = prompt_embeds
|
| 630 |
+
|
| 631 |
+
# 4. Prepare timesteps
|
| 632 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 633 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
| 634 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 635 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 636 |
+
timesteps = self.scheduler.timesteps
|
| 637 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 638 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 639 |
+
timesteps, _ = retrieve_timesteps(
|
| 640 |
+
self.scheduler,
|
| 641 |
+
device=device,
|
| 642 |
+
sigmas=sampling_sigmas)
|
| 643 |
+
else:
|
| 644 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 645 |
+
self._num_timesteps = len(timesteps)
|
| 646 |
+
if comfyui_progressbar:
|
| 647 |
+
from comfy.utils import ProgressBar
|
| 648 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 649 |
+
|
| 650 |
+
latent_channels = self.vae.config.latent_channels
|
| 651 |
+
|
| 652 |
+
if comfyui_progressbar:
|
| 653 |
+
pbar.update(1)
|
| 654 |
+
|
| 655 |
+
# Prepare mask latent variables
|
| 656 |
+
if mask_video is not None:
|
| 657 |
+
bs, _, video_length, height, width = video.size()
|
| 658 |
+
mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 659 |
+
mask_condition = mask_condition.to(dtype=torch.float32)
|
| 660 |
+
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
|
| 661 |
+
mask_condition = torch.tile(mask_condition, [1, 3, 1, 1, 1]).to(dtype=weight_dtype, device=device)
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
if control_video is not None:
|
| 665 |
+
video_length = control_video.shape[2]
|
| 666 |
+
control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 667 |
+
control_video = control_video.to(dtype=torch.float32)
|
| 668 |
+
input_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 669 |
+
|
| 670 |
+
input_video = input_video.to(dtype=weight_dtype, device=device)
|
| 671 |
+
|
| 672 |
+
elif video is not None:
|
| 673 |
+
video_length = video.shape[2]
|
| 674 |
+
init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 675 |
+
init_video = init_video.to(dtype=torch.float32)
|
| 676 |
+
init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length).to(dtype=weight_dtype, device=device)
|
| 677 |
+
|
| 678 |
+
input_video = init_video * (mask_condition < 0.5)
|
| 679 |
+
input_video = input_video.to(dtype=weight_dtype, device=device)
|
| 680 |
+
|
| 681 |
+
if subject_ref_images is not None:
|
| 682 |
+
video_length = subject_ref_images.shape[2]
|
| 683 |
+
subject_ref_images = self.image_processor.preprocess(rearrange(subject_ref_images, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 684 |
+
subject_ref_images = subject_ref_images.to(dtype=torch.float32)
|
| 685 |
+
subject_ref_images = rearrange(subject_ref_images, "(b f) c h w -> b c f h w", f=video_length)
|
| 686 |
+
subject_ref_images = subject_ref_images.to(dtype=weight_dtype, device=device)
|
| 687 |
+
|
| 688 |
+
bs, c, f, h, w = subject_ref_images.size()
|
| 689 |
+
new_subject_ref_images = []
|
| 690 |
+
for i in range(bs):
|
| 691 |
+
new_subject_ref_images.append([])
|
| 692 |
+
for j in range(f):
|
| 693 |
+
new_subject_ref_images[i].append(subject_ref_images[i, :, j:j+1])
|
| 694 |
+
subject_ref_images = new_subject_ref_images
|
| 695 |
+
|
| 696 |
+
vace_latents = self.vace_encode_frames(input_video, subject_ref_images, masks=mask_condition, vae=self.vae)
|
| 697 |
+
mask_latents = self.vace_encode_masks(mask_condition, subject_ref_images, vae_stride=[4, self.vae.spatial_compression_ratio, self.vae.spatial_compression_ratio])
|
| 698 |
+
vace_context = self.vace_latent(vace_latents, mask_latents)
|
| 699 |
+
|
| 700 |
+
# 5. Prepare latents.
|
| 701 |
+
latents = self.prepare_latents(
|
| 702 |
+
batch_size * num_videos_per_prompt,
|
| 703 |
+
latent_channels,
|
| 704 |
+
num_frames,
|
| 705 |
+
height,
|
| 706 |
+
width,
|
| 707 |
+
weight_dtype,
|
| 708 |
+
device,
|
| 709 |
+
generator,
|
| 710 |
+
latents,
|
| 711 |
+
num_length_latents=vace_latents[0].size(1)
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
if comfyui_progressbar:
|
| 715 |
+
pbar.update(1)
|
| 716 |
+
|
| 717 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 718 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 719 |
+
|
| 720 |
+
target_shape = (self.vae.latent_channels, vace_latents[0].size(1), vace_latents[0].size(2), vace_latents[0].size(3))
|
| 721 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 722 |
+
# 7. Denoising loop
|
| 723 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 724 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 725 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 726 |
+
for i, t in enumerate(timesteps):
|
| 727 |
+
self.transformer.current_steps = i
|
| 728 |
+
|
| 729 |
+
if self.interrupt:
|
| 730 |
+
continue
|
| 731 |
+
|
| 732 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 733 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 734 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 735 |
+
|
| 736 |
+
vace_context_input = torch.stack(vace_context * 2) if do_classifier_free_guidance else vace_context
|
| 737 |
+
|
| 738 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 739 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 740 |
+
|
| 741 |
+
if self.transformer_2 is not None:
|
| 742 |
+
if t >= boundary * self.scheduler.config.num_train_timesteps:
|
| 743 |
+
local_transformer = self.transformer_2
|
| 744 |
+
else:
|
| 745 |
+
local_transformer = self.transformer
|
| 746 |
+
else:
|
| 747 |
+
local_transformer = self.transformer
|
| 748 |
+
|
| 749 |
+
# predict noise model_output
|
| 750 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 751 |
+
noise_pred = local_transformer(
|
| 752 |
+
x=latent_model_input,
|
| 753 |
+
context=in_prompt_embeds,
|
| 754 |
+
t=timestep,
|
| 755 |
+
vace_context=vace_context_input,
|
| 756 |
+
seq_len=seq_len,
|
| 757 |
+
vace_context_scale=vace_context_scale,
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
# perform guidance
|
| 761 |
+
if do_classifier_free_guidance:
|
| 762 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 763 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 764 |
+
|
| 765 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 766 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 767 |
+
|
| 768 |
+
if callback_on_step_end is not None:
|
| 769 |
+
callback_kwargs = {}
|
| 770 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 771 |
+
callback_kwargs[k] = locals()[k]
|
| 772 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 773 |
+
|
| 774 |
+
latents = callback_outputs.pop("latents", latents)
|
| 775 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 776 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 777 |
+
|
| 778 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 779 |
+
progress_bar.update()
|
| 780 |
+
if comfyui_progressbar:
|
| 781 |
+
pbar.update(1)
|
| 782 |
+
|
| 783 |
+
if subject_ref_images is not None:
|
| 784 |
+
len_subject_ref_images = len(subject_ref_images[0])
|
| 785 |
+
latents = latents[:, :, len_subject_ref_images:, :, :]
|
| 786 |
+
|
| 787 |
+
if output_type == "numpy":
|
| 788 |
+
video = self.decode_latents(latents)
|
| 789 |
+
elif not output_type == "latent":
|
| 790 |
+
video = self.decode_latents(latents)
|
| 791 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 792 |
+
else:
|
| 793 |
+
video = latents
|
| 794 |
+
|
| 795 |
+
# Offload all models
|
| 796 |
+
self.maybe_free_model_hooks()
|
| 797 |
+
|
| 798 |
+
if not return_dict:
|
| 799 |
+
video = torch.from_numpy(video)
|
| 800 |
+
|
| 801 |
+
return WanPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_wan_fun_control.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 11 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 12 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 13 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 14 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 15 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 16 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 17 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 18 |
+
from diffusers.video_processor import VideoProcessor
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from transformers import T5Tokenizer
|
| 22 |
+
|
| 23 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
|
| 24 |
+
WanT5EncoderModel, WanTransformer3DModel)
|
| 25 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 26 |
+
get_sampling_sigmas)
|
| 27 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
EXAMPLE_DOC_STRING = """
|
| 33 |
+
Examples:
|
| 34 |
+
```python
|
| 35 |
+
pass
|
| 36 |
+
```
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 41 |
+
def retrieve_timesteps(
|
| 42 |
+
scheduler,
|
| 43 |
+
num_inference_steps: Optional[int] = None,
|
| 44 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 45 |
+
timesteps: Optional[List[int]] = None,
|
| 46 |
+
sigmas: Optional[List[float]] = None,
|
| 47 |
+
**kwargs,
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 51 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
scheduler (`SchedulerMixin`):
|
| 55 |
+
The scheduler to get timesteps from.
|
| 56 |
+
num_inference_steps (`int`):
|
| 57 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 58 |
+
must be `None`.
|
| 59 |
+
device (`str` or `torch.device`, *optional*):
|
| 60 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 61 |
+
timesteps (`List[int]`, *optional*):
|
| 62 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 63 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 64 |
+
sigmas (`List[float]`, *optional*):
|
| 65 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 66 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 70 |
+
second element is the number of inference steps.
|
| 71 |
+
"""
|
| 72 |
+
if timesteps is not None and sigmas is not None:
|
| 73 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 74 |
+
if timesteps is not None:
|
| 75 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 76 |
+
if not accepts_timesteps:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 79 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 80 |
+
)
|
| 81 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 82 |
+
timesteps = scheduler.timesteps
|
| 83 |
+
num_inference_steps = len(timesteps)
|
| 84 |
+
elif sigmas is not None:
|
| 85 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 86 |
+
if not accept_sigmas:
|
| 87 |
+
raise ValueError(
|
| 88 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 89 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 90 |
+
)
|
| 91 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 92 |
+
timesteps = scheduler.timesteps
|
| 93 |
+
num_inference_steps = len(timesteps)
|
| 94 |
+
else:
|
| 95 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 96 |
+
timesteps = scheduler.timesteps
|
| 97 |
+
return timesteps, num_inference_steps
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def resize_mask(mask, latent, process_first_frame_only=True):
|
| 101 |
+
latent_size = latent.size()
|
| 102 |
+
batch_size, channels, num_frames, height, width = mask.shape
|
| 103 |
+
|
| 104 |
+
if process_first_frame_only:
|
| 105 |
+
target_size = list(latent_size[2:])
|
| 106 |
+
target_size[0] = 1
|
| 107 |
+
first_frame_resized = F.interpolate(
|
| 108 |
+
mask[:, :, 0:1, :, :],
|
| 109 |
+
size=target_size,
|
| 110 |
+
mode='trilinear',
|
| 111 |
+
align_corners=False
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
target_size = list(latent_size[2:])
|
| 115 |
+
target_size[0] = target_size[0] - 1
|
| 116 |
+
if target_size[0] != 0:
|
| 117 |
+
remaining_frames_resized = F.interpolate(
|
| 118 |
+
mask[:, :, 1:, :, :],
|
| 119 |
+
size=target_size,
|
| 120 |
+
mode='trilinear',
|
| 121 |
+
align_corners=False
|
| 122 |
+
)
|
| 123 |
+
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
| 124 |
+
else:
|
| 125 |
+
resized_mask = first_frame_resized
|
| 126 |
+
else:
|
| 127 |
+
target_size = list(latent_size[2:])
|
| 128 |
+
resized_mask = F.interpolate(
|
| 129 |
+
mask,
|
| 130 |
+
size=target_size,
|
| 131 |
+
mode='trilinear',
|
| 132 |
+
align_corners=False
|
| 133 |
+
)
|
| 134 |
+
return resized_mask
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@dataclass
|
| 138 |
+
class WanPipelineOutput(BaseOutput):
|
| 139 |
+
r"""
|
| 140 |
+
Output class for CogVideo pipelines.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 144 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 145 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 146 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
videos: torch.Tensor
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class WanFunControlPipeline(DiffusionPipeline):
|
| 153 |
+
r"""
|
| 154 |
+
Pipeline for text-to-video generation using Wan.
|
| 155 |
+
|
| 156 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 157 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
_optional_components = []
|
| 161 |
+
model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae"
|
| 162 |
+
|
| 163 |
+
_callback_tensor_inputs = [
|
| 164 |
+
"latents",
|
| 165 |
+
"prompt_embeds",
|
| 166 |
+
"negative_prompt_embeds",
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
tokenizer: AutoTokenizer,
|
| 172 |
+
text_encoder: WanT5EncoderModel,
|
| 173 |
+
vae: AutoencoderKLWan,
|
| 174 |
+
transformer: WanTransformer3DModel,
|
| 175 |
+
clip_image_encoder: CLIPModel,
|
| 176 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 177 |
+
):
|
| 178 |
+
super().__init__()
|
| 179 |
+
|
| 180 |
+
self.register_modules(
|
| 181 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 185 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 186 |
+
self.mask_processor = VaeImageProcessor(
|
| 187 |
+
vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def _get_t5_prompt_embeds(
|
| 191 |
+
self,
|
| 192 |
+
prompt: Union[str, List[str]] = None,
|
| 193 |
+
num_videos_per_prompt: int = 1,
|
| 194 |
+
max_sequence_length: int = 512,
|
| 195 |
+
device: Optional[torch.device] = None,
|
| 196 |
+
dtype: Optional[torch.dtype] = None,
|
| 197 |
+
):
|
| 198 |
+
device = device or self._execution_device
|
| 199 |
+
dtype = dtype or self.text_encoder.dtype
|
| 200 |
+
|
| 201 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 202 |
+
batch_size = len(prompt)
|
| 203 |
+
|
| 204 |
+
text_inputs = self.tokenizer(
|
| 205 |
+
prompt,
|
| 206 |
+
padding="max_length",
|
| 207 |
+
max_length=max_sequence_length,
|
| 208 |
+
truncation=True,
|
| 209 |
+
add_special_tokens=True,
|
| 210 |
+
return_tensors="pt",
|
| 211 |
+
)
|
| 212 |
+
text_input_ids = text_inputs.input_ids
|
| 213 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 214 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 215 |
+
|
| 216 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 217 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 218 |
+
logger.warning(
|
| 219 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 220 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 224 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 225 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 226 |
+
|
| 227 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 228 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 229 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 230 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 231 |
+
|
| 232 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 233 |
+
|
| 234 |
+
def encode_prompt(
|
| 235 |
+
self,
|
| 236 |
+
prompt: Union[str, List[str]],
|
| 237 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 238 |
+
do_classifier_free_guidance: bool = True,
|
| 239 |
+
num_videos_per_prompt: int = 1,
|
| 240 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 241 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 242 |
+
max_sequence_length: int = 512,
|
| 243 |
+
device: Optional[torch.device] = None,
|
| 244 |
+
dtype: Optional[torch.dtype] = None,
|
| 245 |
+
):
|
| 246 |
+
r"""
|
| 247 |
+
Encodes the prompt into text encoder hidden states.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 251 |
+
prompt to be encoded
|
| 252 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 253 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 254 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 255 |
+
less than `1`).
|
| 256 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 257 |
+
Whether to use classifier free guidance or not.
|
| 258 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 259 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 260 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 261 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 262 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 263 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 264 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 265 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 266 |
+
argument.
|
| 267 |
+
device: (`torch.device`, *optional*):
|
| 268 |
+
torch device
|
| 269 |
+
dtype: (`torch.dtype`, *optional*):
|
| 270 |
+
torch dtype
|
| 271 |
+
"""
|
| 272 |
+
device = device or self._execution_device
|
| 273 |
+
|
| 274 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 275 |
+
if prompt is not None:
|
| 276 |
+
batch_size = len(prompt)
|
| 277 |
+
else:
|
| 278 |
+
batch_size = prompt_embeds.shape[0]
|
| 279 |
+
|
| 280 |
+
if prompt_embeds is None:
|
| 281 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 282 |
+
prompt=prompt,
|
| 283 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 284 |
+
max_sequence_length=max_sequence_length,
|
| 285 |
+
device=device,
|
| 286 |
+
dtype=dtype,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 290 |
+
negative_prompt = negative_prompt or ""
|
| 291 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 292 |
+
|
| 293 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 294 |
+
raise TypeError(
|
| 295 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 296 |
+
f" {type(prompt)}."
|
| 297 |
+
)
|
| 298 |
+
elif batch_size != len(negative_prompt):
|
| 299 |
+
raise ValueError(
|
| 300 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 301 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 302 |
+
" the batch size of `prompt`."
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 306 |
+
prompt=negative_prompt,
|
| 307 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 308 |
+
max_sequence_length=max_sequence_length,
|
| 309 |
+
device=device,
|
| 310 |
+
dtype=dtype,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
return prompt_embeds, negative_prompt_embeds
|
| 314 |
+
|
| 315 |
+
def prepare_latents(
|
| 316 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 317 |
+
):
|
| 318 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 319 |
+
raise ValueError(
|
| 320 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 321 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
shape = (
|
| 325 |
+
batch_size,
|
| 326 |
+
num_channels_latents,
|
| 327 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1,
|
| 328 |
+
height // self.vae.spatial_compression_ratio,
|
| 329 |
+
width // self.vae.spatial_compression_ratio,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if latents is None:
|
| 333 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 334 |
+
else:
|
| 335 |
+
latents = latents.to(device)
|
| 336 |
+
|
| 337 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 338 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 339 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 340 |
+
return latents
|
| 341 |
+
|
| 342 |
+
def prepare_control_latents(
|
| 343 |
+
self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
| 344 |
+
):
|
| 345 |
+
# resize the control to latents shape as we concatenate the control to the latents
|
| 346 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 347 |
+
# and half precision
|
| 348 |
+
|
| 349 |
+
if control is not None:
|
| 350 |
+
control = control.to(device=device, dtype=dtype)
|
| 351 |
+
bs = 1
|
| 352 |
+
new_control = []
|
| 353 |
+
for i in range(0, control.shape[0], bs):
|
| 354 |
+
control_bs = control[i : i + bs]
|
| 355 |
+
control_bs = self.vae.encode(control_bs)[0]
|
| 356 |
+
control_bs = control_bs.mode()
|
| 357 |
+
new_control.append(control_bs)
|
| 358 |
+
control = torch.cat(new_control, dim = 0)
|
| 359 |
+
|
| 360 |
+
if control_image is not None:
|
| 361 |
+
control_image = control_image.to(device=device, dtype=dtype)
|
| 362 |
+
bs = 1
|
| 363 |
+
new_control_pixel_values = []
|
| 364 |
+
for i in range(0, control_image.shape[0], bs):
|
| 365 |
+
control_pixel_values_bs = control_image[i : i + bs]
|
| 366 |
+
control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
|
| 367 |
+
control_pixel_values_bs = control_pixel_values_bs.mode()
|
| 368 |
+
new_control_pixel_values.append(control_pixel_values_bs)
|
| 369 |
+
control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
|
| 370 |
+
else:
|
| 371 |
+
control_image_latents = None
|
| 372 |
+
|
| 373 |
+
return control, control_image_latents
|
| 374 |
+
|
| 375 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 376 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 377 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 378 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 379 |
+
frames = frames.cpu().float().numpy()
|
| 380 |
+
return frames
|
| 381 |
+
|
| 382 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 383 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 384 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 385 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 386 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 387 |
+
# and should be between [0, 1]
|
| 388 |
+
|
| 389 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 390 |
+
extra_step_kwargs = {}
|
| 391 |
+
if accepts_eta:
|
| 392 |
+
extra_step_kwargs["eta"] = eta
|
| 393 |
+
|
| 394 |
+
# check if the scheduler accepts generator
|
| 395 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 396 |
+
if accepts_generator:
|
| 397 |
+
extra_step_kwargs["generator"] = generator
|
| 398 |
+
return extra_step_kwargs
|
| 399 |
+
|
| 400 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 401 |
+
def check_inputs(
|
| 402 |
+
self,
|
| 403 |
+
prompt,
|
| 404 |
+
height,
|
| 405 |
+
width,
|
| 406 |
+
negative_prompt,
|
| 407 |
+
callback_on_step_end_tensor_inputs,
|
| 408 |
+
prompt_embeds=None,
|
| 409 |
+
negative_prompt_embeds=None,
|
| 410 |
+
):
|
| 411 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 412 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 413 |
+
|
| 414 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 415 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 416 |
+
):
|
| 417 |
+
raise ValueError(
|
| 418 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 419 |
+
)
|
| 420 |
+
if prompt is not None and prompt_embeds is not None:
|
| 421 |
+
raise ValueError(
|
| 422 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 423 |
+
" only forward one of the two."
|
| 424 |
+
)
|
| 425 |
+
elif prompt is None and prompt_embeds is None:
|
| 426 |
+
raise ValueError(
|
| 427 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 428 |
+
)
|
| 429 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 430 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 431 |
+
|
| 432 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 433 |
+
raise ValueError(
|
| 434 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 435 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 439 |
+
raise ValueError(
|
| 440 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 441 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 445 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 446 |
+
raise ValueError(
|
| 447 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 448 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 449 |
+
f" {negative_prompt_embeds.shape}."
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
@property
|
| 453 |
+
def guidance_scale(self):
|
| 454 |
+
return self._guidance_scale
|
| 455 |
+
|
| 456 |
+
@property
|
| 457 |
+
def num_timesteps(self):
|
| 458 |
+
return self._num_timesteps
|
| 459 |
+
|
| 460 |
+
@property
|
| 461 |
+
def attention_kwargs(self):
|
| 462 |
+
return self._attention_kwargs
|
| 463 |
+
|
| 464 |
+
@property
|
| 465 |
+
def interrupt(self):
|
| 466 |
+
return self._interrupt
|
| 467 |
+
|
| 468 |
+
@torch.no_grad()
|
| 469 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 470 |
+
def __call__(
|
| 471 |
+
self,
|
| 472 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 473 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 474 |
+
height: int = 480,
|
| 475 |
+
width: int = 720,
|
| 476 |
+
control_video: Union[torch.FloatTensor] = None,
|
| 477 |
+
control_camera_video: Union[torch.FloatTensor] = None,
|
| 478 |
+
start_image: Union[torch.FloatTensor] = None,
|
| 479 |
+
ref_image: Union[torch.FloatTensor] = None,
|
| 480 |
+
num_frames: int = 49,
|
| 481 |
+
num_inference_steps: int = 50,
|
| 482 |
+
timesteps: Optional[List[int]] = None,
|
| 483 |
+
guidance_scale: float = 6,
|
| 484 |
+
num_videos_per_prompt: int = 1,
|
| 485 |
+
eta: float = 0.0,
|
| 486 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 487 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 488 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 489 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 490 |
+
output_type: str = "numpy",
|
| 491 |
+
return_dict: bool = False,
|
| 492 |
+
callback_on_step_end: Optional[
|
| 493 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 494 |
+
] = None,
|
| 495 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 496 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 497 |
+
clip_image: Image = None,
|
| 498 |
+
max_sequence_length: int = 512,
|
| 499 |
+
comfyui_progressbar: bool = False,
|
| 500 |
+
shift: int = 5,
|
| 501 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 502 |
+
"""
|
| 503 |
+
Function invoked when calling the pipeline for generation.
|
| 504 |
+
Args:
|
| 505 |
+
|
| 506 |
+
Examples:
|
| 507 |
+
|
| 508 |
+
Returns:
|
| 509 |
+
|
| 510 |
+
"""
|
| 511 |
+
|
| 512 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 513 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 514 |
+
num_videos_per_prompt = 1
|
| 515 |
+
|
| 516 |
+
# 1. Check inputs. Raise error if not correct
|
| 517 |
+
self.check_inputs(
|
| 518 |
+
prompt,
|
| 519 |
+
height,
|
| 520 |
+
width,
|
| 521 |
+
negative_prompt,
|
| 522 |
+
callback_on_step_end_tensor_inputs,
|
| 523 |
+
prompt_embeds,
|
| 524 |
+
negative_prompt_embeds,
|
| 525 |
+
)
|
| 526 |
+
self._guidance_scale = guidance_scale
|
| 527 |
+
self._attention_kwargs = attention_kwargs
|
| 528 |
+
self._interrupt = False
|
| 529 |
+
|
| 530 |
+
# 2. Default call parameters
|
| 531 |
+
if prompt is not None and isinstance(prompt, str):
|
| 532 |
+
batch_size = 1
|
| 533 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 534 |
+
batch_size = len(prompt)
|
| 535 |
+
else:
|
| 536 |
+
batch_size = prompt_embeds.shape[0]
|
| 537 |
+
|
| 538 |
+
device = self._execution_device
|
| 539 |
+
weight_dtype = self.text_encoder.dtype
|
| 540 |
+
|
| 541 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 542 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 543 |
+
# corresponds to doing no classifier free guidance.
|
| 544 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 545 |
+
|
| 546 |
+
# 3. Encode input prompt
|
| 547 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 548 |
+
prompt,
|
| 549 |
+
negative_prompt,
|
| 550 |
+
do_classifier_free_guidance,
|
| 551 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 552 |
+
prompt_embeds=prompt_embeds,
|
| 553 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 554 |
+
max_sequence_length=max_sequence_length,
|
| 555 |
+
device=device,
|
| 556 |
+
)
|
| 557 |
+
if do_classifier_free_guidance:
|
| 558 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 559 |
+
else:
|
| 560 |
+
in_prompt_embeds = prompt_embeds
|
| 561 |
+
|
| 562 |
+
# 4. Prepare timesteps
|
| 563 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 564 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
| 565 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 566 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 567 |
+
timesteps = self.scheduler.timesteps
|
| 568 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 569 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 570 |
+
timesteps, _ = retrieve_timesteps(
|
| 571 |
+
self.scheduler,
|
| 572 |
+
device=device,
|
| 573 |
+
sigmas=sampling_sigmas)
|
| 574 |
+
else:
|
| 575 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 576 |
+
self._num_timesteps = len(timesteps)
|
| 577 |
+
if comfyui_progressbar:
|
| 578 |
+
from comfy.utils import ProgressBar
|
| 579 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 580 |
+
|
| 581 |
+
# 5. Prepare latents.
|
| 582 |
+
latent_channels = self.vae.config.latent_channels
|
| 583 |
+
latents = self.prepare_latents(
|
| 584 |
+
batch_size * num_videos_per_prompt,
|
| 585 |
+
latent_channels,
|
| 586 |
+
num_frames,
|
| 587 |
+
height,
|
| 588 |
+
width,
|
| 589 |
+
weight_dtype,
|
| 590 |
+
device,
|
| 591 |
+
generator,
|
| 592 |
+
latents,
|
| 593 |
+
)
|
| 594 |
+
if comfyui_progressbar:
|
| 595 |
+
pbar.update(1)
|
| 596 |
+
|
| 597 |
+
# Prepare mask latent variables
|
| 598 |
+
if control_camera_video is not None:
|
| 599 |
+
control_latents = None
|
| 600 |
+
# Rearrange dimensions
|
| 601 |
+
# Concatenate and transpose dimensions
|
| 602 |
+
control_camera_latents = torch.concat(
|
| 603 |
+
[
|
| 604 |
+
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
|
| 605 |
+
control_camera_video[:, :, 1:]
|
| 606 |
+
], dim=2
|
| 607 |
+
).transpose(1, 2)
|
| 608 |
+
|
| 609 |
+
# Reshape, transpose, and view into desired shape
|
| 610 |
+
b, f, c, h, w = control_camera_latents.shape
|
| 611 |
+
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
| 612 |
+
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
| 613 |
+
elif control_video is not None:
|
| 614 |
+
video_length = control_video.shape[2]
|
| 615 |
+
control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 616 |
+
control_video = control_video.to(dtype=torch.float32)
|
| 617 |
+
control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 618 |
+
control_video_latents = self.prepare_control_latents(
|
| 619 |
+
None,
|
| 620 |
+
control_video,
|
| 621 |
+
batch_size,
|
| 622 |
+
height,
|
| 623 |
+
width,
|
| 624 |
+
weight_dtype,
|
| 625 |
+
device,
|
| 626 |
+
generator,
|
| 627 |
+
do_classifier_free_guidance
|
| 628 |
+
)[1]
|
| 629 |
+
control_camera_latents = None
|
| 630 |
+
else:
|
| 631 |
+
control_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
|
| 632 |
+
control_camera_latents = None
|
| 633 |
+
|
| 634 |
+
if start_image is not None:
|
| 635 |
+
video_length = start_image.shape[2]
|
| 636 |
+
start_image = self.image_processor.preprocess(rearrange(start_image, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 637 |
+
start_image = start_image.to(dtype=torch.float32)
|
| 638 |
+
start_image = rearrange(start_image, "(b f) c h w -> b c f h w", f=video_length)
|
| 639 |
+
|
| 640 |
+
start_image_latentes = self.prepare_control_latents(
|
| 641 |
+
None,
|
| 642 |
+
start_image,
|
| 643 |
+
batch_size,
|
| 644 |
+
height,
|
| 645 |
+
width,
|
| 646 |
+
weight_dtype,
|
| 647 |
+
device,
|
| 648 |
+
generator,
|
| 649 |
+
do_classifier_free_guidance
|
| 650 |
+
)[1]
|
| 651 |
+
|
| 652 |
+
start_image_latentes_conv_in = torch.zeros_like(latents)
|
| 653 |
+
if latents.size()[2] != 1:
|
| 654 |
+
start_image_latentes_conv_in[:, :, :1] = start_image_latentes
|
| 655 |
+
else:
|
| 656 |
+
start_image_latentes_conv_in = torch.zeros_like(latents)
|
| 657 |
+
|
| 658 |
+
# Prepare clip latent variables
|
| 659 |
+
if clip_image is not None:
|
| 660 |
+
clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
|
| 661 |
+
clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
|
| 662 |
+
else:
|
| 663 |
+
clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0))
|
| 664 |
+
clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
|
| 665 |
+
clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
|
| 666 |
+
clip_context = torch.zeros_like(clip_context)
|
| 667 |
+
|
| 668 |
+
if self.transformer.config.get("add_ref_conv", False):
|
| 669 |
+
if ref_image is not None:
|
| 670 |
+
video_length = ref_image.shape[2]
|
| 671 |
+
ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 672 |
+
ref_image = ref_image.to(dtype=torch.float32)
|
| 673 |
+
ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length)
|
| 674 |
+
|
| 675 |
+
ref_image_latentes = self.prepare_control_latents(
|
| 676 |
+
None,
|
| 677 |
+
ref_image,
|
| 678 |
+
batch_size,
|
| 679 |
+
height,
|
| 680 |
+
width,
|
| 681 |
+
weight_dtype,
|
| 682 |
+
device,
|
| 683 |
+
generator,
|
| 684 |
+
do_classifier_free_guidance
|
| 685 |
+
)[1]
|
| 686 |
+
ref_image_latentes = ref_image_latentes[:, :, 0]
|
| 687 |
+
else:
|
| 688 |
+
ref_image_latentes = torch.zeros_like(latents)[:, :, 0]
|
| 689 |
+
else:
|
| 690 |
+
if ref_image is not None:
|
| 691 |
+
raise ValueError("The add_ref_conv is False, but ref_image is not None")
|
| 692 |
+
else:
|
| 693 |
+
ref_image_latentes = None
|
| 694 |
+
|
| 695 |
+
if comfyui_progressbar:
|
| 696 |
+
pbar.update(1)
|
| 697 |
+
|
| 698 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 699 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 700 |
+
|
| 701 |
+
target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
|
| 702 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 703 |
+
# 7. Denoising loop
|
| 704 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 705 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 706 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 707 |
+
for i, t in enumerate(timesteps):
|
| 708 |
+
self.transformer.current_steps = i
|
| 709 |
+
|
| 710 |
+
if self.interrupt:
|
| 711 |
+
continue
|
| 712 |
+
|
| 713 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 714 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 715 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 716 |
+
|
| 717 |
+
# Prepare mask latent variables
|
| 718 |
+
if control_camera_video is not None:
|
| 719 |
+
control_latents_input = None
|
| 720 |
+
control_camera_latents_input = (
|
| 721 |
+
torch.cat([control_camera_latents] * 2) if do_classifier_free_guidance else control_camera_latents
|
| 722 |
+
).to(device, weight_dtype)
|
| 723 |
+
else:
|
| 724 |
+
control_latents_input = (
|
| 725 |
+
torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
|
| 726 |
+
).to(device, weight_dtype)
|
| 727 |
+
control_camera_latents_input = None
|
| 728 |
+
|
| 729 |
+
start_image_latentes_conv_in_input = (
|
| 730 |
+
torch.cat([start_image_latentes_conv_in] * 2) if do_classifier_free_guidance else start_image_latentes_conv_in
|
| 731 |
+
).to(device, weight_dtype)
|
| 732 |
+
control_latents_input = start_image_latentes_conv_in_input if control_latents_input is None else \
|
| 733 |
+
torch.cat([control_latents_input, start_image_latentes_conv_in_input], dim = 1)
|
| 734 |
+
|
| 735 |
+
clip_context_input = (
|
| 736 |
+
torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
if ref_image_latentes is not None:
|
| 740 |
+
full_ref = (
|
| 741 |
+
torch.cat([ref_image_latentes] * 2) if do_classifier_free_guidance else ref_image_latentes
|
| 742 |
+
).to(device, weight_dtype)
|
| 743 |
+
else:
|
| 744 |
+
full_ref = None
|
| 745 |
+
|
| 746 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 747 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 748 |
+
|
| 749 |
+
# predict noise model_output
|
| 750 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 751 |
+
noise_pred = self.transformer(
|
| 752 |
+
x=latent_model_input,
|
| 753 |
+
context=in_prompt_embeds,
|
| 754 |
+
t=timestep,
|
| 755 |
+
seq_len=seq_len,
|
| 756 |
+
y=control_latents_input,
|
| 757 |
+
y_camera=control_camera_latents_input,
|
| 758 |
+
full_ref=full_ref,
|
| 759 |
+
clip_fea=clip_context_input,
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
# perform guidance
|
| 763 |
+
if do_classifier_free_guidance:
|
| 764 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 765 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 766 |
+
|
| 767 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 768 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 769 |
+
|
| 770 |
+
if callback_on_step_end is not None:
|
| 771 |
+
callback_kwargs = {}
|
| 772 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 773 |
+
callback_kwargs[k] = locals()[k]
|
| 774 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 775 |
+
|
| 776 |
+
latents = callback_outputs.pop("latents", latents)
|
| 777 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 778 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 779 |
+
|
| 780 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 781 |
+
progress_bar.update()
|
| 782 |
+
if comfyui_progressbar:
|
| 783 |
+
pbar.update(1)
|
| 784 |
+
|
| 785 |
+
if output_type == "numpy":
|
| 786 |
+
video = self.decode_latents(latents)
|
| 787 |
+
elif not output_type == "latent":
|
| 788 |
+
video = self.decode_latents(latents)
|
| 789 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 790 |
+
else:
|
| 791 |
+
video = latents
|
| 792 |
+
|
| 793 |
+
# Offload all models
|
| 794 |
+
self.maybe_free_model_hooks()
|
| 795 |
+
|
| 796 |
+
if not return_dict:
|
| 797 |
+
video = torch.from_numpy(video)
|
| 798 |
+
|
| 799 |
+
return WanPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_wan_fun_inpaint.py
ADDED
|
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 11 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 12 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 13 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 14 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 15 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 16 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 17 |
+
from diffusers.video_processor import VideoProcessor
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from transformers import T5Tokenizer
|
| 21 |
+
|
| 22 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
|
| 23 |
+
WanT5EncoderModel, WanTransformer3DModel)
|
| 24 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 25 |
+
get_sampling_sigmas)
|
| 26 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
EXAMPLE_DOC_STRING = """
|
| 32 |
+
Examples:
|
| 33 |
+
```python
|
| 34 |
+
pass
|
| 35 |
+
```
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 40 |
+
def retrieve_timesteps(
|
| 41 |
+
scheduler,
|
| 42 |
+
num_inference_steps: Optional[int] = None,
|
| 43 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 44 |
+
timesteps: Optional[List[int]] = None,
|
| 45 |
+
sigmas: Optional[List[float]] = None,
|
| 46 |
+
**kwargs,
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 50 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
scheduler (`SchedulerMixin`):
|
| 54 |
+
The scheduler to get timesteps from.
|
| 55 |
+
num_inference_steps (`int`):
|
| 56 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 57 |
+
must be `None`.
|
| 58 |
+
device (`str` or `torch.device`, *optional*):
|
| 59 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 60 |
+
timesteps (`List[int]`, *optional*):
|
| 61 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 62 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 63 |
+
sigmas (`List[float]`, *optional*):
|
| 64 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 65 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 69 |
+
second element is the number of inference steps.
|
| 70 |
+
"""
|
| 71 |
+
if timesteps is not None and sigmas is not None:
|
| 72 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 73 |
+
if timesteps is not None:
|
| 74 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 75 |
+
if not accepts_timesteps:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 78 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 79 |
+
)
|
| 80 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 81 |
+
timesteps = scheduler.timesteps
|
| 82 |
+
num_inference_steps = len(timesteps)
|
| 83 |
+
elif sigmas is not None:
|
| 84 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 85 |
+
if not accept_sigmas:
|
| 86 |
+
raise ValueError(
|
| 87 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 88 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 89 |
+
)
|
| 90 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 91 |
+
timesteps = scheduler.timesteps
|
| 92 |
+
num_inference_steps = len(timesteps)
|
| 93 |
+
else:
|
| 94 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 95 |
+
timesteps = scheduler.timesteps
|
| 96 |
+
return timesteps, num_inference_steps
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def resize_mask(mask, latent, process_first_frame_only=True):
|
| 100 |
+
latent_size = latent.size()
|
| 101 |
+
batch_size, channels, num_frames, height, width = mask.shape
|
| 102 |
+
|
| 103 |
+
if process_first_frame_only:
|
| 104 |
+
target_size = list(latent_size[2:])
|
| 105 |
+
target_size[0] = 1
|
| 106 |
+
first_frame_resized = F.interpolate(
|
| 107 |
+
mask[:, :, 0:1, :, :],
|
| 108 |
+
size=target_size,
|
| 109 |
+
mode='trilinear',
|
| 110 |
+
align_corners=False
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
target_size = list(latent_size[2:])
|
| 114 |
+
target_size[0] = target_size[0] - 1
|
| 115 |
+
if target_size[0] != 0:
|
| 116 |
+
remaining_frames_resized = F.interpolate(
|
| 117 |
+
mask[:, :, 1:, :, :],
|
| 118 |
+
size=target_size,
|
| 119 |
+
mode='trilinear',
|
| 120 |
+
align_corners=False
|
| 121 |
+
)
|
| 122 |
+
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
| 123 |
+
else:
|
| 124 |
+
resized_mask = first_frame_resized
|
| 125 |
+
else:
|
| 126 |
+
target_size = list(latent_size[2:])
|
| 127 |
+
resized_mask = F.interpolate(
|
| 128 |
+
mask,
|
| 129 |
+
size=target_size,
|
| 130 |
+
mode='trilinear',
|
| 131 |
+
align_corners=False
|
| 132 |
+
)
|
| 133 |
+
return resized_mask
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@dataclass
|
| 137 |
+
class WanPipelineOutput(BaseOutput):
|
| 138 |
+
r"""
|
| 139 |
+
Output class for CogVideo pipelines.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 143 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 144 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 145 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
videos: torch.Tensor
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class WanFunInpaintPipeline(DiffusionPipeline):
|
| 152 |
+
r"""
|
| 153 |
+
Pipeline for text-to-video generation using Wan.
|
| 154 |
+
|
| 155 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 156 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
_optional_components = []
|
| 160 |
+
model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae"
|
| 161 |
+
|
| 162 |
+
_callback_tensor_inputs = [
|
| 163 |
+
"latents",
|
| 164 |
+
"prompt_embeds",
|
| 165 |
+
"negative_prompt_embeds",
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
def __init__(
|
| 169 |
+
self,
|
| 170 |
+
tokenizer: AutoTokenizer,
|
| 171 |
+
text_encoder: WanT5EncoderModel,
|
| 172 |
+
vae: AutoencoderKLWan,
|
| 173 |
+
transformer: WanTransformer3DModel,
|
| 174 |
+
clip_image_encoder: CLIPModel,
|
| 175 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 176 |
+
):
|
| 177 |
+
super().__init__()
|
| 178 |
+
|
| 179 |
+
self.register_modules(
|
| 180 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 184 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 185 |
+
self.mask_processor = VaeImageProcessor(
|
| 186 |
+
vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def _get_t5_prompt_embeds(
|
| 190 |
+
self,
|
| 191 |
+
prompt: Union[str, List[str]] = None,
|
| 192 |
+
num_videos_per_prompt: int = 1,
|
| 193 |
+
max_sequence_length: int = 512,
|
| 194 |
+
device: Optional[torch.device] = None,
|
| 195 |
+
dtype: Optional[torch.dtype] = None,
|
| 196 |
+
):
|
| 197 |
+
device = device or self._execution_device
|
| 198 |
+
dtype = dtype or self.text_encoder.dtype
|
| 199 |
+
|
| 200 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 201 |
+
batch_size = len(prompt)
|
| 202 |
+
|
| 203 |
+
text_inputs = self.tokenizer(
|
| 204 |
+
prompt,
|
| 205 |
+
padding="max_length",
|
| 206 |
+
max_length=max_sequence_length,
|
| 207 |
+
truncation=True,
|
| 208 |
+
add_special_tokens=True,
|
| 209 |
+
return_tensors="pt",
|
| 210 |
+
)
|
| 211 |
+
text_input_ids = text_inputs.input_ids
|
| 212 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 213 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 214 |
+
|
| 215 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 216 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 217 |
+
logger.warning(
|
| 218 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 219 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 223 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 224 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 225 |
+
|
| 226 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 227 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 228 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 229 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 230 |
+
|
| 231 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 232 |
+
|
| 233 |
+
def encode_prompt(
|
| 234 |
+
self,
|
| 235 |
+
prompt: Union[str, List[str]],
|
| 236 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 237 |
+
do_classifier_free_guidance: bool = True,
|
| 238 |
+
num_videos_per_prompt: int = 1,
|
| 239 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 240 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 241 |
+
max_sequence_length: int = 512,
|
| 242 |
+
device: Optional[torch.device] = None,
|
| 243 |
+
dtype: Optional[torch.dtype] = None,
|
| 244 |
+
):
|
| 245 |
+
r"""
|
| 246 |
+
Encodes the prompt into text encoder hidden states.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 250 |
+
prompt to be encoded
|
| 251 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 252 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 253 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 254 |
+
less than `1`).
|
| 255 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 256 |
+
Whether to use classifier free guidance or not.
|
| 257 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 258 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 259 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 260 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 261 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 262 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 263 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 264 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 265 |
+
argument.
|
| 266 |
+
device: (`torch.device`, *optional*):
|
| 267 |
+
torch device
|
| 268 |
+
dtype: (`torch.dtype`, *optional*):
|
| 269 |
+
torch dtype
|
| 270 |
+
"""
|
| 271 |
+
device = device or self._execution_device
|
| 272 |
+
|
| 273 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 274 |
+
if prompt is not None:
|
| 275 |
+
batch_size = len(prompt)
|
| 276 |
+
else:
|
| 277 |
+
batch_size = prompt_embeds.shape[0]
|
| 278 |
+
|
| 279 |
+
if prompt_embeds is None:
|
| 280 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 281 |
+
prompt=prompt,
|
| 282 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 283 |
+
max_sequence_length=max_sequence_length,
|
| 284 |
+
device=device,
|
| 285 |
+
dtype=dtype,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 289 |
+
negative_prompt = negative_prompt or ""
|
| 290 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 291 |
+
|
| 292 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 293 |
+
raise TypeError(
|
| 294 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 295 |
+
f" {type(prompt)}."
|
| 296 |
+
)
|
| 297 |
+
elif batch_size != len(negative_prompt):
|
| 298 |
+
raise ValueError(
|
| 299 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 300 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 301 |
+
" the batch size of `prompt`."
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 305 |
+
prompt=negative_prompt,
|
| 306 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 307 |
+
max_sequence_length=max_sequence_length,
|
| 308 |
+
device=device,
|
| 309 |
+
dtype=dtype,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return prompt_embeds, negative_prompt_embeds
|
| 313 |
+
|
| 314 |
+
def prepare_latents(
|
| 315 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 316 |
+
):
|
| 317 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 318 |
+
raise ValueError(
|
| 319 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 320 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
shape = (
|
| 324 |
+
batch_size,
|
| 325 |
+
num_channels_latents,
|
| 326 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1,
|
| 327 |
+
height // self.vae.spatial_compression_ratio,
|
| 328 |
+
width // self.vae.spatial_compression_ratio,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if latents is None:
|
| 332 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 333 |
+
else:
|
| 334 |
+
latents = latents.to(device)
|
| 335 |
+
|
| 336 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 337 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 338 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 339 |
+
return latents
|
| 340 |
+
|
| 341 |
+
def prepare_mask_latents(
|
| 342 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
|
| 343 |
+
):
|
| 344 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
| 345 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 346 |
+
# and half precision
|
| 347 |
+
|
| 348 |
+
if mask is not None:
|
| 349 |
+
mask = mask.to(device=device, dtype=self.vae.dtype)
|
| 350 |
+
bs = 1
|
| 351 |
+
new_mask = []
|
| 352 |
+
for i in range(0, mask.shape[0], bs):
|
| 353 |
+
mask_bs = mask[i : i + bs]
|
| 354 |
+
mask_bs = self.vae.encode(mask_bs)[0]
|
| 355 |
+
mask_bs = mask_bs.mode()
|
| 356 |
+
new_mask.append(mask_bs)
|
| 357 |
+
mask = torch.cat(new_mask, dim = 0)
|
| 358 |
+
# mask = mask * self.vae.config.scaling_factor
|
| 359 |
+
|
| 360 |
+
if masked_image is not None:
|
| 361 |
+
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
| 362 |
+
bs = 1
|
| 363 |
+
new_mask_pixel_values = []
|
| 364 |
+
for i in range(0, masked_image.shape[0], bs):
|
| 365 |
+
mask_pixel_values_bs = masked_image[i : i + bs]
|
| 366 |
+
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
| 367 |
+
mask_pixel_values_bs = mask_pixel_values_bs.mode()
|
| 368 |
+
new_mask_pixel_values.append(mask_pixel_values_bs)
|
| 369 |
+
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
| 370 |
+
# masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
| 371 |
+
else:
|
| 372 |
+
masked_image_latents = None
|
| 373 |
+
|
| 374 |
+
return mask, masked_image_latents
|
| 375 |
+
|
| 376 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 377 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 378 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 379 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 380 |
+
frames = frames.cpu().float().numpy()
|
| 381 |
+
return frames
|
| 382 |
+
|
| 383 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 384 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 385 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 386 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 387 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 388 |
+
# and should be between [0, 1]
|
| 389 |
+
|
| 390 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 391 |
+
extra_step_kwargs = {}
|
| 392 |
+
if accepts_eta:
|
| 393 |
+
extra_step_kwargs["eta"] = eta
|
| 394 |
+
|
| 395 |
+
# check if the scheduler accepts generator
|
| 396 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 397 |
+
if accepts_generator:
|
| 398 |
+
extra_step_kwargs["generator"] = generator
|
| 399 |
+
return extra_step_kwargs
|
| 400 |
+
|
| 401 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 402 |
+
def check_inputs(
|
| 403 |
+
self,
|
| 404 |
+
prompt,
|
| 405 |
+
height,
|
| 406 |
+
width,
|
| 407 |
+
negative_prompt,
|
| 408 |
+
callback_on_step_end_tensor_inputs,
|
| 409 |
+
prompt_embeds=None,
|
| 410 |
+
negative_prompt_embeds=None,
|
| 411 |
+
):
|
| 412 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 413 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 414 |
+
|
| 415 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 416 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 417 |
+
):
|
| 418 |
+
raise ValueError(
|
| 419 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 420 |
+
)
|
| 421 |
+
if prompt is not None and prompt_embeds is not None:
|
| 422 |
+
raise ValueError(
|
| 423 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 424 |
+
" only forward one of the two."
|
| 425 |
+
)
|
| 426 |
+
elif prompt is None and prompt_embeds is None:
|
| 427 |
+
raise ValueError(
|
| 428 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 429 |
+
)
|
| 430 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 431 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 432 |
+
|
| 433 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 434 |
+
raise ValueError(
|
| 435 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 436 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 440 |
+
raise ValueError(
|
| 441 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 442 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 446 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 447 |
+
raise ValueError(
|
| 448 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 449 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 450 |
+
f" {negative_prompt_embeds.shape}."
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
@property
|
| 454 |
+
def guidance_scale(self):
|
| 455 |
+
return self._guidance_scale
|
| 456 |
+
|
| 457 |
+
@property
|
| 458 |
+
def num_timesteps(self):
|
| 459 |
+
return self._num_timesteps
|
| 460 |
+
|
| 461 |
+
@property
|
| 462 |
+
def attention_kwargs(self):
|
| 463 |
+
return self._attention_kwargs
|
| 464 |
+
|
| 465 |
+
@property
|
| 466 |
+
def interrupt(self):
|
| 467 |
+
return self._interrupt
|
| 468 |
+
|
| 469 |
+
@torch.no_grad()
|
| 470 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 471 |
+
def __call__(
|
| 472 |
+
self,
|
| 473 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 474 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 475 |
+
height: int = 480,
|
| 476 |
+
width: int = 720,
|
| 477 |
+
video: Union[torch.FloatTensor] = None,
|
| 478 |
+
mask_video: Union[torch.FloatTensor] = None,
|
| 479 |
+
num_frames: int = 49,
|
| 480 |
+
num_inference_steps: int = 50,
|
| 481 |
+
timesteps: Optional[List[int]] = None,
|
| 482 |
+
guidance_scale: float = 6,
|
| 483 |
+
num_videos_per_prompt: int = 1,
|
| 484 |
+
eta: float = 0.0,
|
| 485 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 486 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 487 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 488 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 489 |
+
output_type: str = "numpy",
|
| 490 |
+
return_dict: bool = False,
|
| 491 |
+
callback_on_step_end: Optional[
|
| 492 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 493 |
+
] = None,
|
| 494 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 495 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 496 |
+
clip_image: Image = None,
|
| 497 |
+
max_sequence_length: int = 512,
|
| 498 |
+
comfyui_progressbar: bool = False,
|
| 499 |
+
shift: int = 5,
|
| 500 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 501 |
+
"""
|
| 502 |
+
Function invoked when calling the pipeline for generation.
|
| 503 |
+
Args:
|
| 504 |
+
|
| 505 |
+
Examples:
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 512 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 513 |
+
num_videos_per_prompt = 1
|
| 514 |
+
|
| 515 |
+
# 1. Check inputs. Raise error if not correct
|
| 516 |
+
self.check_inputs(
|
| 517 |
+
prompt,
|
| 518 |
+
height,
|
| 519 |
+
width,
|
| 520 |
+
negative_prompt,
|
| 521 |
+
callback_on_step_end_tensor_inputs,
|
| 522 |
+
prompt_embeds,
|
| 523 |
+
negative_prompt_embeds,
|
| 524 |
+
)
|
| 525 |
+
self._guidance_scale = guidance_scale
|
| 526 |
+
self._attention_kwargs = attention_kwargs
|
| 527 |
+
self._interrupt = False
|
| 528 |
+
|
| 529 |
+
# 2. Default call parameters
|
| 530 |
+
if prompt is not None and isinstance(prompt, str):
|
| 531 |
+
batch_size = 1
|
| 532 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 533 |
+
batch_size = len(prompt)
|
| 534 |
+
else:
|
| 535 |
+
batch_size = prompt_embeds.shape[0]
|
| 536 |
+
|
| 537 |
+
device = self._execution_device
|
| 538 |
+
weight_dtype = self.text_encoder.dtype
|
| 539 |
+
|
| 540 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 541 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 542 |
+
# corresponds to doing no classifier free guidance.
|
| 543 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 544 |
+
|
| 545 |
+
# 3. Encode input prompt
|
| 546 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 547 |
+
prompt,
|
| 548 |
+
negative_prompt,
|
| 549 |
+
do_classifier_free_guidance,
|
| 550 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 551 |
+
prompt_embeds=prompt_embeds,
|
| 552 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 553 |
+
max_sequence_length=max_sequence_length,
|
| 554 |
+
device=device,
|
| 555 |
+
)
|
| 556 |
+
if do_classifier_free_guidance:
|
| 557 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 558 |
+
else:
|
| 559 |
+
in_prompt_embeds = prompt_embeds
|
| 560 |
+
|
| 561 |
+
# 4. Prepare timesteps
|
| 562 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 563 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
| 564 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 565 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 566 |
+
timesteps = self.scheduler.timesteps
|
| 567 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 568 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 569 |
+
timesteps, _ = retrieve_timesteps(
|
| 570 |
+
self.scheduler,
|
| 571 |
+
device=device,
|
| 572 |
+
sigmas=sampling_sigmas)
|
| 573 |
+
else:
|
| 574 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 575 |
+
self._num_timesteps = len(timesteps)
|
| 576 |
+
if comfyui_progressbar:
|
| 577 |
+
from comfy.utils import ProgressBar
|
| 578 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 579 |
+
|
| 580 |
+
# 5. Prepare latents.
|
| 581 |
+
if video is not None:
|
| 582 |
+
video_length = video.shape[2]
|
| 583 |
+
init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 584 |
+
init_video = init_video.to(dtype=torch.float32)
|
| 585 |
+
init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 586 |
+
else:
|
| 587 |
+
init_video = None
|
| 588 |
+
|
| 589 |
+
latent_channels = self.vae.config.latent_channels
|
| 590 |
+
latents = self.prepare_latents(
|
| 591 |
+
batch_size * num_videos_per_prompt,
|
| 592 |
+
latent_channels,
|
| 593 |
+
num_frames,
|
| 594 |
+
height,
|
| 595 |
+
width,
|
| 596 |
+
weight_dtype,
|
| 597 |
+
device,
|
| 598 |
+
generator,
|
| 599 |
+
latents,
|
| 600 |
+
)
|
| 601 |
+
if comfyui_progressbar:
|
| 602 |
+
pbar.update(1)
|
| 603 |
+
|
| 604 |
+
# Prepare mask latent variables
|
| 605 |
+
if init_video is not None:
|
| 606 |
+
if (mask_video == 255).all():
|
| 607 |
+
mask_latents = torch.tile(
|
| 608 |
+
torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1]
|
| 609 |
+
)
|
| 610 |
+
masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
|
| 611 |
+
else:
|
| 612 |
+
bs, _, video_length, height, width = video.size()
|
| 613 |
+
mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 614 |
+
mask_condition = mask_condition.to(dtype=torch.float32)
|
| 615 |
+
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
|
| 616 |
+
|
| 617 |
+
masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5)
|
| 618 |
+
_, masked_video_latents = self.prepare_mask_latents(
|
| 619 |
+
None,
|
| 620 |
+
masked_video,
|
| 621 |
+
batch_size,
|
| 622 |
+
height,
|
| 623 |
+
width,
|
| 624 |
+
weight_dtype,
|
| 625 |
+
device,
|
| 626 |
+
generator,
|
| 627 |
+
do_classifier_free_guidance,
|
| 628 |
+
noise_aug_strength=None,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
mask_condition = torch.concat(
|
| 632 |
+
[
|
| 633 |
+
torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
|
| 634 |
+
mask_condition[:, :, 1:]
|
| 635 |
+
], dim=2
|
| 636 |
+
)
|
| 637 |
+
mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width)
|
| 638 |
+
mask_condition = mask_condition.transpose(1, 2)
|
| 639 |
+
mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype)
|
| 640 |
+
|
| 641 |
+
# Prepare clip latent variables
|
| 642 |
+
if clip_image is not None:
|
| 643 |
+
clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
|
| 644 |
+
clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
|
| 645 |
+
else:
|
| 646 |
+
clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0))
|
| 647 |
+
clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
|
| 648 |
+
clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
|
| 649 |
+
clip_context = torch.zeros_like(clip_context)
|
| 650 |
+
if comfyui_progressbar:
|
| 651 |
+
pbar.update(1)
|
| 652 |
+
|
| 653 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 654 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 655 |
+
|
| 656 |
+
target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
|
| 657 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 658 |
+
# 7. Denoising loop
|
| 659 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 660 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 661 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 662 |
+
for i, t in enumerate(timesteps):
|
| 663 |
+
self.transformer.current_steps = i
|
| 664 |
+
|
| 665 |
+
if self.interrupt:
|
| 666 |
+
continue
|
| 667 |
+
|
| 668 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 669 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 670 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 671 |
+
|
| 672 |
+
if init_video is not None:
|
| 673 |
+
mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
|
| 674 |
+
masked_video_latents_input = (
|
| 675 |
+
torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
|
| 676 |
+
)
|
| 677 |
+
y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype)
|
| 678 |
+
|
| 679 |
+
clip_context_input = (
|
| 680 |
+
torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 684 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 685 |
+
|
| 686 |
+
# predict noise model_output
|
| 687 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 688 |
+
noise_pred = self.transformer(
|
| 689 |
+
x=latent_model_input,
|
| 690 |
+
context=in_prompt_embeds,
|
| 691 |
+
t=timestep,
|
| 692 |
+
seq_len=seq_len,
|
| 693 |
+
y=y,
|
| 694 |
+
clip_fea=clip_context_input,
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
# perform guidance
|
| 698 |
+
if do_classifier_free_guidance:
|
| 699 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 700 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 701 |
+
|
| 702 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 703 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 704 |
+
|
| 705 |
+
if callback_on_step_end is not None:
|
| 706 |
+
callback_kwargs = {}
|
| 707 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 708 |
+
callback_kwargs[k] = locals()[k]
|
| 709 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 710 |
+
|
| 711 |
+
latents = callback_outputs.pop("latents", latents)
|
| 712 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 713 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 714 |
+
|
| 715 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 716 |
+
progress_bar.update()
|
| 717 |
+
if comfyui_progressbar:
|
| 718 |
+
pbar.update(1)
|
| 719 |
+
|
| 720 |
+
if output_type == "numpy":
|
| 721 |
+
video = self.decode_latents(latents)
|
| 722 |
+
elif not output_type == "latent":
|
| 723 |
+
video = self.decode_latents(latents)
|
| 724 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 725 |
+
else:
|
| 726 |
+
video = latents
|
| 727 |
+
|
| 728 |
+
# Offload all models
|
| 729 |
+
self.maybe_free_model_hooks()
|
| 730 |
+
|
| 731 |
+
if not return_dict:
|
| 732 |
+
video = torch.from_numpy(video)
|
| 733 |
+
|
| 734 |
+
return WanPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_wan_phantom.py
ADDED
|
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 11 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 12 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 13 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 14 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 15 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 16 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 17 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 18 |
+
from diffusers.video_processor import VideoProcessor
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from transformers import T5Tokenizer
|
| 22 |
+
|
| 23 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
|
| 24 |
+
WanT5EncoderModel, WanTransformer3DModel)
|
| 25 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 26 |
+
get_sampling_sigmas)
|
| 27 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
EXAMPLE_DOC_STRING = """
|
| 33 |
+
Examples:
|
| 34 |
+
```python
|
| 35 |
+
pass
|
| 36 |
+
```
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 41 |
+
def retrieve_timesteps(
|
| 42 |
+
scheduler,
|
| 43 |
+
num_inference_steps: Optional[int] = None,
|
| 44 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 45 |
+
timesteps: Optional[List[int]] = None,
|
| 46 |
+
sigmas: Optional[List[float]] = None,
|
| 47 |
+
**kwargs,
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 51 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
scheduler (`SchedulerMixin`):
|
| 55 |
+
The scheduler to get timesteps from.
|
| 56 |
+
num_inference_steps (`int`):
|
| 57 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 58 |
+
must be `None`.
|
| 59 |
+
device (`str` or `torch.device`, *optional*):
|
| 60 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 61 |
+
timesteps (`List[int]`, *optional*):
|
| 62 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 63 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 64 |
+
sigmas (`List[float]`, *optional*):
|
| 65 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 66 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 70 |
+
second element is the number of inference steps.
|
| 71 |
+
"""
|
| 72 |
+
if timesteps is not None and sigmas is not None:
|
| 73 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 74 |
+
if timesteps is not None:
|
| 75 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 76 |
+
if not accepts_timesteps:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 79 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 80 |
+
)
|
| 81 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 82 |
+
timesteps = scheduler.timesteps
|
| 83 |
+
num_inference_steps = len(timesteps)
|
| 84 |
+
elif sigmas is not None:
|
| 85 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 86 |
+
if not accept_sigmas:
|
| 87 |
+
raise ValueError(
|
| 88 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 89 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 90 |
+
)
|
| 91 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 92 |
+
timesteps = scheduler.timesteps
|
| 93 |
+
num_inference_steps = len(timesteps)
|
| 94 |
+
else:
|
| 95 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 96 |
+
timesteps = scheduler.timesteps
|
| 97 |
+
return timesteps, num_inference_steps
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def resize_mask(mask, latent, process_first_frame_only=True):
|
| 101 |
+
latent_size = latent.size()
|
| 102 |
+
batch_size, channels, num_frames, height, width = mask.shape
|
| 103 |
+
|
| 104 |
+
if process_first_frame_only:
|
| 105 |
+
target_size = list(latent_size[2:])
|
| 106 |
+
target_size[0] = 1
|
| 107 |
+
first_frame_resized = F.interpolate(
|
| 108 |
+
mask[:, :, 0:1, :, :],
|
| 109 |
+
size=target_size,
|
| 110 |
+
mode='trilinear',
|
| 111 |
+
align_corners=False
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
target_size = list(latent_size[2:])
|
| 115 |
+
target_size[0] = target_size[0] - 1
|
| 116 |
+
if target_size[0] != 0:
|
| 117 |
+
remaining_frames_resized = F.interpolate(
|
| 118 |
+
mask[:, :, 1:, :, :],
|
| 119 |
+
size=target_size,
|
| 120 |
+
mode='trilinear',
|
| 121 |
+
align_corners=False
|
| 122 |
+
)
|
| 123 |
+
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
| 124 |
+
else:
|
| 125 |
+
resized_mask = first_frame_resized
|
| 126 |
+
else:
|
| 127 |
+
target_size = list(latent_size[2:])
|
| 128 |
+
resized_mask = F.interpolate(
|
| 129 |
+
mask,
|
| 130 |
+
size=target_size,
|
| 131 |
+
mode='trilinear',
|
| 132 |
+
align_corners=False
|
| 133 |
+
)
|
| 134 |
+
return resized_mask
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@dataclass
|
| 138 |
+
class WanPipelineOutput(BaseOutput):
|
| 139 |
+
r"""
|
| 140 |
+
Output class for CogVideo pipelines.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 144 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 145 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 146 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
videos: torch.Tensor
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class WanFunPhantomPipeline(DiffusionPipeline):
|
| 153 |
+
r"""
|
| 154 |
+
Pipeline for text-to-video generation using Wan.
|
| 155 |
+
|
| 156 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 157 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
_optional_components = []
|
| 161 |
+
model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae"
|
| 162 |
+
|
| 163 |
+
_callback_tensor_inputs = [
|
| 164 |
+
"latents",
|
| 165 |
+
"prompt_embeds",
|
| 166 |
+
"negative_prompt_embeds",
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
tokenizer: AutoTokenizer,
|
| 172 |
+
text_encoder: WanT5EncoderModel,
|
| 173 |
+
vae: AutoencoderKLWan,
|
| 174 |
+
transformer: WanTransformer3DModel,
|
| 175 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 176 |
+
):
|
| 177 |
+
super().__init__()
|
| 178 |
+
|
| 179 |
+
self.register_modules(
|
| 180 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 184 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 185 |
+
self.mask_processor = VaeImageProcessor(
|
| 186 |
+
vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def _get_t5_prompt_embeds(
|
| 190 |
+
self,
|
| 191 |
+
prompt: Union[str, List[str]] = None,
|
| 192 |
+
num_videos_per_prompt: int = 1,
|
| 193 |
+
max_sequence_length: int = 512,
|
| 194 |
+
device: Optional[torch.device] = None,
|
| 195 |
+
dtype: Optional[torch.dtype] = None,
|
| 196 |
+
):
|
| 197 |
+
device = device or self._execution_device
|
| 198 |
+
dtype = dtype or self.text_encoder.dtype
|
| 199 |
+
|
| 200 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 201 |
+
batch_size = len(prompt)
|
| 202 |
+
|
| 203 |
+
text_inputs = self.tokenizer(
|
| 204 |
+
prompt,
|
| 205 |
+
padding="max_length",
|
| 206 |
+
max_length=max_sequence_length,
|
| 207 |
+
truncation=True,
|
| 208 |
+
add_special_tokens=True,
|
| 209 |
+
return_tensors="pt",
|
| 210 |
+
)
|
| 211 |
+
text_input_ids = text_inputs.input_ids
|
| 212 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 213 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 214 |
+
|
| 215 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 216 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 217 |
+
logger.warning(
|
| 218 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 219 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 223 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 224 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 225 |
+
|
| 226 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 227 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 228 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 229 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 230 |
+
|
| 231 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 232 |
+
|
| 233 |
+
def encode_prompt(
|
| 234 |
+
self,
|
| 235 |
+
prompt: Union[str, List[str]],
|
| 236 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 237 |
+
do_classifier_free_guidance: bool = True,
|
| 238 |
+
num_videos_per_prompt: int = 1,
|
| 239 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 240 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 241 |
+
max_sequence_length: int = 512,
|
| 242 |
+
device: Optional[torch.device] = None,
|
| 243 |
+
dtype: Optional[torch.dtype] = None,
|
| 244 |
+
):
|
| 245 |
+
r"""
|
| 246 |
+
Encodes the prompt into text encoder hidden states.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 250 |
+
prompt to be encoded
|
| 251 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 252 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 253 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 254 |
+
less than `1`).
|
| 255 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 256 |
+
Whether to use classifier free guidance or not.
|
| 257 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 258 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 259 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 260 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 261 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 262 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 263 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 264 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 265 |
+
argument.
|
| 266 |
+
device: (`torch.device`, *optional*):
|
| 267 |
+
torch device
|
| 268 |
+
dtype: (`torch.dtype`, *optional*):
|
| 269 |
+
torch dtype
|
| 270 |
+
"""
|
| 271 |
+
device = device or self._execution_device
|
| 272 |
+
|
| 273 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 274 |
+
if prompt is not None:
|
| 275 |
+
batch_size = len(prompt)
|
| 276 |
+
else:
|
| 277 |
+
batch_size = prompt_embeds.shape[0]
|
| 278 |
+
|
| 279 |
+
if prompt_embeds is None:
|
| 280 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 281 |
+
prompt=prompt,
|
| 282 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 283 |
+
max_sequence_length=max_sequence_length,
|
| 284 |
+
device=device,
|
| 285 |
+
dtype=dtype,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 289 |
+
negative_prompt = negative_prompt or ""
|
| 290 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 291 |
+
|
| 292 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 293 |
+
raise TypeError(
|
| 294 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 295 |
+
f" {type(prompt)}."
|
| 296 |
+
)
|
| 297 |
+
elif batch_size != len(negative_prompt):
|
| 298 |
+
raise ValueError(
|
| 299 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 300 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 301 |
+
" the batch size of `prompt`."
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 305 |
+
prompt=negative_prompt,
|
| 306 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 307 |
+
max_sequence_length=max_sequence_length,
|
| 308 |
+
device=device,
|
| 309 |
+
dtype=dtype,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return prompt_embeds, negative_prompt_embeds
|
| 313 |
+
|
| 314 |
+
def prepare_latents(
|
| 315 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 316 |
+
):
|
| 317 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 318 |
+
raise ValueError(
|
| 319 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 320 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
shape = (
|
| 324 |
+
batch_size,
|
| 325 |
+
num_channels_latents,
|
| 326 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1,
|
| 327 |
+
height // self.vae.spatial_compression_ratio,
|
| 328 |
+
width // self.vae.spatial_compression_ratio,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if latents is None:
|
| 332 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 333 |
+
else:
|
| 334 |
+
latents = latents.to(device)
|
| 335 |
+
|
| 336 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 337 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 338 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 339 |
+
return latents
|
| 340 |
+
|
| 341 |
+
def prepare_control_latents(
|
| 342 |
+
self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
| 343 |
+
):
|
| 344 |
+
# resize the control to latents shape as we concatenate the control to the latents
|
| 345 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 346 |
+
# and half precision
|
| 347 |
+
|
| 348 |
+
if control is not None:
|
| 349 |
+
control = control.to(device=device, dtype=dtype)
|
| 350 |
+
bs = 1
|
| 351 |
+
new_control = []
|
| 352 |
+
for i in range(0, control.shape[0], bs):
|
| 353 |
+
control_bs = control[i : i + bs]
|
| 354 |
+
control_bs = self.vae.encode(control_bs)[0]
|
| 355 |
+
control_bs = control_bs.mode()
|
| 356 |
+
new_control.append(control_bs)
|
| 357 |
+
control = torch.cat(new_control, dim = 0)
|
| 358 |
+
|
| 359 |
+
if control_image is not None:
|
| 360 |
+
control_image = control_image.to(device=device, dtype=dtype)
|
| 361 |
+
bs = 1
|
| 362 |
+
new_control_pixel_values = []
|
| 363 |
+
for i in range(0, control_image.shape[0], bs):
|
| 364 |
+
control_pixel_values_bs = control_image[i : i + bs]
|
| 365 |
+
control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
|
| 366 |
+
control_pixel_values_bs = control_pixel_values_bs.mode()
|
| 367 |
+
new_control_pixel_values.append(control_pixel_values_bs)
|
| 368 |
+
control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
|
| 369 |
+
else:
|
| 370 |
+
control_image_latents = None
|
| 371 |
+
|
| 372 |
+
return control, control_image_latents
|
| 373 |
+
|
| 374 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 375 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 376 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 377 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 378 |
+
frames = frames.cpu().float().numpy()
|
| 379 |
+
return frames
|
| 380 |
+
|
| 381 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 382 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 383 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 384 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 385 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 386 |
+
# and should be between [0, 1]
|
| 387 |
+
|
| 388 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 389 |
+
extra_step_kwargs = {}
|
| 390 |
+
if accepts_eta:
|
| 391 |
+
extra_step_kwargs["eta"] = eta
|
| 392 |
+
|
| 393 |
+
# check if the scheduler accepts generator
|
| 394 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 395 |
+
if accepts_generator:
|
| 396 |
+
extra_step_kwargs["generator"] = generator
|
| 397 |
+
return extra_step_kwargs
|
| 398 |
+
|
| 399 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 400 |
+
def check_inputs(
|
| 401 |
+
self,
|
| 402 |
+
prompt,
|
| 403 |
+
height,
|
| 404 |
+
width,
|
| 405 |
+
negative_prompt,
|
| 406 |
+
callback_on_step_end_tensor_inputs,
|
| 407 |
+
prompt_embeds=None,
|
| 408 |
+
negative_prompt_embeds=None,
|
| 409 |
+
):
|
| 410 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 411 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 412 |
+
|
| 413 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 414 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 415 |
+
):
|
| 416 |
+
raise ValueError(
|
| 417 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 418 |
+
)
|
| 419 |
+
if prompt is not None and prompt_embeds is not None:
|
| 420 |
+
raise ValueError(
|
| 421 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 422 |
+
" only forward one of the two."
|
| 423 |
+
)
|
| 424 |
+
elif prompt is None and prompt_embeds is None:
|
| 425 |
+
raise ValueError(
|
| 426 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 427 |
+
)
|
| 428 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 429 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 430 |
+
|
| 431 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 432 |
+
raise ValueError(
|
| 433 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 434 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 438 |
+
raise ValueError(
|
| 439 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 440 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 444 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 445 |
+
raise ValueError(
|
| 446 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 447 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 448 |
+
f" {negative_prompt_embeds.shape}."
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
@property
|
| 452 |
+
def guidance_scale(self):
|
| 453 |
+
return self._guidance_scale
|
| 454 |
+
|
| 455 |
+
@property
|
| 456 |
+
def num_timesteps(self):
|
| 457 |
+
return self._num_timesteps
|
| 458 |
+
|
| 459 |
+
@property
|
| 460 |
+
def attention_kwargs(self):
|
| 461 |
+
return self._attention_kwargs
|
| 462 |
+
|
| 463 |
+
@property
|
| 464 |
+
def interrupt(self):
|
| 465 |
+
return self._interrupt
|
| 466 |
+
|
| 467 |
+
@torch.no_grad()
|
| 468 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 469 |
+
def __call__(
|
| 470 |
+
self,
|
| 471 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 472 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 473 |
+
height: int = 480,
|
| 474 |
+
width: int = 720,
|
| 475 |
+
subject_ref_images: Union[torch.FloatTensor] = None,
|
| 476 |
+
num_frames: int = 49,
|
| 477 |
+
num_inference_steps: int = 50,
|
| 478 |
+
timesteps: Optional[List[int]] = None,
|
| 479 |
+
guidance_scale: float = 6,
|
| 480 |
+
num_videos_per_prompt: int = 1,
|
| 481 |
+
eta: float = 0.0,
|
| 482 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 483 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 484 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 485 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 486 |
+
output_type: str = "numpy",
|
| 487 |
+
return_dict: bool = False,
|
| 488 |
+
callback_on_step_end: Optional[
|
| 489 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 490 |
+
] = None,
|
| 491 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 492 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 493 |
+
max_sequence_length: int = 512,
|
| 494 |
+
comfyui_progressbar: bool = False,
|
| 495 |
+
shift: int = 5,
|
| 496 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 497 |
+
"""
|
| 498 |
+
Function invoked when calling the pipeline for generation.
|
| 499 |
+
Args:
|
| 500 |
+
|
| 501 |
+
Examples:
|
| 502 |
+
|
| 503 |
+
Returns:
|
| 504 |
+
|
| 505 |
+
"""
|
| 506 |
+
|
| 507 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 508 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 509 |
+
num_videos_per_prompt = 1
|
| 510 |
+
|
| 511 |
+
# 1. Check inputs. Raise error if not correct
|
| 512 |
+
self.check_inputs(
|
| 513 |
+
prompt,
|
| 514 |
+
height,
|
| 515 |
+
width,
|
| 516 |
+
negative_prompt,
|
| 517 |
+
callback_on_step_end_tensor_inputs,
|
| 518 |
+
prompt_embeds,
|
| 519 |
+
negative_prompt_embeds,
|
| 520 |
+
)
|
| 521 |
+
self._guidance_scale = guidance_scale
|
| 522 |
+
self._attention_kwargs = attention_kwargs
|
| 523 |
+
self._interrupt = False
|
| 524 |
+
|
| 525 |
+
# 2. Default call parameters
|
| 526 |
+
if prompt is not None and isinstance(prompt, str):
|
| 527 |
+
batch_size = 1
|
| 528 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 529 |
+
batch_size = len(prompt)
|
| 530 |
+
else:
|
| 531 |
+
batch_size = prompt_embeds.shape[0]
|
| 532 |
+
|
| 533 |
+
device = self._execution_device
|
| 534 |
+
weight_dtype = self.text_encoder.dtype
|
| 535 |
+
|
| 536 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 537 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 538 |
+
# corresponds to doing no classifier free guidance.
|
| 539 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 540 |
+
|
| 541 |
+
# 3. Encode input prompt
|
| 542 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 543 |
+
prompt,
|
| 544 |
+
negative_prompt,
|
| 545 |
+
do_classifier_free_guidance,
|
| 546 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 547 |
+
prompt_embeds=prompt_embeds,
|
| 548 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 549 |
+
max_sequence_length=max_sequence_length,
|
| 550 |
+
device=device,
|
| 551 |
+
)
|
| 552 |
+
if do_classifier_free_guidance:
|
| 553 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 554 |
+
else:
|
| 555 |
+
in_prompt_embeds = prompt_embeds
|
| 556 |
+
|
| 557 |
+
# 4. Prepare timesteps
|
| 558 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 559 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
| 560 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 561 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 562 |
+
timesteps = self.scheduler.timesteps
|
| 563 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 564 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 565 |
+
timesteps, _ = retrieve_timesteps(
|
| 566 |
+
self.scheduler,
|
| 567 |
+
device=device,
|
| 568 |
+
sigmas=sampling_sigmas)
|
| 569 |
+
else:
|
| 570 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 571 |
+
self._num_timesteps = len(timesteps)
|
| 572 |
+
if comfyui_progressbar:
|
| 573 |
+
from comfy.utils import ProgressBar
|
| 574 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 575 |
+
|
| 576 |
+
# 5. Prepare latents.
|
| 577 |
+
latent_channels = self.vae.config.latent_channels
|
| 578 |
+
latents = self.prepare_latents(
|
| 579 |
+
batch_size * num_videos_per_prompt,
|
| 580 |
+
latent_channels,
|
| 581 |
+
num_frames,
|
| 582 |
+
height,
|
| 583 |
+
width,
|
| 584 |
+
weight_dtype,
|
| 585 |
+
device,
|
| 586 |
+
generator,
|
| 587 |
+
latents,
|
| 588 |
+
)
|
| 589 |
+
if comfyui_progressbar:
|
| 590 |
+
pbar.update(1)
|
| 591 |
+
|
| 592 |
+
if subject_ref_images is not None:
|
| 593 |
+
video_length = subject_ref_images.shape[2]
|
| 594 |
+
subject_ref_images = self.image_processor.preprocess(rearrange(subject_ref_images, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 595 |
+
subject_ref_images = subject_ref_images.to(dtype=torch.float32)
|
| 596 |
+
subject_ref_images = rearrange(subject_ref_images, "(b f) c h w -> b c f h w", f=video_length)
|
| 597 |
+
|
| 598 |
+
subject_ref_images_latentes = torch.cat(
|
| 599 |
+
[
|
| 600 |
+
self.prepare_control_latents(
|
| 601 |
+
None,
|
| 602 |
+
subject_ref_images[:, :, i:i+1],
|
| 603 |
+
batch_size,
|
| 604 |
+
height,
|
| 605 |
+
width,
|
| 606 |
+
weight_dtype,
|
| 607 |
+
device,
|
| 608 |
+
generator,
|
| 609 |
+
do_classifier_free_guidance
|
| 610 |
+
)[1] for i in range(video_length)
|
| 611 |
+
], dim = 2
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
if comfyui_progressbar:
|
| 615 |
+
pbar.update(1)
|
| 616 |
+
|
| 617 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 618 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 619 |
+
|
| 620 |
+
target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
|
| 621 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 622 |
+
# 7. Denoising loop
|
| 623 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 624 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 625 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 626 |
+
for i, t in enumerate(timesteps):
|
| 627 |
+
self.transformer.current_steps = i
|
| 628 |
+
|
| 629 |
+
if self.interrupt:
|
| 630 |
+
continue
|
| 631 |
+
|
| 632 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 633 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 634 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 635 |
+
|
| 636 |
+
if subject_ref_images is not None:
|
| 637 |
+
subject_ref = (
|
| 638 |
+
torch.cat(
|
| 639 |
+
[torch.zeros_like(subject_ref_images_latentes), subject_ref_images_latentes]
|
| 640 |
+
) if do_classifier_free_guidance else subject_ref_images_latentes
|
| 641 |
+
).to(device, weight_dtype)
|
| 642 |
+
else:
|
| 643 |
+
subject_ref = None
|
| 644 |
+
|
| 645 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 646 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 647 |
+
|
| 648 |
+
# predict noise model_output
|
| 649 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 650 |
+
noise_pred = self.transformer(
|
| 651 |
+
x=latent_model_input,
|
| 652 |
+
context=in_prompt_embeds,
|
| 653 |
+
t=timestep,
|
| 654 |
+
seq_len=seq_len,
|
| 655 |
+
subject_ref=subject_ref,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
# perform guidance
|
| 659 |
+
if do_classifier_free_guidance:
|
| 660 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 661 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 662 |
+
|
| 663 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 664 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 665 |
+
|
| 666 |
+
if callback_on_step_end is not None:
|
| 667 |
+
callback_kwargs = {}
|
| 668 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 669 |
+
callback_kwargs[k] = locals()[k]
|
| 670 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 671 |
+
|
| 672 |
+
latents = callback_outputs.pop("latents", latents)
|
| 673 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 674 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 675 |
+
|
| 676 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 677 |
+
progress_bar.update()
|
| 678 |
+
if comfyui_progressbar:
|
| 679 |
+
pbar.update(1)
|
| 680 |
+
|
| 681 |
+
if output_type == "numpy":
|
| 682 |
+
video = self.decode_latents(latents)
|
| 683 |
+
elif not output_type == "latent":
|
| 684 |
+
video = self.decode_latents(latents)
|
| 685 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 686 |
+
else:
|
| 687 |
+
video = latents
|
| 688 |
+
|
| 689 |
+
# Offload all models
|
| 690 |
+
self.maybe_free_model_hooks()
|
| 691 |
+
|
| 692 |
+
if not return_dict:
|
| 693 |
+
video = torch.from_numpy(video)
|
| 694 |
+
|
| 695 |
+
return WanPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_wan_vace.py
ADDED
|
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 11 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 12 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 13 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
| 14 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 15 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 16 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 17 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 18 |
+
from diffusers.video_processor import VideoProcessor
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from transformers import T5Tokenizer
|
| 22 |
+
|
| 23 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer,
|
| 24 |
+
WanT5EncoderModel, VaceWanTransformer3DModel)
|
| 25 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 26 |
+
get_sampling_sigmas)
|
| 27 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
EXAMPLE_DOC_STRING = """
|
| 33 |
+
Examples:
|
| 34 |
+
```python
|
| 35 |
+
pass
|
| 36 |
+
```
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 41 |
+
def retrieve_timesteps(
|
| 42 |
+
scheduler,
|
| 43 |
+
num_inference_steps: Optional[int] = None,
|
| 44 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 45 |
+
timesteps: Optional[List[int]] = None,
|
| 46 |
+
sigmas: Optional[List[float]] = None,
|
| 47 |
+
**kwargs,
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 51 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
scheduler (`SchedulerMixin`):
|
| 55 |
+
The scheduler to get timesteps from.
|
| 56 |
+
num_inference_steps (`int`):
|
| 57 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 58 |
+
must be `None`.
|
| 59 |
+
device (`str` or `torch.device`, *optional*):
|
| 60 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 61 |
+
timesteps (`List[int]`, *optional*):
|
| 62 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 63 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 64 |
+
sigmas (`List[float]`, *optional*):
|
| 65 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 66 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 70 |
+
second element is the number of inference steps.
|
| 71 |
+
"""
|
| 72 |
+
if timesteps is not None and sigmas is not None:
|
| 73 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 74 |
+
if timesteps is not None:
|
| 75 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 76 |
+
if not accepts_timesteps:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 79 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 80 |
+
)
|
| 81 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 82 |
+
timesteps = scheduler.timesteps
|
| 83 |
+
num_inference_steps = len(timesteps)
|
| 84 |
+
elif sigmas is not None:
|
| 85 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 86 |
+
if not accept_sigmas:
|
| 87 |
+
raise ValueError(
|
| 88 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 89 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 90 |
+
)
|
| 91 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 92 |
+
timesteps = scheduler.timesteps
|
| 93 |
+
num_inference_steps = len(timesteps)
|
| 94 |
+
else:
|
| 95 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 96 |
+
timesteps = scheduler.timesteps
|
| 97 |
+
return timesteps, num_inference_steps
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def resize_mask(mask, latent, process_first_frame_only=True):
|
| 101 |
+
latent_size = latent.size()
|
| 102 |
+
batch_size, channels, num_frames, height, width = mask.shape
|
| 103 |
+
|
| 104 |
+
if process_first_frame_only:
|
| 105 |
+
target_size = list(latent_size[2:])
|
| 106 |
+
target_size[0] = 1
|
| 107 |
+
first_frame_resized = F.interpolate(
|
| 108 |
+
mask[:, :, 0:1, :, :],
|
| 109 |
+
size=target_size,
|
| 110 |
+
mode='trilinear',
|
| 111 |
+
align_corners=False
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
target_size = list(latent_size[2:])
|
| 115 |
+
target_size[0] = target_size[0] - 1
|
| 116 |
+
if target_size[0] != 0:
|
| 117 |
+
remaining_frames_resized = F.interpolate(
|
| 118 |
+
mask[:, :, 1:, :, :],
|
| 119 |
+
size=target_size,
|
| 120 |
+
mode='trilinear',
|
| 121 |
+
align_corners=False
|
| 122 |
+
)
|
| 123 |
+
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
| 124 |
+
else:
|
| 125 |
+
resized_mask = first_frame_resized
|
| 126 |
+
else:
|
| 127 |
+
target_size = list(latent_size[2:])
|
| 128 |
+
resized_mask = F.interpolate(
|
| 129 |
+
mask,
|
| 130 |
+
size=target_size,
|
| 131 |
+
mode='trilinear',
|
| 132 |
+
align_corners=False
|
| 133 |
+
)
|
| 134 |
+
return resized_mask
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@dataclass
|
| 138 |
+
class WanPipelineOutput(BaseOutput):
|
| 139 |
+
r"""
|
| 140 |
+
Output class for CogVideo pipelines.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 144 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 145 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 146 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
videos: torch.Tensor
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class WanVacePipeline(DiffusionPipeline):
|
| 153 |
+
r"""
|
| 154 |
+
Pipeline for text-to-video generation using Wan.
|
| 155 |
+
|
| 156 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 157 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
_optional_components = []
|
| 161 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 162 |
+
|
| 163 |
+
_callback_tensor_inputs = [
|
| 164 |
+
"latents",
|
| 165 |
+
"prompt_embeds",
|
| 166 |
+
"negative_prompt_embeds",
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
tokenizer: AutoTokenizer,
|
| 172 |
+
text_encoder: WanT5EncoderModel,
|
| 173 |
+
vae: AutoencoderKLWan,
|
| 174 |
+
transformer: VaceWanTransformer3DModel,
|
| 175 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 176 |
+
):
|
| 177 |
+
super().__init__()
|
| 178 |
+
|
| 179 |
+
self.register_modules(
|
| 180 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 184 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 185 |
+
self.mask_processor = VaeImageProcessor(
|
| 186 |
+
vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def _get_t5_prompt_embeds(
|
| 190 |
+
self,
|
| 191 |
+
prompt: Union[str, List[str]] = None,
|
| 192 |
+
num_videos_per_prompt: int = 1,
|
| 193 |
+
max_sequence_length: int = 512,
|
| 194 |
+
device: Optional[torch.device] = None,
|
| 195 |
+
dtype: Optional[torch.dtype] = None,
|
| 196 |
+
):
|
| 197 |
+
device = device or self._execution_device
|
| 198 |
+
dtype = dtype or self.text_encoder.dtype
|
| 199 |
+
|
| 200 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 201 |
+
batch_size = len(prompt)
|
| 202 |
+
|
| 203 |
+
text_inputs = self.tokenizer(
|
| 204 |
+
prompt,
|
| 205 |
+
padding="max_length",
|
| 206 |
+
max_length=max_sequence_length,
|
| 207 |
+
truncation=True,
|
| 208 |
+
add_special_tokens=True,
|
| 209 |
+
return_tensors="pt",
|
| 210 |
+
)
|
| 211 |
+
text_input_ids = text_inputs.input_ids
|
| 212 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 213 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 214 |
+
|
| 215 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 216 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 217 |
+
logger.warning(
|
| 218 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 219 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 223 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 224 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 225 |
+
|
| 226 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 227 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 228 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 229 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 230 |
+
|
| 231 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 232 |
+
|
| 233 |
+
def encode_prompt(
|
| 234 |
+
self,
|
| 235 |
+
prompt: Union[str, List[str]],
|
| 236 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 237 |
+
do_classifier_free_guidance: bool = True,
|
| 238 |
+
num_videos_per_prompt: int = 1,
|
| 239 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 240 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 241 |
+
max_sequence_length: int = 512,
|
| 242 |
+
device: Optional[torch.device] = None,
|
| 243 |
+
dtype: Optional[torch.dtype] = None,
|
| 244 |
+
):
|
| 245 |
+
r"""
|
| 246 |
+
Encodes the prompt into text encoder hidden states.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 250 |
+
prompt to be encoded
|
| 251 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 252 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 253 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 254 |
+
less than `1`).
|
| 255 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 256 |
+
Whether to use classifier free guidance or not.
|
| 257 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 258 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 259 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 260 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 261 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 262 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 263 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 264 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 265 |
+
argument.
|
| 266 |
+
device: (`torch.device`, *optional*):
|
| 267 |
+
torch device
|
| 268 |
+
dtype: (`torch.dtype`, *optional*):
|
| 269 |
+
torch dtype
|
| 270 |
+
"""
|
| 271 |
+
device = device or self._execution_device
|
| 272 |
+
|
| 273 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 274 |
+
if prompt is not None:
|
| 275 |
+
batch_size = len(prompt)
|
| 276 |
+
else:
|
| 277 |
+
batch_size = prompt_embeds.shape[0]
|
| 278 |
+
|
| 279 |
+
if prompt_embeds is None:
|
| 280 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 281 |
+
prompt=prompt,
|
| 282 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 283 |
+
max_sequence_length=max_sequence_length,
|
| 284 |
+
device=device,
|
| 285 |
+
dtype=dtype,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 289 |
+
negative_prompt = negative_prompt or ""
|
| 290 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 291 |
+
|
| 292 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 293 |
+
raise TypeError(
|
| 294 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 295 |
+
f" {type(prompt)}."
|
| 296 |
+
)
|
| 297 |
+
elif batch_size != len(negative_prompt):
|
| 298 |
+
raise ValueError(
|
| 299 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 300 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 301 |
+
" the batch size of `prompt`."
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 305 |
+
prompt=negative_prompt,
|
| 306 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 307 |
+
max_sequence_length=max_sequence_length,
|
| 308 |
+
device=device,
|
| 309 |
+
dtype=dtype,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return prompt_embeds, negative_prompt_embeds
|
| 313 |
+
|
| 314 |
+
def prepare_latents(
|
| 315 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None
|
| 316 |
+
):
|
| 317 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 318 |
+
raise ValueError(
|
| 319 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 320 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
shape = (
|
| 324 |
+
batch_size,
|
| 325 |
+
num_channels_latents,
|
| 326 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents,
|
| 327 |
+
height // self.vae.spatial_compression_ratio,
|
| 328 |
+
width // self.vae.spatial_compression_ratio,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if latents is None:
|
| 332 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 333 |
+
else:
|
| 334 |
+
latents = latents.to(device)
|
| 335 |
+
|
| 336 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 337 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 338 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 339 |
+
return latents
|
| 340 |
+
|
| 341 |
+
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
|
| 342 |
+
vae = self.vae if vae is None else vae
|
| 343 |
+
weight_dtype = frames.dtype
|
| 344 |
+
if ref_images is None:
|
| 345 |
+
ref_images = [None] * len(frames)
|
| 346 |
+
else:
|
| 347 |
+
assert len(frames) == len(ref_images)
|
| 348 |
+
|
| 349 |
+
if masks is None:
|
| 350 |
+
latents = vae.encode(frames)[0].mode()
|
| 351 |
+
else:
|
| 352 |
+
masks = [torch.where(m > 0.5, 1.0, 0.0).to(weight_dtype) for m in masks]
|
| 353 |
+
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
|
| 354 |
+
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
| 355 |
+
inactive = vae.encode(inactive)[0].mode()
|
| 356 |
+
reactive = vae.encode(reactive)[0].mode()
|
| 357 |
+
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
| 358 |
+
|
| 359 |
+
cat_latents = []
|
| 360 |
+
for latent, refs in zip(latents, ref_images):
|
| 361 |
+
if refs is not None:
|
| 362 |
+
if masks is None:
|
| 363 |
+
ref_latent = vae.encode(refs)[0].mode()
|
| 364 |
+
else:
|
| 365 |
+
ref_latent = vae.encode(refs)[0].mode()
|
| 366 |
+
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
|
| 367 |
+
assert all([x.shape[1] == 1 for x in ref_latent])
|
| 368 |
+
latent = torch.cat([*ref_latent, latent], dim=1)
|
| 369 |
+
cat_latents.append(latent)
|
| 370 |
+
return cat_latents
|
| 371 |
+
|
| 372 |
+
def vace_encode_masks(self, masks, ref_images=None, vae_stride=[4, 8, 8]):
|
| 373 |
+
if ref_images is None:
|
| 374 |
+
ref_images = [None] * len(masks)
|
| 375 |
+
else:
|
| 376 |
+
assert len(masks) == len(ref_images)
|
| 377 |
+
|
| 378 |
+
result_masks = []
|
| 379 |
+
for mask, refs in zip(masks, ref_images):
|
| 380 |
+
c, depth, height, width = mask.shape
|
| 381 |
+
new_depth = int((depth + 3) // vae_stride[0])
|
| 382 |
+
height = 2 * (int(height) // (vae_stride[1] * 2))
|
| 383 |
+
width = 2 * (int(width) // (vae_stride[2] * 2))
|
| 384 |
+
|
| 385 |
+
# reshape
|
| 386 |
+
mask = mask[0, :, :, :]
|
| 387 |
+
mask = mask.view(
|
| 388 |
+
depth, height, vae_stride[1], width, vae_stride[1]
|
| 389 |
+
) # depth, height, 8, width, 8
|
| 390 |
+
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
|
| 391 |
+
mask = mask.reshape(
|
| 392 |
+
vae_stride[1] * vae_stride[2], depth, height, width
|
| 393 |
+
) # 8*8, depth, height, width
|
| 394 |
+
|
| 395 |
+
# interpolation
|
| 396 |
+
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
|
| 397 |
+
|
| 398 |
+
if refs is not None:
|
| 399 |
+
length = len(refs)
|
| 400 |
+
mask_pad = torch.zeros_like(mask[:, :length, :, :])
|
| 401 |
+
mask = torch.cat((mask_pad, mask), dim=1)
|
| 402 |
+
result_masks.append(mask)
|
| 403 |
+
return result_masks
|
| 404 |
+
|
| 405 |
+
def vace_latent(self, z, m):
|
| 406 |
+
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
| 407 |
+
|
| 408 |
+
def prepare_control_latents(
|
| 409 |
+
self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
| 410 |
+
):
|
| 411 |
+
# resize the control to latents shape as we concatenate the control to the latents
|
| 412 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 413 |
+
# and half precision
|
| 414 |
+
|
| 415 |
+
if control is not None:
|
| 416 |
+
control = control.to(device=device, dtype=dtype)
|
| 417 |
+
bs = 1
|
| 418 |
+
new_control = []
|
| 419 |
+
for i in range(0, control.shape[0], bs):
|
| 420 |
+
control_bs = control[i : i + bs]
|
| 421 |
+
control_bs = self.vae.encode(control_bs)[0]
|
| 422 |
+
control_bs = control_bs.mode()
|
| 423 |
+
new_control.append(control_bs)
|
| 424 |
+
control = torch.cat(new_control, dim = 0)
|
| 425 |
+
|
| 426 |
+
if control_image is not None:
|
| 427 |
+
control_image = control_image.to(device=device, dtype=dtype)
|
| 428 |
+
bs = 1
|
| 429 |
+
new_control_pixel_values = []
|
| 430 |
+
for i in range(0, control_image.shape[0], bs):
|
| 431 |
+
control_pixel_values_bs = control_image[i : i + bs]
|
| 432 |
+
control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
|
| 433 |
+
control_pixel_values_bs = control_pixel_values_bs.mode()
|
| 434 |
+
new_control_pixel_values.append(control_pixel_values_bs)
|
| 435 |
+
control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
|
| 436 |
+
else:
|
| 437 |
+
control_image_latents = None
|
| 438 |
+
|
| 439 |
+
return control, control_image_latents
|
| 440 |
+
|
| 441 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 442 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 443 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 444 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 445 |
+
frames = frames.cpu().float().numpy()
|
| 446 |
+
return frames
|
| 447 |
+
|
| 448 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 449 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 450 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 451 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 452 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 453 |
+
# and should be between [0, 1]
|
| 454 |
+
|
| 455 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 456 |
+
extra_step_kwargs = {}
|
| 457 |
+
if accepts_eta:
|
| 458 |
+
extra_step_kwargs["eta"] = eta
|
| 459 |
+
|
| 460 |
+
# check if the scheduler accepts generator
|
| 461 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 462 |
+
if accepts_generator:
|
| 463 |
+
extra_step_kwargs["generator"] = generator
|
| 464 |
+
return extra_step_kwargs
|
| 465 |
+
|
| 466 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 467 |
+
def check_inputs(
|
| 468 |
+
self,
|
| 469 |
+
prompt,
|
| 470 |
+
height,
|
| 471 |
+
width,
|
| 472 |
+
negative_prompt,
|
| 473 |
+
callback_on_step_end_tensor_inputs,
|
| 474 |
+
prompt_embeds=None,
|
| 475 |
+
negative_prompt_embeds=None,
|
| 476 |
+
):
|
| 477 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 478 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 479 |
+
|
| 480 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 481 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 482 |
+
):
|
| 483 |
+
raise ValueError(
|
| 484 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 485 |
+
)
|
| 486 |
+
if prompt is not None and prompt_embeds is not None:
|
| 487 |
+
raise ValueError(
|
| 488 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 489 |
+
" only forward one of the two."
|
| 490 |
+
)
|
| 491 |
+
elif prompt is None and prompt_embeds is None:
|
| 492 |
+
raise ValueError(
|
| 493 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 494 |
+
)
|
| 495 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 496 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 497 |
+
|
| 498 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 499 |
+
raise ValueError(
|
| 500 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 501 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 505 |
+
raise ValueError(
|
| 506 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 507 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 511 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 512 |
+
raise ValueError(
|
| 513 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 514 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 515 |
+
f" {negative_prompt_embeds.shape}."
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
@property
|
| 519 |
+
def guidance_scale(self):
|
| 520 |
+
return self._guidance_scale
|
| 521 |
+
|
| 522 |
+
@property
|
| 523 |
+
def num_timesteps(self):
|
| 524 |
+
return self._num_timesteps
|
| 525 |
+
|
| 526 |
+
@property
|
| 527 |
+
def attention_kwargs(self):
|
| 528 |
+
return self._attention_kwargs
|
| 529 |
+
|
| 530 |
+
@property
|
| 531 |
+
def interrupt(self):
|
| 532 |
+
return self._interrupt
|
| 533 |
+
|
| 534 |
+
@torch.no_grad()
|
| 535 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 536 |
+
def __call__(
|
| 537 |
+
self,
|
| 538 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 539 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 540 |
+
height: int = 480,
|
| 541 |
+
width: int = 720,
|
| 542 |
+
video: Union[torch.FloatTensor] = None,
|
| 543 |
+
mask_video: Union[torch.FloatTensor] = None,
|
| 544 |
+
control_video: Union[torch.FloatTensor] = None,
|
| 545 |
+
subject_ref_images: Union[torch.FloatTensor] = None,
|
| 546 |
+
num_frames: int = 49,
|
| 547 |
+
num_inference_steps: int = 50,
|
| 548 |
+
timesteps: Optional[List[int]] = None,
|
| 549 |
+
guidance_scale: float = 6,
|
| 550 |
+
num_videos_per_prompt: int = 1,
|
| 551 |
+
eta: float = 0.0,
|
| 552 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 553 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 554 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 555 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 556 |
+
output_type: str = "numpy",
|
| 557 |
+
return_dict: bool = False,
|
| 558 |
+
callback_on_step_end: Optional[
|
| 559 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 560 |
+
] = None,
|
| 561 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 562 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 563 |
+
max_sequence_length: int = 512,
|
| 564 |
+
comfyui_progressbar: bool = False,
|
| 565 |
+
shift: int = 5,
|
| 566 |
+
vace_context_scale: float = 1.0
|
| 567 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 568 |
+
"""
|
| 569 |
+
Function invoked when calling the pipeline for generation.
|
| 570 |
+
Args:
|
| 571 |
+
|
| 572 |
+
Examples:
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
|
| 576 |
+
"""
|
| 577 |
+
|
| 578 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 579 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 580 |
+
num_videos_per_prompt = 1
|
| 581 |
+
|
| 582 |
+
# 1. Check inputs. Raise error if not correct
|
| 583 |
+
self.check_inputs(
|
| 584 |
+
prompt,
|
| 585 |
+
height,
|
| 586 |
+
width,
|
| 587 |
+
negative_prompt,
|
| 588 |
+
callback_on_step_end_tensor_inputs,
|
| 589 |
+
prompt_embeds,
|
| 590 |
+
negative_prompt_embeds,
|
| 591 |
+
)
|
| 592 |
+
self._guidance_scale = guidance_scale
|
| 593 |
+
self._attention_kwargs = attention_kwargs
|
| 594 |
+
self._interrupt = False
|
| 595 |
+
|
| 596 |
+
# 2. Default call parameters
|
| 597 |
+
if prompt is not None and isinstance(prompt, str):
|
| 598 |
+
batch_size = 1
|
| 599 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 600 |
+
batch_size = len(prompt)
|
| 601 |
+
else:
|
| 602 |
+
batch_size = prompt_embeds.shape[0]
|
| 603 |
+
|
| 604 |
+
device = self._execution_device
|
| 605 |
+
weight_dtype = self.text_encoder.dtype
|
| 606 |
+
|
| 607 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 608 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 609 |
+
# corresponds to doing no classifier free guidance.
|
| 610 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 611 |
+
|
| 612 |
+
# 3. Encode input prompt
|
| 613 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 614 |
+
prompt,
|
| 615 |
+
negative_prompt,
|
| 616 |
+
do_classifier_free_guidance,
|
| 617 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 618 |
+
prompt_embeds=prompt_embeds,
|
| 619 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 620 |
+
max_sequence_length=max_sequence_length,
|
| 621 |
+
device=device,
|
| 622 |
+
)
|
| 623 |
+
if do_classifier_free_guidance:
|
| 624 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 625 |
+
else:
|
| 626 |
+
in_prompt_embeds = prompt_embeds
|
| 627 |
+
|
| 628 |
+
# 4. Prepare timesteps
|
| 629 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 630 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
| 631 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 632 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 633 |
+
timesteps = self.scheduler.timesteps
|
| 634 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 635 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 636 |
+
timesteps, _ = retrieve_timesteps(
|
| 637 |
+
self.scheduler,
|
| 638 |
+
device=device,
|
| 639 |
+
sigmas=sampling_sigmas)
|
| 640 |
+
else:
|
| 641 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 642 |
+
self._num_timesteps = len(timesteps)
|
| 643 |
+
if comfyui_progressbar:
|
| 644 |
+
from comfy.utils import ProgressBar
|
| 645 |
+
pbar = ProgressBar(num_inference_steps + 2)
|
| 646 |
+
|
| 647 |
+
latent_channels = self.vae.config.latent_channels
|
| 648 |
+
|
| 649 |
+
if comfyui_progressbar:
|
| 650 |
+
pbar.update(1)
|
| 651 |
+
|
| 652 |
+
# Prepare mask latent variables
|
| 653 |
+
if mask_video is not None:
|
| 654 |
+
bs, _, video_length, height, width = video.size()
|
| 655 |
+
mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 656 |
+
mask_condition = mask_condition.to(dtype=torch.float32)
|
| 657 |
+
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
|
| 658 |
+
mask_condition = torch.tile(mask_condition, [1, 3, 1, 1, 1]).to(dtype=weight_dtype, device=device)
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
if control_video is not None:
|
| 662 |
+
video_length = control_video.shape[2]
|
| 663 |
+
control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 664 |
+
control_video = control_video.to(dtype=torch.float32)
|
| 665 |
+
input_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
|
| 666 |
+
|
| 667 |
+
input_video = input_video.to(dtype=weight_dtype, device=device)
|
| 668 |
+
|
| 669 |
+
elif video is not None:
|
| 670 |
+
video_length = video.shape[2]
|
| 671 |
+
init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 672 |
+
init_video = init_video.to(dtype=torch.float32)
|
| 673 |
+
init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length).to(dtype=weight_dtype, device=device)
|
| 674 |
+
|
| 675 |
+
input_video = init_video * (mask_condition < 0.5)
|
| 676 |
+
input_video = input_video.to(dtype=weight_dtype, device=device)
|
| 677 |
+
|
| 678 |
+
if subject_ref_images is not None:
|
| 679 |
+
video_length = subject_ref_images.shape[2]
|
| 680 |
+
subject_ref_images = self.image_processor.preprocess(rearrange(subject_ref_images, "b c f h w -> (b f) c h w"), height=height, width=width)
|
| 681 |
+
subject_ref_images = subject_ref_images.to(dtype=torch.float32)
|
| 682 |
+
subject_ref_images = rearrange(subject_ref_images, "(b f) c h w -> b c f h w", f=video_length)
|
| 683 |
+
subject_ref_images = subject_ref_images.to(dtype=weight_dtype, device=device)
|
| 684 |
+
|
| 685 |
+
bs, c, f, h, w = subject_ref_images.size()
|
| 686 |
+
new_subject_ref_images = []
|
| 687 |
+
for i in range(bs):
|
| 688 |
+
new_subject_ref_images.append([])
|
| 689 |
+
for j in range(f):
|
| 690 |
+
new_subject_ref_images[i].append(subject_ref_images[i, :, j:j+1])
|
| 691 |
+
subject_ref_images = new_subject_ref_images
|
| 692 |
+
|
| 693 |
+
vace_latents = self.vace_encode_frames(input_video, subject_ref_images, masks=mask_condition, vae=self.vae)
|
| 694 |
+
mask_latents = self.vace_encode_masks(mask_condition, subject_ref_images)
|
| 695 |
+
vace_context = self.vace_latent(vace_latents, mask_latents)
|
| 696 |
+
|
| 697 |
+
# 5. Prepare latents.
|
| 698 |
+
latents = self.prepare_latents(
|
| 699 |
+
batch_size * num_videos_per_prompt,
|
| 700 |
+
latent_channels,
|
| 701 |
+
num_frames,
|
| 702 |
+
height,
|
| 703 |
+
width,
|
| 704 |
+
weight_dtype,
|
| 705 |
+
device,
|
| 706 |
+
generator,
|
| 707 |
+
latents,
|
| 708 |
+
num_length_latents=vace_latents[0].size(1)
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 712 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 713 |
+
|
| 714 |
+
target_shape = (self.vae.latent_channels, vace_latents[0].size(1), vace_latents[0].size(2), vace_latents[0].size(3))
|
| 715 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 716 |
+
# 7. Denoising loop
|
| 717 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 718 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 719 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 720 |
+
for i, t in enumerate(timesteps):
|
| 721 |
+
self.transformer.current_steps = i
|
| 722 |
+
|
| 723 |
+
if self.interrupt:
|
| 724 |
+
continue
|
| 725 |
+
|
| 726 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 727 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 728 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 729 |
+
|
| 730 |
+
vace_context_input = torch.stack(vace_context * 2) if do_classifier_free_guidance else vace_context
|
| 731 |
+
|
| 732 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 733 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 734 |
+
|
| 735 |
+
# predict noise model_output
|
| 736 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 737 |
+
noise_pred = self.transformer(
|
| 738 |
+
x=latent_model_input,
|
| 739 |
+
context=in_prompt_embeds,
|
| 740 |
+
t=timestep,
|
| 741 |
+
vace_context=vace_context_input,
|
| 742 |
+
seq_len=seq_len,
|
| 743 |
+
vace_context_scale=vace_context_scale
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
# perform guidance
|
| 747 |
+
if do_classifier_free_guidance:
|
| 748 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 749 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 750 |
+
|
| 751 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 752 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 753 |
+
|
| 754 |
+
if callback_on_step_end is not None:
|
| 755 |
+
callback_kwargs = {}
|
| 756 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 757 |
+
callback_kwargs[k] = locals()[k]
|
| 758 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 759 |
+
|
| 760 |
+
latents = callback_outputs.pop("latents", latents)
|
| 761 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 762 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 763 |
+
|
| 764 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 765 |
+
progress_bar.update()
|
| 766 |
+
if comfyui_progressbar:
|
| 767 |
+
pbar.update(1)
|
| 768 |
+
|
| 769 |
+
if subject_ref_images is not None:
|
| 770 |
+
len_subject_ref_images = len(subject_ref_images[0])
|
| 771 |
+
latents = latents[:, :, len_subject_ref_images:, :, :]
|
| 772 |
+
|
| 773 |
+
if output_type == "numpy":
|
| 774 |
+
video = self.decode_latents(latents)
|
| 775 |
+
elif not output_type == "latent":
|
| 776 |
+
video = self.decode_latents(latents)
|
| 777 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 778 |
+
else:
|
| 779 |
+
video = latents
|
| 780 |
+
|
| 781 |
+
# Offload all models
|
| 782 |
+
self.maybe_free_model_hooks()
|
| 783 |
+
|
| 784 |
+
if not return_dict:
|
| 785 |
+
video = torch.from_numpy(video)
|
| 786 |
+
|
| 787 |
+
return WanPipelineOutput(videos=video)
|
videox_fun/pipeline/pipeline_z_image.py
ADDED
|
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
import numpy as np
|
| 17 |
+
import PIL
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 23 |
+
from diffusers.loaders import FromSingleFileMixin
|
| 24 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 25 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 26 |
+
from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
|
| 27 |
+
replace_example_docstring)
|
| 28 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 29 |
+
from transformers import AutoTokenizer, PreTrainedModel
|
| 30 |
+
|
| 31 |
+
from ..models import AutoencoderKL, ZImageTransformer2DModel
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 34 |
+
|
| 35 |
+
EXAMPLE_DOC_STRING = """
|
| 36 |
+
Examples:
|
| 37 |
+
```py
|
| 38 |
+
>>> import torch
|
| 39 |
+
>>> from diffusers import ZImagePipeline
|
| 40 |
+
|
| 41 |
+
>>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
| 42 |
+
>>> pipe.to("cuda")
|
| 43 |
+
|
| 44 |
+
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
|
| 45 |
+
>>> # (1) Use flash attention 2
|
| 46 |
+
>>> # pipe.transformer.set_attention_backend("flash")
|
| 47 |
+
>>> # (2) Use flash attention 3
|
| 48 |
+
>>> # pipe.transformer.set_attention_backend("_flash_3")
|
| 49 |
+
|
| 50 |
+
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
|
| 51 |
+
>>> image = pipe(
|
| 52 |
+
diffusers. prompt,
|
| 53 |
+
diffusers. height=1024,
|
| 54 |
+
diffusers. width=1024,
|
| 55 |
+
diffusers. num_inference_steps=9,
|
| 56 |
+
diffusers. guidance_scale=0.0,
|
| 57 |
+
diffusers. generator=torch.Generator("cuda").manual_seed(42),
|
| 58 |
+
diffusers. ).images[0]
|
| 59 |
+
>>> image.save("zimage.png")
|
| 60 |
+
```
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
| 65 |
+
def calculate_shift(
|
| 66 |
+
image_seq_len,
|
| 67 |
+
base_seq_len: int = 256,
|
| 68 |
+
max_seq_len: int = 4096,
|
| 69 |
+
base_shift: float = 0.5,
|
| 70 |
+
max_shift: float = 1.15,
|
| 71 |
+
):
|
| 72 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 73 |
+
b = base_shift - m * base_seq_len
|
| 74 |
+
mu = image_seq_len * m + b
|
| 75 |
+
return mu
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 79 |
+
def retrieve_timesteps(
|
| 80 |
+
scheduler,
|
| 81 |
+
num_inference_steps: Optional[int] = None,
|
| 82 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 83 |
+
timesteps: Optional[List[int]] = None,
|
| 84 |
+
sigmas: Optional[List[float]] = None,
|
| 85 |
+
**kwargs,
|
| 86 |
+
):
|
| 87 |
+
r"""
|
| 88 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 89 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
scheduler (`SchedulerMixin`):
|
| 93 |
+
The scheduler to get timesteps from.
|
| 94 |
+
num_inference_steps (`int`):
|
| 95 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 96 |
+
must be `None`.
|
| 97 |
+
device (`str` or `torch.device`, *optional*):
|
| 98 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 99 |
+
timesteps (`List[int]`, *optional*):
|
| 100 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 101 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 102 |
+
sigmas (`List[float]`, *optional*):
|
| 103 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 104 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 108 |
+
second element is the number of inference steps.
|
| 109 |
+
"""
|
| 110 |
+
if timesteps is not None and sigmas is not None:
|
| 111 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 112 |
+
if timesteps is not None:
|
| 113 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 114 |
+
if not accepts_timesteps:
|
| 115 |
+
raise ValueError(
|
| 116 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 117 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 118 |
+
)
|
| 119 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 120 |
+
timesteps = scheduler.timesteps
|
| 121 |
+
num_inference_steps = len(timesteps)
|
| 122 |
+
elif sigmas is not None:
|
| 123 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 124 |
+
if not accept_sigmas:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 127 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 128 |
+
)
|
| 129 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 130 |
+
timesteps = scheduler.timesteps
|
| 131 |
+
num_inference_steps = len(timesteps)
|
| 132 |
+
else:
|
| 133 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 134 |
+
timesteps = scheduler.timesteps
|
| 135 |
+
return timesteps, num_inference_steps
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclass
|
| 139 |
+
class ZImagePipelineOutput(BaseOutput):
|
| 140 |
+
"""
|
| 141 |
+
Output class for Z-Image image generation pipelines.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
|
| 145 |
+
List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
|
| 146 |
+
height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
|
| 147 |
+
pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
|
| 148 |
+
passed to the decoder.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
| 155 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 156 |
+
_optional_components = []
|
| 157 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 158 |
+
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 162 |
+
vae: AutoencoderKL,
|
| 163 |
+
text_encoder: PreTrainedModel,
|
| 164 |
+
tokenizer: AutoTokenizer,
|
| 165 |
+
transformer: ZImageTransformer2DModel,
|
| 166 |
+
):
|
| 167 |
+
super().__init__()
|
| 168 |
+
|
| 169 |
+
self.register_modules(
|
| 170 |
+
vae=vae,
|
| 171 |
+
text_encoder=text_encoder,
|
| 172 |
+
tokenizer=tokenizer,
|
| 173 |
+
scheduler=scheduler,
|
| 174 |
+
transformer=transformer,
|
| 175 |
+
)
|
| 176 |
+
self.vae_scale_factor = (
|
| 177 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 178 |
+
)
|
| 179 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 180 |
+
|
| 181 |
+
def encode_prompt(
|
| 182 |
+
self,
|
| 183 |
+
prompt: Union[str, List[str]],
|
| 184 |
+
device: Optional[torch.device] = None,
|
| 185 |
+
do_classifier_free_guidance: bool = True,
|
| 186 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 187 |
+
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
| 188 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 189 |
+
max_sequence_length: int = 512,
|
| 190 |
+
):
|
| 191 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 192 |
+
prompt_embeds = self._encode_prompt(
|
| 193 |
+
prompt=prompt,
|
| 194 |
+
device=device,
|
| 195 |
+
prompt_embeds=prompt_embeds,
|
| 196 |
+
max_sequence_length=max_sequence_length,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
if do_classifier_free_guidance:
|
| 200 |
+
if negative_prompt is None:
|
| 201 |
+
negative_prompt = ["" for _ in prompt]
|
| 202 |
+
else:
|
| 203 |
+
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 204 |
+
assert len(prompt) == len(negative_prompt)
|
| 205 |
+
negative_prompt_embeds = self._encode_prompt(
|
| 206 |
+
prompt=negative_prompt,
|
| 207 |
+
device=device,
|
| 208 |
+
prompt_embeds=negative_prompt_embeds,
|
| 209 |
+
max_sequence_length=max_sequence_length,
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
negative_prompt_embeds = []
|
| 213 |
+
return prompt_embeds, negative_prompt_embeds
|
| 214 |
+
|
| 215 |
+
def _encode_prompt(
|
| 216 |
+
self,
|
| 217 |
+
prompt: Union[str, List[str]],
|
| 218 |
+
device: Optional[torch.device] = None,
|
| 219 |
+
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
| 220 |
+
max_sequence_length: int = 512,
|
| 221 |
+
) -> List[torch.FloatTensor]:
|
| 222 |
+
device = device or self._execution_device
|
| 223 |
+
|
| 224 |
+
if prompt_embeds is not None:
|
| 225 |
+
return prompt_embeds
|
| 226 |
+
|
| 227 |
+
if isinstance(prompt, str):
|
| 228 |
+
prompt = [prompt]
|
| 229 |
+
|
| 230 |
+
for i, prompt_item in enumerate(prompt):
|
| 231 |
+
messages = [
|
| 232 |
+
{"role": "user", "content": prompt_item},
|
| 233 |
+
]
|
| 234 |
+
prompt_item = self.tokenizer.apply_chat_template(
|
| 235 |
+
messages,
|
| 236 |
+
tokenize=False,
|
| 237 |
+
add_generation_prompt=True,
|
| 238 |
+
enable_thinking=True,
|
| 239 |
+
)
|
| 240 |
+
prompt[i] = prompt_item
|
| 241 |
+
|
| 242 |
+
text_inputs = self.tokenizer(
|
| 243 |
+
prompt,
|
| 244 |
+
padding="max_length",
|
| 245 |
+
max_length=max_sequence_length,
|
| 246 |
+
truncation=True,
|
| 247 |
+
return_tensors="pt",
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
text_input_ids = text_inputs.input_ids.to(device)
|
| 251 |
+
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
| 252 |
+
|
| 253 |
+
prompt_embeds = self.text_encoder(
|
| 254 |
+
input_ids=text_input_ids,
|
| 255 |
+
attention_mask=prompt_masks,
|
| 256 |
+
output_hidden_states=True,
|
| 257 |
+
).hidden_states[-2]
|
| 258 |
+
|
| 259 |
+
embeddings_list = []
|
| 260 |
+
|
| 261 |
+
for i in range(len(prompt_embeds)):
|
| 262 |
+
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
| 263 |
+
|
| 264 |
+
return embeddings_list
|
| 265 |
+
|
| 266 |
+
def prepare_latents(
|
| 267 |
+
self,
|
| 268 |
+
batch_size,
|
| 269 |
+
num_channels_latents,
|
| 270 |
+
height,
|
| 271 |
+
width,
|
| 272 |
+
dtype,
|
| 273 |
+
device,
|
| 274 |
+
generator,
|
| 275 |
+
latents=None,
|
| 276 |
+
):
|
| 277 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 278 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 279 |
+
|
| 280 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 281 |
+
|
| 282 |
+
if latents is None:
|
| 283 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 284 |
+
else:
|
| 285 |
+
if latents.shape != shape:
|
| 286 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
| 287 |
+
latents = latents.to(device)
|
| 288 |
+
return latents
|
| 289 |
+
|
| 290 |
+
@property
|
| 291 |
+
def guidance_scale(self):
|
| 292 |
+
return self._guidance_scale
|
| 293 |
+
|
| 294 |
+
@property
|
| 295 |
+
def do_classifier_free_guidance(self):
|
| 296 |
+
return self._guidance_scale > 1
|
| 297 |
+
|
| 298 |
+
@property
|
| 299 |
+
def joint_attention_kwargs(self):
|
| 300 |
+
return self._joint_attention_kwargs
|
| 301 |
+
|
| 302 |
+
@property
|
| 303 |
+
def num_timesteps(self):
|
| 304 |
+
return self._num_timesteps
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def interrupt(self):
|
| 308 |
+
return self._interrupt
|
| 309 |
+
|
| 310 |
+
@torch.no_grad()
|
| 311 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 312 |
+
def __call__(
|
| 313 |
+
self,
|
| 314 |
+
prompt: Union[str, List[str]] = None,
|
| 315 |
+
height: Optional[int] = None,
|
| 316 |
+
width: Optional[int] = None,
|
| 317 |
+
num_inference_steps: int = 50,
|
| 318 |
+
sigmas: Optional[List[float]] = None,
|
| 319 |
+
guidance_scale: float = 5.0,
|
| 320 |
+
cfg_normalization: bool = False,
|
| 321 |
+
cfg_truncation: float = 1.0,
|
| 322 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 323 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 324 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 325 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 326 |
+
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
| 327 |
+
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
| 328 |
+
output_type: Optional[str] = "pil",
|
| 329 |
+
return_dict: bool = True,
|
| 330 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 331 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 332 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 333 |
+
max_sequence_length: int = 512,
|
| 334 |
+
):
|
| 335 |
+
r"""
|
| 336 |
+
Function invoked when calling the pipeline for generation.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 340 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 341 |
+
instead.
|
| 342 |
+
height (`int`, *optional*, defaults to 1024):
|
| 343 |
+
The height in pixels of the generated image.
|
| 344 |
+
width (`int`, *optional*, defaults to 1024):
|
| 345 |
+
The width in pixels of the generated image.
|
| 346 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 347 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 348 |
+
expense of slower inference.
|
| 349 |
+
sigmas (`List[float]`, *optional*):
|
| 350 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 351 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 352 |
+
will be used.
|
| 353 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 354 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 355 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 356 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 357 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 358 |
+
usually at the expense of lower image quality.
|
| 359 |
+
cfg_normalization (`bool`, *optional*, defaults to False):
|
| 360 |
+
Whether to apply configuration normalization.
|
| 361 |
+
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
| 362 |
+
The truncation value for configuration.
|
| 363 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 364 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 365 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 366 |
+
less than `1`).
|
| 367 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 368 |
+
The number of images to generate per prompt.
|
| 369 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 370 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 371 |
+
to make generation deterministic.
|
| 372 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 373 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 374 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 375 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 376 |
+
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
| 377 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 378 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 379 |
+
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
| 380 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 381 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 382 |
+
argument.
|
| 383 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 384 |
+
The output format of the generate image. Choose between
|
| 385 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 386 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 387 |
+
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
|
| 388 |
+
tuple.
|
| 389 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 390 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 391 |
+
`self.processor` in
|
| 392 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 393 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 394 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 395 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 396 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 397 |
+
`callback_on_step_end_tensor_inputs`.
|
| 398 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 399 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 400 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 401 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 402 |
+
max_sequence_length (`int`, *optional*, defaults to 512):
|
| 403 |
+
Maximum sequence length to use with the `prompt`.
|
| 404 |
+
|
| 405 |
+
Examples:
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
|
| 409 |
+
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
| 410 |
+
generated images.
|
| 411 |
+
"""
|
| 412 |
+
height = height or 1024
|
| 413 |
+
width = width or 1024
|
| 414 |
+
|
| 415 |
+
vae_scale = self.vae_scale_factor * 2
|
| 416 |
+
if height % vae_scale != 0:
|
| 417 |
+
raise ValueError(
|
| 418 |
+
f"Height must be divisible by {vae_scale} (got {height}). "
|
| 419 |
+
f"Please adjust the height to a multiple of {vae_scale}."
|
| 420 |
+
)
|
| 421 |
+
if width % vae_scale != 0:
|
| 422 |
+
raise ValueError(
|
| 423 |
+
f"Width must be divisible by {vae_scale} (got {width}). "
|
| 424 |
+
f"Please adjust the width to a multiple of {vae_scale}."
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
device = self._execution_device
|
| 428 |
+
|
| 429 |
+
self._guidance_scale = guidance_scale
|
| 430 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 431 |
+
self._interrupt = False
|
| 432 |
+
self._cfg_normalization = cfg_normalization
|
| 433 |
+
self._cfg_truncation = cfg_truncation
|
| 434 |
+
# 2. Define call parameters
|
| 435 |
+
if prompt is not None and isinstance(prompt, str):
|
| 436 |
+
batch_size = 1
|
| 437 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 438 |
+
batch_size = len(prompt)
|
| 439 |
+
else:
|
| 440 |
+
batch_size = len(prompt_embeds)
|
| 441 |
+
|
| 442 |
+
# If prompt_embeds is provided and prompt is None, skip encoding
|
| 443 |
+
if prompt_embeds is not None and prompt is None:
|
| 444 |
+
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 445 |
+
raise ValueError(
|
| 446 |
+
"When `prompt_embeds` is provided without `prompt`, "
|
| 447 |
+
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
| 448 |
+
)
|
| 449 |
+
else:
|
| 450 |
+
(
|
| 451 |
+
prompt_embeds,
|
| 452 |
+
negative_prompt_embeds,
|
| 453 |
+
) = self.encode_prompt(
|
| 454 |
+
prompt=prompt,
|
| 455 |
+
negative_prompt=negative_prompt,
|
| 456 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 457 |
+
prompt_embeds=prompt_embeds,
|
| 458 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 459 |
+
device=device,
|
| 460 |
+
max_sequence_length=max_sequence_length,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# 4. Prepare latent variables
|
| 464 |
+
num_channels_latents = self.transformer.in_channels
|
| 465 |
+
|
| 466 |
+
latents = self.prepare_latents(
|
| 467 |
+
batch_size * num_images_per_prompt,
|
| 468 |
+
num_channels_latents,
|
| 469 |
+
height,
|
| 470 |
+
width,
|
| 471 |
+
torch.float32,
|
| 472 |
+
device,
|
| 473 |
+
generator,
|
| 474 |
+
latents,
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# Repeat prompt_embeds for num_images_per_prompt
|
| 478 |
+
if num_images_per_prompt > 1:
|
| 479 |
+
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
| 480 |
+
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
| 481 |
+
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
| 482 |
+
|
| 483 |
+
actual_batch_size = batch_size * num_images_per_prompt
|
| 484 |
+
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
| 485 |
+
|
| 486 |
+
# 5. Prepare timesteps
|
| 487 |
+
mu = calculate_shift(
|
| 488 |
+
image_seq_len,
|
| 489 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 490 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 491 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 492 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 493 |
+
)
|
| 494 |
+
self.scheduler.sigma_min = 0.0
|
| 495 |
+
scheduler_kwargs = {"mu": mu}
|
| 496 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 497 |
+
self.scheduler,
|
| 498 |
+
num_inference_steps,
|
| 499 |
+
device,
|
| 500 |
+
sigmas=sigmas,
|
| 501 |
+
**scheduler_kwargs,
|
| 502 |
+
)
|
| 503 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 504 |
+
self._num_timesteps = len(timesteps)
|
| 505 |
+
|
| 506 |
+
# 6. Denoising loop
|
| 507 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 508 |
+
for i, t in enumerate(timesteps):
|
| 509 |
+
if self.interrupt:
|
| 510 |
+
continue
|
| 511 |
+
|
| 512 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 513 |
+
timestep = t.expand(latents.shape[0])
|
| 514 |
+
timestep = (1000 - timestep) / 1000
|
| 515 |
+
# Normalized time for time-aware config (0 at start, 1 at end)
|
| 516 |
+
t_norm = timestep[0].item()
|
| 517 |
+
|
| 518 |
+
# Handle cfg truncation
|
| 519 |
+
current_guidance_scale = self.guidance_scale
|
| 520 |
+
if (
|
| 521 |
+
self.do_classifier_free_guidance
|
| 522 |
+
and self._cfg_truncation is not None
|
| 523 |
+
and float(self._cfg_truncation) <= 1
|
| 524 |
+
):
|
| 525 |
+
if t_norm > self._cfg_truncation:
|
| 526 |
+
current_guidance_scale = 0.0
|
| 527 |
+
|
| 528 |
+
# Run CFG only if configured AND scale is non-zero
|
| 529 |
+
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
| 530 |
+
|
| 531 |
+
if apply_cfg:
|
| 532 |
+
latents_typed = latents.to(self.transformer.dtype)
|
| 533 |
+
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
| 534 |
+
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
| 535 |
+
timestep_model_input = timestep.repeat(2)
|
| 536 |
+
else:
|
| 537 |
+
latent_model_input = latents.to(self.transformer.dtype)
|
| 538 |
+
prompt_embeds_model_input = prompt_embeds
|
| 539 |
+
timestep_model_input = timestep
|
| 540 |
+
|
| 541 |
+
latent_model_input = latent_model_input.unsqueeze(2)
|
| 542 |
+
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
| 543 |
+
|
| 544 |
+
model_out_list = self.transformer(
|
| 545 |
+
latent_model_input_list,
|
| 546 |
+
timestep_model_input,
|
| 547 |
+
prompt_embeds_model_input,
|
| 548 |
+
)[0]
|
| 549 |
+
|
| 550 |
+
if apply_cfg:
|
| 551 |
+
# Perform CFG
|
| 552 |
+
pos_out = model_out_list[:actual_batch_size]
|
| 553 |
+
neg_out = model_out_list[actual_batch_size:]
|
| 554 |
+
|
| 555 |
+
noise_pred = []
|
| 556 |
+
for j in range(actual_batch_size):
|
| 557 |
+
pos = pos_out[j].float()
|
| 558 |
+
neg = neg_out[j].float()
|
| 559 |
+
|
| 560 |
+
pred = pos + current_guidance_scale * (pos - neg)
|
| 561 |
+
|
| 562 |
+
# Renormalization
|
| 563 |
+
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
| 564 |
+
ori_pos_norm = torch.linalg.vector_norm(pos)
|
| 565 |
+
new_pos_norm = torch.linalg.vector_norm(pred)
|
| 566 |
+
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
| 567 |
+
if new_pos_norm > max_new_norm:
|
| 568 |
+
pred = pred * (max_new_norm / new_pos_norm)
|
| 569 |
+
|
| 570 |
+
noise_pred.append(pred)
|
| 571 |
+
|
| 572 |
+
noise_pred = torch.stack(noise_pred, dim=0)
|
| 573 |
+
else:
|
| 574 |
+
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
| 575 |
+
|
| 576 |
+
noise_pred = noise_pred.squeeze(2)
|
| 577 |
+
noise_pred = -noise_pred
|
| 578 |
+
|
| 579 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 580 |
+
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
| 581 |
+
assert latents.dtype == torch.float32
|
| 582 |
+
|
| 583 |
+
if callback_on_step_end is not None:
|
| 584 |
+
callback_kwargs = {}
|
| 585 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 586 |
+
callback_kwargs[k] = locals()[k]
|
| 587 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 588 |
+
|
| 589 |
+
latents = callback_outputs.pop("latents", latents)
|
| 590 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 591 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 592 |
+
|
| 593 |
+
# call the callback, if provided
|
| 594 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 595 |
+
progress_bar.update()
|
| 596 |
+
|
| 597 |
+
if output_type == "latent":
|
| 598 |
+
image = latents
|
| 599 |
+
|
| 600 |
+
else:
|
| 601 |
+
latents = latents.to(self.vae.dtype)
|
| 602 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 603 |
+
|
| 604 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 605 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 606 |
+
|
| 607 |
+
# Offload all models
|
| 608 |
+
self.maybe_free_model_hooks()
|
| 609 |
+
|
| 610 |
+
if not return_dict:
|
| 611 |
+
return (image,)
|
| 612 |
+
|
| 613 |
+
return ZImagePipelineOutput(images=image)
|
videox_fun/pipeline/pipeline_z_image_control.py
ADDED
|
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
import numpy as np
|
| 17 |
+
import PIL
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 24 |
+
from diffusers.loaders import FromSingleFileMixin
|
| 25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 26 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 27 |
+
from diffusers.utils import (BaseOutput, is_torch_xla_available, logging,
|
| 28 |
+
replace_example_docstring)
|
| 29 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 30 |
+
from transformers import AutoTokenizer, PreTrainedModel
|
| 31 |
+
|
| 32 |
+
from ..models import AutoencoderKL, ZImageTransformer2DModel
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 35 |
+
|
| 36 |
+
EXAMPLE_DOC_STRING = """
|
| 37 |
+
Examples:
|
| 38 |
+
```py
|
| 39 |
+
>>> import torch
|
| 40 |
+
>>> from diffusers import ZImagePipeline
|
| 41 |
+
|
| 42 |
+
>>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
| 43 |
+
>>> pipe.to("cuda")
|
| 44 |
+
|
| 45 |
+
>>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
|
| 46 |
+
>>> # (1) Use flash attention 2
|
| 47 |
+
>>> # pipe.transformer.set_attention_backend("flash")
|
| 48 |
+
>>> # (2) Use flash attention 3
|
| 49 |
+
>>> # pipe.transformer.set_attention_backend("_flash_3")
|
| 50 |
+
|
| 51 |
+
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
|
| 52 |
+
>>> image = pipe(
|
| 53 |
+
diffusers. prompt,
|
| 54 |
+
diffusers. height=1024,
|
| 55 |
+
diffusers. width=1024,
|
| 56 |
+
diffusers. num_inference_steps=9,
|
| 57 |
+
diffusers. guidance_scale=0.0,
|
| 58 |
+
diffusers. generator=torch.Generator("cuda").manual_seed(42),
|
| 59 |
+
diffusers. ).images[0]
|
| 60 |
+
>>> image.save("zimage.png")
|
| 61 |
+
```
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
| 66 |
+
def calculate_shift(
|
| 67 |
+
image_seq_len,
|
| 68 |
+
base_seq_len: int = 256,
|
| 69 |
+
max_seq_len: int = 4096,
|
| 70 |
+
base_shift: float = 0.5,
|
| 71 |
+
max_shift: float = 1.15,
|
| 72 |
+
):
|
| 73 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 74 |
+
b = base_shift - m * base_seq_len
|
| 75 |
+
mu = image_seq_len * m + b
|
| 76 |
+
return mu
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 80 |
+
def retrieve_timesteps(
|
| 81 |
+
scheduler,
|
| 82 |
+
num_inference_steps: Optional[int] = None,
|
| 83 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 84 |
+
timesteps: Optional[List[int]] = None,
|
| 85 |
+
sigmas: Optional[List[float]] = None,
|
| 86 |
+
**kwargs,
|
| 87 |
+
):
|
| 88 |
+
r"""
|
| 89 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 90 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
scheduler (`SchedulerMixin`):
|
| 94 |
+
The scheduler to get timesteps from.
|
| 95 |
+
num_inference_steps (`int`):
|
| 96 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 97 |
+
must be `None`.
|
| 98 |
+
device (`str` or `torch.device`, *optional*):
|
| 99 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 100 |
+
timesteps (`List[int]`, *optional*):
|
| 101 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 102 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 103 |
+
sigmas (`List[float]`, *optional*):
|
| 104 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 105 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 109 |
+
second element is the number of inference steps.
|
| 110 |
+
"""
|
| 111 |
+
if timesteps is not None and sigmas is not None:
|
| 112 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 113 |
+
if timesteps is not None:
|
| 114 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 115 |
+
if not accepts_timesteps:
|
| 116 |
+
raise ValueError(
|
| 117 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 118 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 119 |
+
)
|
| 120 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 121 |
+
timesteps = scheduler.timesteps
|
| 122 |
+
num_inference_steps = len(timesteps)
|
| 123 |
+
elif sigmas is not None:
|
| 124 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 125 |
+
if not accept_sigmas:
|
| 126 |
+
raise ValueError(
|
| 127 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 128 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 129 |
+
)
|
| 130 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 131 |
+
timesteps = scheduler.timesteps
|
| 132 |
+
num_inference_steps = len(timesteps)
|
| 133 |
+
else:
|
| 134 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 135 |
+
timesteps = scheduler.timesteps
|
| 136 |
+
return timesteps, num_inference_steps
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@dataclass
|
| 140 |
+
class ZImagePipelineOutput(BaseOutput):
|
| 141 |
+
"""
|
| 142 |
+
Output class for Z-Image image generation pipelines.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
|
| 146 |
+
List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
|
| 147 |
+
height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
|
| 148 |
+
pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
|
| 149 |
+
passed to the decoder.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class ZImageControlPipeline(DiffusionPipeline, FromSingleFileMixin):
|
| 156 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 157 |
+
_optional_components = []
|
| 158 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 163 |
+
vae: AutoencoderKL,
|
| 164 |
+
text_encoder: PreTrainedModel,
|
| 165 |
+
tokenizer: AutoTokenizer,
|
| 166 |
+
transformer: ZImageTransformer2DModel,
|
| 167 |
+
):
|
| 168 |
+
super().__init__()
|
| 169 |
+
|
| 170 |
+
self.register_modules(
|
| 171 |
+
vae=vae,
|
| 172 |
+
text_encoder=text_encoder,
|
| 173 |
+
tokenizer=tokenizer,
|
| 174 |
+
scheduler=scheduler,
|
| 175 |
+
transformer=transformer,
|
| 176 |
+
)
|
| 177 |
+
self.vae_scale_factor = (
|
| 178 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 179 |
+
)
|
| 180 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 181 |
+
self.mask_processor = VaeImageProcessor(
|
| 182 |
+
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def encode_prompt(
|
| 186 |
+
self,
|
| 187 |
+
prompt: Union[str, List[str]],
|
| 188 |
+
device: Optional[torch.device] = None,
|
| 189 |
+
do_classifier_free_guidance: bool = True,
|
| 190 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 191 |
+
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
| 192 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 193 |
+
max_sequence_length: int = 512,
|
| 194 |
+
):
|
| 195 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 196 |
+
prompt_embeds = self._encode_prompt(
|
| 197 |
+
prompt=prompt,
|
| 198 |
+
device=device,
|
| 199 |
+
prompt_embeds=prompt_embeds,
|
| 200 |
+
max_sequence_length=max_sequence_length,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
if do_classifier_free_guidance:
|
| 204 |
+
if negative_prompt is None:
|
| 205 |
+
negative_prompt = ["" for _ in prompt]
|
| 206 |
+
else:
|
| 207 |
+
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 208 |
+
assert len(prompt) == len(negative_prompt)
|
| 209 |
+
negative_prompt_embeds = self._encode_prompt(
|
| 210 |
+
prompt=negative_prompt,
|
| 211 |
+
device=device,
|
| 212 |
+
prompt_embeds=negative_prompt_embeds,
|
| 213 |
+
max_sequence_length=max_sequence_length,
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
negative_prompt_embeds = []
|
| 217 |
+
return prompt_embeds, negative_prompt_embeds
|
| 218 |
+
|
| 219 |
+
def _encode_prompt(
|
| 220 |
+
self,
|
| 221 |
+
prompt: Union[str, List[str]],
|
| 222 |
+
device: Optional[torch.device] = None,
|
| 223 |
+
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
| 224 |
+
max_sequence_length: int = 512,
|
| 225 |
+
) -> List[torch.FloatTensor]:
|
| 226 |
+
device = device or self._execution_device
|
| 227 |
+
|
| 228 |
+
if prompt_embeds is not None:
|
| 229 |
+
return prompt_embeds
|
| 230 |
+
|
| 231 |
+
if isinstance(prompt, str):
|
| 232 |
+
prompt = [prompt]
|
| 233 |
+
|
| 234 |
+
for i, prompt_item in enumerate(prompt):
|
| 235 |
+
messages = [
|
| 236 |
+
{"role": "user", "content": prompt_item},
|
| 237 |
+
]
|
| 238 |
+
prompt_item = self.tokenizer.apply_chat_template(
|
| 239 |
+
messages,
|
| 240 |
+
tokenize=False,
|
| 241 |
+
add_generation_prompt=True,
|
| 242 |
+
enable_thinking=True,
|
| 243 |
+
)
|
| 244 |
+
prompt[i] = prompt_item
|
| 245 |
+
|
| 246 |
+
text_inputs = self.tokenizer(
|
| 247 |
+
prompt,
|
| 248 |
+
padding="max_length",
|
| 249 |
+
max_length=max_sequence_length,
|
| 250 |
+
truncation=True,
|
| 251 |
+
return_tensors="pt",
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
text_input_ids = text_inputs.input_ids.to(device)
|
| 255 |
+
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
| 256 |
+
|
| 257 |
+
prompt_embeds = self.text_encoder(
|
| 258 |
+
input_ids=text_input_ids,
|
| 259 |
+
attention_mask=prompt_masks,
|
| 260 |
+
output_hidden_states=True,
|
| 261 |
+
).hidden_states[-2]
|
| 262 |
+
|
| 263 |
+
embeddings_list = []
|
| 264 |
+
|
| 265 |
+
for i in range(len(prompt_embeds)):
|
| 266 |
+
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
| 267 |
+
|
| 268 |
+
return embeddings_list
|
| 269 |
+
|
| 270 |
+
def prepare_latents(
|
| 271 |
+
self,
|
| 272 |
+
batch_size,
|
| 273 |
+
num_channels_latents,
|
| 274 |
+
height,
|
| 275 |
+
width,
|
| 276 |
+
dtype,
|
| 277 |
+
device,
|
| 278 |
+
generator,
|
| 279 |
+
latents=None,
|
| 280 |
+
):
|
| 281 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 282 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 283 |
+
|
| 284 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 285 |
+
|
| 286 |
+
if latents is None:
|
| 287 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 288 |
+
else:
|
| 289 |
+
if latents.shape != shape:
|
| 290 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
| 291 |
+
latents = latents.to(device)
|
| 292 |
+
return latents
|
| 293 |
+
|
| 294 |
+
@property
|
| 295 |
+
def guidance_scale(self):
|
| 296 |
+
return self._guidance_scale
|
| 297 |
+
|
| 298 |
+
@property
|
| 299 |
+
def do_classifier_free_guidance(self):
|
| 300 |
+
return self._guidance_scale > 1
|
| 301 |
+
|
| 302 |
+
@property
|
| 303 |
+
def joint_attention_kwargs(self):
|
| 304 |
+
return self._joint_attention_kwargs
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def num_timesteps(self):
|
| 308 |
+
return self._num_timesteps
|
| 309 |
+
|
| 310 |
+
@property
|
| 311 |
+
def interrupt(self):
|
| 312 |
+
return self._interrupt
|
| 313 |
+
|
| 314 |
+
@torch.no_grad()
|
| 315 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 316 |
+
def __call__(
|
| 317 |
+
self,
|
| 318 |
+
prompt: Union[str, List[str]] = None,
|
| 319 |
+
height: Optional[int] = None,
|
| 320 |
+
width: Optional[int] = None,
|
| 321 |
+
|
| 322 |
+
control_image: Union[torch.FloatTensor] = None,
|
| 323 |
+
control_context_scale: float = 1.0,
|
| 324 |
+
|
| 325 |
+
num_inference_steps: int = 50,
|
| 326 |
+
sigmas: Optional[List[float]] = None,
|
| 327 |
+
guidance_scale: float = 5.0,
|
| 328 |
+
cfg_normalization: bool = False,
|
| 329 |
+
cfg_truncation: float = 1.0,
|
| 330 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 331 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 332 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 333 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 334 |
+
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
| 335 |
+
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
| 336 |
+
output_type: Optional[str] = "pil",
|
| 337 |
+
return_dict: bool = True,
|
| 338 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 339 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 340 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 341 |
+
max_sequence_length: int = 512,
|
| 342 |
+
):
|
| 343 |
+
r"""
|
| 344 |
+
Function invoked when calling the pipeline for generation.
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 348 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 349 |
+
instead.
|
| 350 |
+
height (`int`, *optional*, defaults to 1024):
|
| 351 |
+
The height in pixels of the generated image.
|
| 352 |
+
width (`int`, *optional*, defaults to 1024):
|
| 353 |
+
The width in pixels of the generated image.
|
| 354 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 355 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 356 |
+
expense of slower inference.
|
| 357 |
+
sigmas (`List[float]`, *optional*):
|
| 358 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 359 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 360 |
+
will be used.
|
| 361 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 362 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 363 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 364 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 365 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 366 |
+
usually at the expense of lower image quality.
|
| 367 |
+
cfg_normalization (`bool`, *optional*, defaults to False):
|
| 368 |
+
Whether to apply configuration normalization.
|
| 369 |
+
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
| 370 |
+
The truncation value for configuration.
|
| 371 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 372 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 373 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 374 |
+
less than `1`).
|
| 375 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 376 |
+
The number of images to generate per prompt.
|
| 377 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 378 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 379 |
+
to make generation deterministic.
|
| 380 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 381 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 382 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 383 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 384 |
+
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
| 385 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 386 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 387 |
+
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
| 388 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 389 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 390 |
+
argument.
|
| 391 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 392 |
+
The output format of the generate image. Choose between
|
| 393 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 394 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 395 |
+
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
|
| 396 |
+
tuple.
|
| 397 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 398 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 399 |
+
`self.processor` in
|
| 400 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 401 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 402 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 403 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 404 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 405 |
+
`callback_on_step_end_tensor_inputs`.
|
| 406 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 407 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 408 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 409 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 410 |
+
max_sequence_length (`int`, *optional*, defaults to 512):
|
| 411 |
+
Maximum sequence length to use with the `prompt`.
|
| 412 |
+
|
| 413 |
+
Examples:
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
|
| 417 |
+
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
| 418 |
+
generated images.
|
| 419 |
+
"""
|
| 420 |
+
height = height or 1024
|
| 421 |
+
width = width or 1024
|
| 422 |
+
|
| 423 |
+
vae_scale = self.vae_scale_factor * 2
|
| 424 |
+
if height % vae_scale != 0:
|
| 425 |
+
raise ValueError(
|
| 426 |
+
f"Height must be divisible by {vae_scale} (got {height}). "
|
| 427 |
+
f"Please adjust the height to a multiple of {vae_scale}."
|
| 428 |
+
)
|
| 429 |
+
if width % vae_scale != 0:
|
| 430 |
+
raise ValueError(
|
| 431 |
+
f"Width must be divisible by {vae_scale} (got {width}). "
|
| 432 |
+
f"Please adjust the width to a multiple of {vae_scale}."
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
self._guidance_scale = guidance_scale
|
| 436 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 437 |
+
self._interrupt = False
|
| 438 |
+
self._cfg_normalization = cfg_normalization
|
| 439 |
+
self._cfg_truncation = cfg_truncation
|
| 440 |
+
# 2. Define call parameters
|
| 441 |
+
if prompt is not None and isinstance(prompt, str):
|
| 442 |
+
batch_size = 1
|
| 443 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 444 |
+
batch_size = len(prompt)
|
| 445 |
+
else:
|
| 446 |
+
batch_size = len(prompt_embeds)
|
| 447 |
+
|
| 448 |
+
device = self._execution_device
|
| 449 |
+
weight_dtype = self.text_encoder.dtype
|
| 450 |
+
num_channels_latents = self.transformer.in_channels
|
| 451 |
+
|
| 452 |
+
if control_image is not None:
|
| 453 |
+
control_image = self.image_processor.preprocess(control_image, height=height, width=width)
|
| 454 |
+
control_image = control_image.to(dtype=weight_dtype, device=device)
|
| 455 |
+
control_latents = self.vae.encode(control_image)[0].mode()
|
| 456 |
+
control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 457 |
+
else:
|
| 458 |
+
control_latents = torch.zeros_like(inpaint_latent)
|
| 459 |
+
|
| 460 |
+
control_context = control_latents.unsqueeze(2)
|
| 461 |
+
|
| 462 |
+
# If prompt_embeds is provided and prompt is None, skip encoding
|
| 463 |
+
if prompt_embeds is not None and prompt is None:
|
| 464 |
+
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 465 |
+
raise ValueError(
|
| 466 |
+
"When `prompt_embeds` is provided without `prompt`, "
|
| 467 |
+
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
| 468 |
+
)
|
| 469 |
+
else:
|
| 470 |
+
(
|
| 471 |
+
prompt_embeds,
|
| 472 |
+
negative_prompt_embeds,
|
| 473 |
+
) = self.encode_prompt(
|
| 474 |
+
prompt=prompt,
|
| 475 |
+
negative_prompt=negative_prompt,
|
| 476 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 477 |
+
prompt_embeds=prompt_embeds,
|
| 478 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 479 |
+
device=device,
|
| 480 |
+
max_sequence_length=max_sequence_length,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# 4. Prepare latent variables
|
| 484 |
+
latents = self.prepare_latents(
|
| 485 |
+
batch_size * num_images_per_prompt,
|
| 486 |
+
num_channels_latents,
|
| 487 |
+
height,
|
| 488 |
+
width,
|
| 489 |
+
torch.float32,
|
| 490 |
+
device,
|
| 491 |
+
generator,
|
| 492 |
+
latents,
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# Repeat prompt_embeds for num_images_per_prompt
|
| 496 |
+
if num_images_per_prompt > 1:
|
| 497 |
+
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
| 498 |
+
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
| 499 |
+
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
| 500 |
+
|
| 501 |
+
actual_batch_size = batch_size * num_images_per_prompt
|
| 502 |
+
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
| 503 |
+
|
| 504 |
+
# 5. Prepare timesteps
|
| 505 |
+
mu = calculate_shift(
|
| 506 |
+
image_seq_len,
|
| 507 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 508 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 509 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 510 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 511 |
+
)
|
| 512 |
+
self.scheduler.sigma_min = 0.0
|
| 513 |
+
scheduler_kwargs = {"mu": mu}
|
| 514 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 515 |
+
self.scheduler,
|
| 516 |
+
num_inference_steps,
|
| 517 |
+
device,
|
| 518 |
+
sigmas=sigmas,
|
| 519 |
+
**scheduler_kwargs,
|
| 520 |
+
)
|
| 521 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 522 |
+
self._num_timesteps = len(timesteps)
|
| 523 |
+
|
| 524 |
+
# 6. Denoising loop
|
| 525 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 526 |
+
for i, t in enumerate(timesteps):
|
| 527 |
+
if self.interrupt:
|
| 528 |
+
continue
|
| 529 |
+
|
| 530 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 531 |
+
timestep = t.expand(latents.shape[0])
|
| 532 |
+
timestep = (1000 - timestep) / 1000
|
| 533 |
+
# Normalized time for time-aware config (0 at start, 1 at end)
|
| 534 |
+
t_norm = timestep[0].item()
|
| 535 |
+
|
| 536 |
+
# Handle cfg truncation
|
| 537 |
+
current_guidance_scale = self.guidance_scale
|
| 538 |
+
if (
|
| 539 |
+
self.do_classifier_free_guidance
|
| 540 |
+
and self._cfg_truncation is not None
|
| 541 |
+
and float(self._cfg_truncation) <= 1
|
| 542 |
+
):
|
| 543 |
+
if t_norm > self._cfg_truncation:
|
| 544 |
+
current_guidance_scale = 0.0
|
| 545 |
+
|
| 546 |
+
# Run CFG only if configured AND scale is non-zero
|
| 547 |
+
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
| 548 |
+
|
| 549 |
+
if apply_cfg:
|
| 550 |
+
latents_typed = latents.to(self.transformer.dtype)
|
| 551 |
+
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
| 552 |
+
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
| 553 |
+
timestep_model_input = timestep.repeat(2)
|
| 554 |
+
else:
|
| 555 |
+
latent_model_input = latents.to(self.transformer.dtype)
|
| 556 |
+
prompt_embeds_model_input = prompt_embeds
|
| 557 |
+
timestep_model_input = timestep
|
| 558 |
+
|
| 559 |
+
latent_model_input = latent_model_input.unsqueeze(2)
|
| 560 |
+
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
| 561 |
+
|
| 562 |
+
model_out_list = self.transformer(
|
| 563 |
+
latent_model_input_list,
|
| 564 |
+
timestep_model_input,
|
| 565 |
+
prompt_embeds_model_input,
|
| 566 |
+
control_context=control_context,
|
| 567 |
+
control_context_scale=control_context_scale,
|
| 568 |
+
)[0]
|
| 569 |
+
|
| 570 |
+
if apply_cfg:
|
| 571 |
+
# Perform CFG
|
| 572 |
+
pos_out = model_out_list[:actual_batch_size]
|
| 573 |
+
neg_out = model_out_list[actual_batch_size:]
|
| 574 |
+
|
| 575 |
+
noise_pred = []
|
| 576 |
+
for j in range(actual_batch_size):
|
| 577 |
+
pos = pos_out[j].float()
|
| 578 |
+
neg = neg_out[j].float()
|
| 579 |
+
|
| 580 |
+
pred = pos + current_guidance_scale * (pos - neg)
|
| 581 |
+
|
| 582 |
+
# Renormalization
|
| 583 |
+
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
| 584 |
+
ori_pos_norm = torch.linalg.vector_norm(pos)
|
| 585 |
+
new_pos_norm = torch.linalg.vector_norm(pred)
|
| 586 |
+
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
| 587 |
+
if new_pos_norm > max_new_norm:
|
| 588 |
+
pred = pred * (max_new_norm / new_pos_norm)
|
| 589 |
+
|
| 590 |
+
noise_pred.append(pred)
|
| 591 |
+
|
| 592 |
+
noise_pred = torch.stack(noise_pred, dim=0)
|
| 593 |
+
else:
|
| 594 |
+
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
| 595 |
+
|
| 596 |
+
noise_pred = noise_pred.squeeze(2)
|
| 597 |
+
noise_pred = -noise_pred
|
| 598 |
+
|
| 599 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 600 |
+
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
| 601 |
+
assert latents.dtype == torch.float32
|
| 602 |
+
|
| 603 |
+
if callback_on_step_end is not None:
|
| 604 |
+
callback_kwargs = {}
|
| 605 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 606 |
+
callback_kwargs[k] = locals()[k]
|
| 607 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 608 |
+
|
| 609 |
+
latents = callback_outputs.pop("latents", latents)
|
| 610 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 611 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 612 |
+
|
| 613 |
+
# call the callback, if provided
|
| 614 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 615 |
+
progress_bar.update()
|
| 616 |
+
|
| 617 |
+
if output_type == "latent":
|
| 618 |
+
image = latents
|
| 619 |
+
|
| 620 |
+
else:
|
| 621 |
+
latents = latents.to(self.vae.dtype)
|
| 622 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 623 |
+
|
| 624 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 625 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 626 |
+
|
| 627 |
+
# Offload all models
|
| 628 |
+
self.maybe_free_model_hooks()
|
| 629 |
+
|
| 630 |
+
if not return_dict:
|
| 631 |
+
return (image,)
|
| 632 |
+
|
| 633 |
+
return ZImagePipelineOutput(images=image)
|
videox_fun/reward/MPS/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
This folder is modified from the official [MPS](https://github.com/Kwai-Kolors/MPS/tree/main) repository.
|