Update app.py
Browse files
app.py
CHANGED
|
@@ -9,6 +9,7 @@ import tempfile
|
|
| 9 |
import numpy as np
|
| 10 |
from PIL import Image
|
| 11 |
import os
|
|
|
|
| 12 |
|
| 13 |
import gradio as gr
|
| 14 |
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
|
|
@@ -21,7 +22,7 @@ from torchao.quantization import Int8WeightOnlyConfig
|
|
| 21 |
|
| 22 |
import aoti
|
| 23 |
|
| 24 |
-
#
|
| 25 |
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
|
| 26 |
|
| 27 |
MAX_DIM = 832
|
|
@@ -45,7 +46,7 @@ default_negative_prompt = (
|
|
| 45 |
"形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
|
| 46 |
)
|
| 47 |
|
| 48 |
-
#
|
| 49 |
pipe = WanImageToVideoPipeline.from_pretrained(
|
| 50 |
MODEL_ID,
|
| 51 |
transformer=WanTransformer3DModel.from_pretrained(
|
|
@@ -90,13 +91,12 @@ aoti.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Wan2", variant="fp8da")
|
|
| 90 |
aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da")
|
| 91 |
|
| 92 |
# ------------------------------------------------------------
|
| 93 |
-
#
|
| 94 |
# ------------------------------------------------------------
|
| 95 |
def resize_image(image: Image.Image) -> Image.Image:
|
| 96 |
"""Resize / crop the input image to a size the model accepts."""
|
| 97 |
w, h = image.size
|
| 98 |
|
| 99 |
-
# square shortcut
|
| 100 |
if w == h:
|
| 101 |
return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
|
| 102 |
|
|
@@ -105,19 +105,19 @@ def resize_image(image: Image.Image) -> Image.Image:
|
|
| 105 |
MIN_AR = MIN_DIM / MAX_DIM
|
| 106 |
img = image
|
| 107 |
|
| 108 |
-
if aspect > MAX_AR:
|
| 109 |
crop_w = int(round(h * MAX_AR))
|
| 110 |
left = (w - crop_w) // 2
|
| 111 |
img = image.crop((left, 0, left + crop_w, h))
|
| 112 |
-
elif aspect < MIN_AR:
|
| 113 |
crop_h = int(round(w / MIN_AR))
|
| 114 |
top = (h - crop_h) // 2
|
| 115 |
img = image.crop((0, top, w, top + crop_h))
|
| 116 |
else:
|
| 117 |
-
if w > h:
|
| 118 |
target_w = MAX_DIM
|
| 119 |
target_h = int(round(target_w / aspect))
|
| 120 |
-
else:
|
| 121 |
target_h = MAX_DIM
|
| 122 |
target_w = int(round(target_h * aspect))
|
| 123 |
img = image
|
|
@@ -152,7 +152,7 @@ def get_duration(
|
|
| 152 |
randomize_seed,
|
| 153 |
progress,
|
| 154 |
):
|
| 155 |
-
"""GPU‑time estimator
|
| 156 |
BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
|
| 157 |
BASE_STEP_DURATION = 15
|
| 158 |
|
|
@@ -162,12 +162,11 @@ def get_duration(
|
|
| 162 |
step_duration = BASE_STEP_DURATION * factor ** 1.5
|
| 163 |
est = 10 + int(steps) * step_duration
|
| 164 |
|
| 165 |
-
|
| 166 |
-
return min(est, 30)
|
| 167 |
|
| 168 |
|
| 169 |
# ------------------------------------------------------------
|
| 170 |
-
# MAIN GENERATION FUNCTION
|
| 171 |
# ------------------------------------------------------------
|
| 172 |
@spaces.GPU(duration=get_duration)
|
| 173 |
def generate_video(
|
|
@@ -180,56 +179,70 @@ def generate_video(
|
|
| 180 |
guidance_scale_2=1.5,
|
| 181 |
seed=42,
|
| 182 |
randomize_seed=False,
|
| 183 |
-
progress=None,
|
| 184 |
):
|
| 185 |
-
"""
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
| 215 |
-
video_path = tmp.name
|
| 216 |
-
export_to_video(frames, video_path, fps=FIXED_FPS)
|
| 217 |
|
| 218 |
-
|
| 219 |
-
# Clean up GPU memory for the next request
|
| 220 |
-
# -----------------------------------------------------------------
|
| 221 |
-
gc.collect()
|
| 222 |
-
torch.cuda.empty_cache()
|
| 223 |
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
|
| 227 |
# ------------------------------------------------------------
|
| 228 |
-
# UI –
|
| 229 |
# ------------------------------------------------------------
|
| 230 |
def create_demo():
|
| 231 |
with gr.Blocks(css="", title="Fast Image to Video") as demo:
|
| 232 |
-
#
|
| 233 |
gr.HTML(
|
| 234 |
"""
|
| 235 |
<script>
|
|
@@ -241,7 +254,7 @@ def create_demo():
|
|
| 241 |
"""
|
| 242 |
)
|
| 243 |
|
| 244 |
-
#
|
| 245 |
gr.HTML(
|
| 246 |
"""
|
| 247 |
<style>
|
|
@@ -268,7 +281,7 @@ def create_demo():
|
|
| 268 |
body::before{
|
| 269 |
content:"";
|
| 270 |
display:block;
|
| 271 |
-
height:600px; /*
|
| 272 |
background:#000 !important;
|
| 273 |
}
|
| 274 |
.gr-blocks,.container{
|
|
@@ -351,7 +364,7 @@ def create_demo():
|
|
| 351 |
box-sizing:border-box !important;
|
| 352 |
display:block !important;
|
| 353 |
}
|
| 354 |
-
/* ---- hide
|
| 355 |
.image-container[aria-label="Generated Video"] .progress-text,
|
| 356 |
.image-container[aria-label="Generated Video"] .gr-progress,
|
| 357 |
.image-container[aria-label="Generated Video"] .gr-progress-bar,
|
|
@@ -378,7 +391,7 @@ def create_demo():
|
|
| 378 |
.image-container[aria-label="Generated Video"] *[class*="progress"],
|
| 379 |
.image-container[aria-label="Generated Video"] *[class*="loading"],
|
| 380 |
.image-container[aria-label="Generated Video"] *[class*="status"],
|
| 381 |
-
.image-container[aria-label="Generated Video"] *[class*="spinner
|
| 382 |
.progress-text,.gr-progress,.gr-progress-bar,.progress-bar,
|
| 383 |
[data-testid="progress"],.status,.loading,.spinner,.gr-spinner,
|
| 384 |
.gr-loading,.gr-status,.gpu-init,.initializing,.queue,
|
|
@@ -542,7 +555,7 @@ def create_demo():
|
|
| 542 |
"""
|
| 543 |
)
|
| 544 |
|
| 545 |
-
# ------------------- UI
|
| 546 |
with gr.Row(elem_id="general_items"):
|
| 547 |
gr.Markdown("# ")
|
| 548 |
gr.Markdown(
|
|
@@ -566,8 +579,6 @@ def create_demo():
|
|
| 566 |
placeholder="Describe the desired animation or motion",
|
| 567 |
elem_classes=["gradio-component"],
|
| 568 |
)
|
| 569 |
-
# (the rest of the advanced sliders you had in the “original” UI are omitted
|
| 570 |
-
# because your current functional code only expects the arguments below)
|
| 571 |
generate_button = gr.Button(
|
| 572 |
"Generate Video",
|
| 573 |
variant="primary",
|
|
@@ -582,21 +593,21 @@ def create_demo():
|
|
| 582 |
elem_classes=["gradio-component", "image-container"],
|
| 583 |
)
|
| 584 |
|
| 585 |
-
# -------------------
|
| 586 |
generate_button.click(
|
| 587 |
fn=generate_video,
|
| 588 |
inputs=[
|
| 589 |
-
input_image,
|
| 590 |
-
prompt,
|
| 591 |
-
gr.State(value=6),
|
| 592 |
gr.State(value=default_negative_prompt), # negative_prompt
|
| 593 |
-
gr.State(value=3.2),
|
| 594 |
-
gr.State(value=1.5),
|
| 595 |
-
gr.State(value=1.5),
|
| 596 |
-
gr.State(value=42),
|
| 597 |
-
gr.State(value=True),
|
| 598 |
],
|
| 599 |
-
outputs=[output_video, gr.State(value=42)],
|
| 600 |
)
|
| 601 |
|
| 602 |
return demo
|
|
|
|
| 9 |
import numpy as np
|
| 10 |
from PIL import Image
|
| 11 |
import os
|
| 12 |
+
import traceback
|
| 13 |
|
| 14 |
import gradio as gr
|
| 15 |
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
|
|
|
|
| 22 |
|
| 23 |
import aoti
|
| 24 |
|
| 25 |
+
# ------------------- constants -------------------
|
| 26 |
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
|
| 27 |
|
| 28 |
MAX_DIM = 832
|
|
|
|
| 46 |
"形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
|
| 47 |
)
|
| 48 |
|
| 49 |
+
# ------------------- load pipeline -------------------
|
| 50 |
pipe = WanImageToVideoPipeline.from_pretrained(
|
| 51 |
MODEL_ID,
|
| 52 |
transformer=WanTransformer3DModel.from_pretrained(
|
|
|
|
| 91 |
aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da")
|
| 92 |
|
| 93 |
# ------------------------------------------------------------
|
| 94 |
+
# HELPERS
|
| 95 |
# ------------------------------------------------------------
|
| 96 |
def resize_image(image: Image.Image) -> Image.Image:
|
| 97 |
"""Resize / crop the input image to a size the model accepts."""
|
| 98 |
w, h = image.size
|
| 99 |
|
|
|
|
| 100 |
if w == h:
|
| 101 |
return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
|
| 102 |
|
|
|
|
| 105 |
MIN_AR = MIN_DIM / MAX_DIM
|
| 106 |
img = image
|
| 107 |
|
| 108 |
+
if aspect > MAX_AR: # very wide
|
| 109 |
crop_w = int(round(h * MAX_AR))
|
| 110 |
left = (w - crop_w) // 2
|
| 111 |
img = image.crop((left, 0, left + crop_w, h))
|
| 112 |
+
elif aspect < MIN_AR: # very tall
|
| 113 |
crop_h = int(round(w / MIN_AR))
|
| 114 |
top = (h - crop_h) // 2
|
| 115 |
img = image.crop((0, top, w, top + crop_h))
|
| 116 |
else:
|
| 117 |
+
if w > h: # landscape
|
| 118 |
target_w = MAX_DIM
|
| 119 |
target_h = int(round(target_w / aspect))
|
| 120 |
+
else: # portrait
|
| 121 |
target_h = MAX_DIM
|
| 122 |
target_w = int(round(target_h * aspect))
|
| 123 |
img = image
|
|
|
|
| 152 |
randomize_seed,
|
| 153 |
progress,
|
| 154 |
):
|
| 155 |
+
"""GPU‑time estimator for @spaces.GPU."""
|
| 156 |
BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
|
| 157 |
BASE_STEP_DURATION = 15
|
| 158 |
|
|
|
|
| 162 |
step_duration = BASE_STEP_DURATION * factor ** 1.5
|
| 163 |
est = 10 + int(steps) * step_duration
|
| 164 |
|
| 165 |
+
return min(est, 30) # safety cap
|
|
|
|
| 166 |
|
| 167 |
|
| 168 |
# ------------------------------------------------------------
|
| 169 |
+
# MAIN GENERATION FUNCTION – now with error logging
|
| 170 |
# ------------------------------------------------------------
|
| 171 |
@spaces.GPU(duration=get_duration)
|
| 172 |
def generate_video(
|
|
|
|
| 179 |
guidance_scale_2=1.5,
|
| 180 |
seed=42,
|
| 181 |
randomize_seed=False,
|
| 182 |
+
progress=None, # optional – Gradio will inject if needed
|
| 183 |
):
|
| 184 |
+
"""
|
| 185 |
+
Run the Wan‑2.2 pipeline and return an MP4 file.
|
| 186 |
+
Any exception is caught, printed to the Space logs, and re‑raised as a Gradio error.
|
| 187 |
+
"""
|
| 188 |
+
try:
|
| 189 |
+
if input_image is None:
|
| 190 |
+
raise gr.Error("Please upload an input image.")
|
| 191 |
+
|
| 192 |
+
num_frames = get_num_frames(duration_seconds)
|
| 193 |
+
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
|
| 194 |
+
|
| 195 |
+
resized = resize_image(input_image)
|
| 196 |
+
|
| 197 |
+
# -------------------- model inference --------------------
|
| 198 |
+
out = pipe(
|
| 199 |
+
image=resized,
|
| 200 |
+
prompt=prompt,
|
| 201 |
+
negative_prompt=negative_prompt,
|
| 202 |
+
height=resized.height,
|
| 203 |
+
width=resized.width,
|
| 204 |
+
num_frames=num_frames,
|
| 205 |
+
guidance_scale=float(guidance_scale),
|
| 206 |
+
guidance_scale_2=float(guidance_scale_2),
|
| 207 |
+
num_inference_steps=int(steps),
|
| 208 |
+
generator=torch.Generator(device="cuda").manual_seed(current_seed),
|
| 209 |
+
)
|
| 210 |
+
frames = out.frames[0]
|
| 211 |
+
|
| 212 |
+
if not frames or len(frames) == 0:
|
| 213 |
+
raise RuntimeError("Pipeline returned an empty frame list.")
|
| 214 |
+
|
| 215 |
+
# -------------------- write MP4 --------------------
|
| 216 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
| 217 |
+
video_path = tmp.name
|
| 218 |
+
export_to_video(frames, video_path, fps=FIXED_FPS)
|
| 219 |
|
| 220 |
+
# -------------------- clean up --------------------
|
| 221 |
+
gc.collect()
|
| 222 |
+
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
+
return video_path, current_seed
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
+
except Exception as exc:
|
| 227 |
+
# -----------------------------------------------------------------
|
| 228 |
+
# Print a full traceback to the Space console – you’ll see it in the
|
| 229 |
+
# “Logs” tab. After you identify the problem you can simply delete
|
| 230 |
+
# this whole try/except block.
|
| 231 |
+
# -----------------------------------------------------------------
|
| 232 |
+
tb = traceback.format_exc()
|
| 233 |
+
print("\n=== VIDEO‑GENERATION ERROR =================================================")
|
| 234 |
+
print(tb)
|
| 235 |
+
print("============================================================================\n")
|
| 236 |
+
# Re‑raise as a user‑friendly Gradio error
|
| 237 |
+
raise gr.Error(f"Video generation failed: {type(exc).__name__}: {exc}")
|
| 238 |
|
| 239 |
|
| 240 |
# ------------------------------------------------------------
|
| 241 |
+
# UI – unchanged visual theme (all CSS, 500‑error guard, gap, etc.)
|
| 242 |
# ------------------------------------------------------------
|
| 243 |
def create_demo():
|
| 244 |
with gr.Blocks(css="", title="Fast Image to Video") as demo:
|
| 245 |
+
# ----- 500‑error guard (exact copy) -----
|
| 246 |
gr.HTML(
|
| 247 |
"""
|
| 248 |
<script>
|
|
|
|
| 254 |
"""
|
| 255 |
)
|
| 256 |
|
| 257 |
+
# ----- all custom CSS (exactly as you posted) -----
|
| 258 |
gr.HTML(
|
| 259 |
"""
|
| 260 |
<style>
|
|
|
|
| 281 |
body::before{
|
| 282 |
content:"";
|
| 283 |
display:block;
|
| 284 |
+
height:600px; /* top gap */
|
| 285 |
background:#000 !important;
|
| 286 |
}
|
| 287 |
.gr-blocks,.container{
|
|
|
|
| 364 |
box-sizing:border-box !important;
|
| 365 |
display:block !important;
|
| 366 |
}
|
| 367 |
+
/* ---- hide all Gradio progress UI ---- */
|
| 368 |
.image-container[aria-label="Generated Video"] .progress-text,
|
| 369 |
.image-container[aria-label="Generated Video"] .gr-progress,
|
| 370 |
.image-container[aria-label="Generated Video"] .gr-progress-bar,
|
|
|
|
| 391 |
.image-container[aria-label="Generated Video"] *[class*="progress"],
|
| 392 |
.image-container[aria-label="Generated Video"] *[class*="loading"],
|
| 393 |
.image-container[aria-label="Generated Video"] *[class*="status"],
|
| 394 |
+
.image-container[aria-label="Generated Video"] *[class*="spinner],
|
| 395 |
.progress-text,.gr-progress,.gr-progress-bar,.progress-bar,
|
| 396 |
[data-testid="progress"],.status,.loading,.spinner,.gr-spinner,
|
| 397 |
.gr-loading,.gr-status,.gpu-init,.initializing,.queue,
|
|
|
|
| 555 |
"""
|
| 556 |
)
|
| 557 |
|
| 558 |
+
# ------------------- UI components (same layout as original) -------------------
|
| 559 |
with gr.Row(elem_id="general_items"):
|
| 560 |
gr.Markdown("# ")
|
| 561 |
gr.Markdown(
|
|
|
|
| 579 |
placeholder="Describe the desired animation or motion",
|
| 580 |
elem_classes=["gradio-component"],
|
| 581 |
)
|
|
|
|
|
|
|
| 582 |
generate_button = gr.Button(
|
| 583 |
"Generate Video",
|
| 584 |
variant="primary",
|
|
|
|
| 593 |
elem_classes=["gradio-component", "image-container"],
|
| 594 |
)
|
| 595 |
|
| 596 |
+
# ------------------- wiring -------------------
|
| 597 |
generate_button.click(
|
| 598 |
fn=generate_video,
|
| 599 |
inputs=[
|
| 600 |
+
input_image,
|
| 601 |
+
prompt,
|
| 602 |
+
gr.State(value=6), # steps
|
| 603 |
gr.State(value=default_negative_prompt), # negative_prompt
|
| 604 |
+
gr.State(value=3.2), # duration_seconds
|
| 605 |
+
gr.State(value=1.5), # guidance_scale
|
| 606 |
+
gr.State(value=1.5), # guidance_scale_2
|
| 607 |
+
gr.State(value=42), # seed
|
| 608 |
+
gr.State(value=True), # randomize_seed
|
| 609 |
],
|
| 610 |
+
outputs=[output_video, gr.State(value=42)],
|
| 611 |
)
|
| 612 |
|
| 613 |
return demo
|