Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,7 @@ import gc
|
|
| 6 |
import tempfile
|
| 7 |
import numpy as np
|
| 8 |
from PIL import Image
|
|
|
|
| 9 |
|
| 10 |
import gradio as gr
|
| 11 |
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
|
|
@@ -18,9 +19,9 @@ from torchao.quantization import Int8WeightOnlyConfig
|
|
| 18 |
|
| 19 |
import aoti
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
# -------------------------- CONFIG
|
| 23 |
-
#
|
| 24 |
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
|
| 25 |
|
| 26 |
MAX_DIM = 832
|
|
@@ -44,9 +45,9 @@ default_negative_prompt = (
|
|
| 44 |
"形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
|
| 45 |
)
|
| 46 |
|
| 47 |
-
#
|
| 48 |
-
# ----------------------- MODEL LOADING
|
| 49 |
-
#
|
| 50 |
pipe = WanImageToVideoPipeline.from_pretrained(
|
| 51 |
MODEL_ID,
|
| 52 |
transformer=WanTransformer3DModel.from_pretrained(
|
|
@@ -90,9 +91,9 @@ quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
|
|
| 90 |
aoti.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Wan2", variant="fp8da")
|
| 91 |
aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da")
|
| 92 |
|
| 93 |
-
#
|
| 94 |
-
# -------------------------- HELPERS
|
| 95 |
-
#
|
| 96 |
def resize_image(image: Image.Image) -> Image.Image:
|
| 97 |
"""Resize / crop the input image so the model receives a valid size."""
|
| 98 |
width, height = image.size
|
|
@@ -107,20 +108,18 @@ def resize_image(image: Image.Image) -> Image.Image:
|
|
| 107 |
img = image
|
| 108 |
|
| 109 |
if aspect_ratio > MAX_ASPECT_RATIO:
|
| 110 |
-
# Very wide → crop width
|
| 111 |
crop_w = int(round(height * MAX_ASPECT_RATIO))
|
| 112 |
left = (width - crop_w) // 2
|
| 113 |
img = image.crop((left, 0, left + crop_w, height))
|
| 114 |
elif aspect_ratio < MIN_ASPECT_RATIO:
|
| 115 |
-
# Very tall → crop height
|
| 116 |
crop_h = int(round(width / MIN_ASPECT_RATIO))
|
| 117 |
top = (height - crop_h) // 2
|
| 118 |
img = image.crop((0, top, width, top + crop_h))
|
| 119 |
else:
|
| 120 |
-
if width > height:
|
| 121 |
target_w = MAX_DIM
|
| 122 |
target_h = int(round(target_w / aspect_ratio))
|
| 123 |
-
else:
|
| 124 |
target_h = MAX_DIM
|
| 125 |
target_w = int(round(target_h * aspect_ratio))
|
| 126 |
img = image
|
|
@@ -155,7 +154,7 @@ def get_duration(
|
|
| 155 |
randomize_seed,
|
| 156 |
progress,
|
| 157 |
):
|
| 158 |
-
"""
|
| 159 |
BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
|
| 160 |
BASE_STEP_DURATION = 15
|
| 161 |
|
|
@@ -165,13 +164,13 @@ def get_duration(
|
|
| 165 |
step_duration = BASE_STEP_DURATION * factor ** 1.5
|
| 166 |
est = 10 + int(steps) * step_duration
|
| 167 |
|
| 168 |
-
#
|
| 169 |
return min(est, 30)
|
| 170 |
|
| 171 |
|
| 172 |
@spaces.GPU
|
| 173 |
def translate_albanian_to_english(text):
|
| 174 |
-
"""
|
| 175 |
if not text.strip():
|
| 176 |
raise gr.Error("Please enter a description.")
|
| 177 |
for attempt in range(2):
|
|
@@ -190,9 +189,9 @@ def translate_albanian_to_english(text):
|
|
| 190 |
raise gr.Error("Translation failed. Please try again.")
|
| 191 |
|
| 192 |
|
| 193 |
-
#
|
| 194 |
-
# -------------------------- MAIN FUNCTION
|
| 195 |
-
#
|
| 196 |
@spaces.GPU(duration=get_duration)
|
| 197 |
def generate_video(
|
| 198 |
input_image,
|
|
@@ -204,56 +203,93 @@ def generate_video(
|
|
| 204 |
guidance_scale_2=1.5,
|
| 205 |
seed=42,
|
| 206 |
randomize_seed=False,
|
| 207 |
-
progress=None, #
|
| 208 |
):
|
| 209 |
-
"""Generate a video from an image + prompt."""
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
# -----------------------------------------------------------------
|
| 238 |
-
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
| 239 |
-
video_path = tmp.name
|
| 240 |
-
export_to_video(output_frames, video_path, fps=FIXED_FPS)
|
| 241 |
|
| 242 |
-
|
| 243 |
-
gc.collect()
|
| 244 |
-
torch.cuda.empty_cache()
|
| 245 |
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
|
|
|
| 248 |
|
| 249 |
-
|
| 250 |
-
# --------------------------- UI -------------------------------
|
| 251 |
-
# ------------------------------------------------------------
|
| 252 |
-
def create_demo():
|
| 253 |
-
with gr.Blocks(css="", title="Fast Image to Video") as demo:
|
| 254 |
# -----------------------------------------------------------------
|
| 255 |
-
#
|
|
|
|
| 256 |
# -----------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
gr.HTML(
|
| 258 |
"""
|
| 259 |
<script>
|
|
@@ -265,9 +301,7 @@ def create_demo():
|
|
| 265 |
"""
|
| 266 |
)
|
| 267 |
|
| 268 |
-
#
|
| 269 |
-
# All your custom CSS / visual theme – **unaltered**
|
| 270 |
-
# -----------------------------------------------------------------
|
| 271 |
gr.HTML(
|
| 272 |
"""
|
| 273 |
<style>
|
|
@@ -295,7 +329,7 @@ def create_demo():
|
|
| 295 |
body::before{
|
| 296 |
content:"";
|
| 297 |
display:block;
|
| 298 |
-
height:600px; /* <--
|
| 299 |
background:#000 !important;
|
| 300 |
}
|
| 301 |
.gr-blocks,.container{
|
|
@@ -378,7 +412,7 @@ def create_demo():
|
|
| 378 |
box-sizing:border-box !important;
|
| 379 |
display:block !important;
|
| 380 |
}
|
| 381 |
-
/*
|
| 382 |
.image-container[aria-label="Generated Video"] .progress-text,
|
| 383 |
.image-container[aria-label="Generated Video"] .gr-progress,
|
| 384 |
.image-container[aria-label="Generated Video"] .gr-progress-bar,
|
|
@@ -429,7 +463,7 @@ def create_demo():
|
|
| 429 |
pointer-events:none!important;
|
| 430 |
overflow:hidden!important;
|
| 431 |
}
|
| 432 |
-
/*
|
| 433 |
.image-container[aria-label="Input Image"] .file-upload,
|
| 434 |
.image-container[aria-label="Input Image"] .file-preview,
|
| 435 |
.image-container[aria-label="Input Image"] .image-actions,
|
|
@@ -523,9 +557,7 @@ def create_demo():
|
|
| 523 |
animation:slide 4s ease-in-out infinite,glow-hover 3s ease-in-out infinite;
|
| 524 |
transform:scale(1.05);
|
| 525 |
}
|
| 526 |
-
button[aria-label="Fullscreen"],button[aria-label="Share"]{
|
| 527 |
-
display:none!important;
|
| 528 |
-
}
|
| 529 |
button[aria-label="Download"]{
|
| 530 |
transform:scale(3);
|
| 531 |
transform-origin:top right;
|
|
@@ -541,9 +573,7 @@ def create_demo():
|
|
| 541 |
button[aria-label="Download"]:hover{
|
| 542 |
box-shadow:0 0 12px rgba(255,255,255,0.5)!important;
|
| 543 |
}
|
| 544 |
-
footer,.gr-button-secondary{
|
| 545 |
-
display:none!important;
|
| 546 |
-
}
|
| 547 |
.gr-group{
|
| 548 |
background:#000!important;
|
| 549 |
border:none!important;
|
|
@@ -573,9 +603,7 @@ def create_demo():
|
|
| 573 |
"""
|
| 574 |
)
|
| 575 |
|
| 576 |
-
#
|
| 577 |
-
# UI layout – **exactly the same structure you built**
|
| 578 |
-
# -----------------------------------------------------------------
|
| 579 |
with gr.Row(elem_id="general_items"):
|
| 580 |
gr.Markdown("# ")
|
| 581 |
gr.Markdown(
|
|
@@ -613,9 +641,7 @@ def create_demo():
|
|
| 613 |
elem_classes=["gradio-component", "image-container"],
|
| 614 |
)
|
| 615 |
|
| 616 |
-
#
|
| 617 |
-
# Wiring – unchanged component order (matches generate_video signature)
|
| 618 |
-
# -----------------------------------------------------------------
|
| 619 |
generate_button.click(
|
| 620 |
fn=generate_video,
|
| 621 |
inputs=[
|
|
@@ -637,5 +663,5 @@ def create_demo():
|
|
| 637 |
|
| 638 |
if __name__ == "__main__":
|
| 639 |
demo = create_demo()
|
| 640 |
-
#
|
| 641 |
demo.queue().launch(share=True)
|
|
|
|
| 6 |
import tempfile
|
| 7 |
import numpy as np
|
| 8 |
from PIL import Image
|
| 9 |
+
import os
|
| 10 |
|
| 11 |
import gradio as gr
|
| 12 |
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
|
|
|
|
| 19 |
|
| 20 |
import aoti
|
| 21 |
|
| 22 |
+
# ------------------------------------------------------------------
|
| 23 |
+
# -------------------------- CONFIG -------------------------------
|
| 24 |
+
# ------------------------------------------------------------------
|
| 25 |
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
|
| 26 |
|
| 27 |
MAX_DIM = 832
|
|
|
|
| 45 |
"形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
|
| 46 |
)
|
| 47 |
|
| 48 |
+
# ------------------------------------------------------------------
|
| 49 |
+
# ----------------------- MODEL LOADING ---------------------------
|
| 50 |
+
# ------------------------------------------------------------------
|
| 51 |
pipe = WanImageToVideoPipeline.from_pretrained(
|
| 52 |
MODEL_ID,
|
| 53 |
transformer=WanTransformer3DModel.from_pretrained(
|
|
|
|
| 91 |
aoti.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Wan2", variant="fp8da")
|
| 92 |
aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da")
|
| 93 |
|
| 94 |
+
# ------------------------------------------------------------------
|
| 95 |
+
# -------------------------- HELPERS -----------------------------
|
| 96 |
+
# ------------------------------------------------------------------
|
| 97 |
def resize_image(image: Image.Image) -> Image.Image:
|
| 98 |
"""Resize / crop the input image so the model receives a valid size."""
|
| 99 |
width, height = image.size
|
|
|
|
| 108 |
img = image
|
| 109 |
|
| 110 |
if aspect_ratio > MAX_ASPECT_RATIO:
|
|
|
|
| 111 |
crop_w = int(round(height * MAX_ASPECT_RATIO))
|
| 112 |
left = (width - crop_w) // 2
|
| 113 |
img = image.crop((left, 0, left + crop_w, height))
|
| 114 |
elif aspect_ratio < MIN_ASPECT_RATIO:
|
|
|
|
| 115 |
crop_h = int(round(width / MIN_ASPECT_RATIO))
|
| 116 |
top = (height - crop_h) // 2
|
| 117 |
img = image.crop((0, top, width, top + crop_h))
|
| 118 |
else:
|
| 119 |
+
if width > height:
|
| 120 |
target_w = MAX_DIM
|
| 121 |
target_h = int(round(target_w / aspect_ratio))
|
| 122 |
+
else:
|
| 123 |
target_h = MAX_DIM
|
| 124 |
target_w = int(round(target_h * aspect_ratio))
|
| 125 |
img = image
|
|
|
|
| 154 |
randomize_seed,
|
| 155 |
progress,
|
| 156 |
):
|
| 157 |
+
"""GPU‑time estimator used by @spaces.GPU."""
|
| 158 |
BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
|
| 159 |
BASE_STEP_DURATION = 15
|
| 160 |
|
|
|
|
| 164 |
step_duration = BASE_STEP_DURATION * factor ** 1.5
|
| 165 |
est = 10 + int(steps) * step_duration
|
| 166 |
|
| 167 |
+
# never block the GPU for >30 s (feel free to raise while debugging)
|
| 168 |
return min(est, 30)
|
| 169 |
|
| 170 |
|
| 171 |
@spaces.GPU
|
| 172 |
def translate_albanian_to_english(text):
|
| 173 |
+
"""Helper – kept unchanged (not used in the UI)."""
|
| 174 |
if not text.strip():
|
| 175 |
raise gr.Error("Please enter a description.")
|
| 176 |
for attempt in range(2):
|
|
|
|
| 189 |
raise gr.Error("Translation failed. Please try again.")
|
| 190 |
|
| 191 |
|
| 192 |
+
# ------------------------------------------------------------------
|
| 193 |
+
# -------------------------- MAIN FUNCTION -------------------------
|
| 194 |
+
# ------------------------------------------------------------------
|
| 195 |
@spaces.GPU(duration=get_duration)
|
| 196 |
def generate_video(
|
| 197 |
input_image,
|
|
|
|
| 203 |
guidance_scale_2=1.5,
|
| 204 |
seed=42,
|
| 205 |
randomize_seed=False,
|
| 206 |
+
progress=None, # optional – no UI impact
|
| 207 |
):
|
| 208 |
+
"""Generate a video from an image + prompt – now wrapped in robust try/except."""
|
| 209 |
+
try:
|
| 210 |
+
if input_image is None:
|
| 211 |
+
raise gr.Error("Please upload an input image.")
|
| 212 |
+
|
| 213 |
+
# --------------------------------------------------------------
|
| 214 |
+
# 1️⃣ Compute number of frames & seed
|
| 215 |
+
# --------------------------------------------------------------
|
| 216 |
+
num_frames = get_num_frames(duration_seconds)
|
| 217 |
+
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
|
| 218 |
+
|
| 219 |
+
# --------------------------------------------------------------
|
| 220 |
+
# 2️⃣ Resize image to model‑compatible dimensions
|
| 221 |
+
# --------------------------------------------------------------
|
| 222 |
+
resized = resize_image(input_image)
|
| 223 |
+
|
| 224 |
+
# --------------------------------------------------------------
|
| 225 |
+
# 3️⃣ Model inference
|
| 226 |
+
# --------------------------------------------------------------
|
| 227 |
+
out = pipe(
|
| 228 |
+
image=resized,
|
| 229 |
+
prompt=prompt,
|
| 230 |
+
negative_prompt=negative_prompt,
|
| 231 |
+
height=resized.height,
|
| 232 |
+
width=resized.width,
|
| 233 |
+
num_frames=num_frames,
|
| 234 |
+
guidance_scale=float(guidance_scale),
|
| 235 |
+
guidance_scale_2=float(guidance_scale_2),
|
| 236 |
+
num_inference_steps=int(steps),
|
| 237 |
+
generator=torch.Generator(device="cuda").manual_seed(current_seed),
|
| 238 |
+
)
|
| 239 |
+
# `out.frames` is a list of batches → we want the first batch
|
| 240 |
+
output_frames = out.frames[0]
|
| 241 |
+
|
| 242 |
+
if not output_frames or len(output_frames) == 0:
|
| 243 |
+
raise RuntimeError("Pipeline returned an empty frame list.")
|
| 244 |
+
|
| 245 |
+
# --------------------------------------------------------------
|
| 246 |
+
# 4️⃣ Write temporary MP4 (requires ffmpeg)
|
| 247 |
+
# --------------------------------------------------------------
|
| 248 |
+
# Ensure ffmpeg is present – the Space image usually has it, but just in case:
|
| 249 |
+
if not any(
|
| 250 |
+
os.access(os.path.join(p, "ffmpeg"), os.X_OK) for p in os.getenv("PATH", "").split(":")
|
| 251 |
+
):
|
| 252 |
+
# If ffmpeg is missing we raise a clear error; you can install it via
|
| 253 |
+
# `!apt-get update && apt-get install -y ffmpeg` in a startup cell.
|
| 254 |
+
raise FileNotFoundError(
|
| 255 |
+
"ffmpeg binary not found. Install it in the Space with `apt-get install -y ffmpeg`."
|
| 256 |
+
)
|
| 257 |
|
| 258 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
|
| 259 |
+
video_path = tmp.name
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
+
export_to_video(output_frames, video_path, fps=FIXED_FPS)
|
|
|
|
|
|
|
| 262 |
|
| 263 |
+
# --------------------------------------------------------------
|
| 264 |
+
# 5️⃣ Clean‑up GPU memory before returning (helps repeated calls)
|
| 265 |
+
# --------------------------------------------------------------
|
| 266 |
+
gc.collect()
|
| 267 |
+
torch.cuda.empty_cache()
|
| 268 |
|
| 269 |
+
return video_path, current_seed
|
| 270 |
|
| 271 |
+
except Exception as exc:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
# -----------------------------------------------------------------
|
| 273 |
+
# Log the full traceback to the Space console – this is what you’ll see
|
| 274 |
+
# in the “View logs” tab. Gradio will display a nice red error box.
|
| 275 |
# -----------------------------------------------------------------
|
| 276 |
+
import traceback
|
| 277 |
+
|
| 278 |
+
tb = traceback.format_exc()
|
| 279 |
+
print("\n--- VIDEO‑GENERATION ERROR ------------------------------------------------")
|
| 280 |
+
print(tb)
|
| 281 |
+
print("----------------------------------------------------------------------------\n")
|
| 282 |
+
|
| 283 |
+
# Re‑raise as a Gradio‑friendly error (the message will appear in the UI)
|
| 284 |
+
raise gr.Error(f"Video generation failed: {str(exc)}")
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# ------------------------------------------------------------------
|
| 288 |
+
# --------------------------- UI -----------------------------------
|
| 289 |
+
# ------------------------------------------------------------------
|
| 290 |
+
def create_demo():
|
| 291 |
+
with gr.Blocks(css="", title="Fast Image to Video") as demo:
|
| 292 |
+
# ------------------- 500‑error guard (unchanged) -------------------
|
| 293 |
gr.HTML(
|
| 294 |
"""
|
| 295 |
<script>
|
|
|
|
| 301 |
"""
|
| 302 |
)
|
| 303 |
|
| 304 |
+
# ------------------- All your custom CSS (exact copy) -------------
|
|
|
|
|
|
|
| 305 |
gr.HTML(
|
| 306 |
"""
|
| 307 |
<style>
|
|
|
|
| 329 |
body::before{
|
| 330 |
content:"";
|
| 331 |
display:block;
|
| 332 |
+
height:600px; /* <-- the 600 px top gap you requested */
|
| 333 |
background:#000 !important;
|
| 334 |
}
|
| 335 |
.gr-blocks,.container{
|
|
|
|
| 412 |
box-sizing:border-box !important;
|
| 413 |
display:block !important;
|
| 414 |
}
|
| 415 |
+
/* ---- hide all Gradio progress UI ---- */
|
| 416 |
.image-container[aria-label="Generated Video"] .progress-text,
|
| 417 |
.image-container[aria-label="Generated Video"] .gr-progress,
|
| 418 |
.image-container[aria-label="Generated Video"] .gr-progress-bar,
|
|
|
|
| 463 |
pointer-events:none!important;
|
| 464 |
overflow:hidden!important;
|
| 465 |
}
|
| 466 |
+
/* ---- hide toolbar buttons ---- */
|
| 467 |
.image-container[aria-label="Input Image"] .file-upload,
|
| 468 |
.image-container[aria-label="Input Image"] .file-preview,
|
| 469 |
.image-container[aria-label="Input Image"] .image-actions,
|
|
|
|
| 557 |
animation:slide 4s ease-in-out infinite,glow-hover 3s ease-in-out infinite;
|
| 558 |
transform:scale(1.05);
|
| 559 |
}
|
| 560 |
+
button[aria-label="Fullscreen"],button[aria-label="Share"]{display:none!important;}
|
|
|
|
|
|
|
| 561 |
button[aria-label="Download"]{
|
| 562 |
transform:scale(3);
|
| 563 |
transform-origin:top right;
|
|
|
|
| 573 |
button[aria-label="Download"]:hover{
|
| 574 |
box-shadow:0 0 12px rgba(255,255,255,0.5)!important;
|
| 575 |
}
|
| 576 |
+
footer,.gr-button-secondary{display:none!important;}
|
|
|
|
|
|
|
| 577 |
.gr-group{
|
| 578 |
background:#000!important;
|
| 579 |
border:none!important;
|
|
|
|
| 603 |
"""
|
| 604 |
)
|
| 605 |
|
| 606 |
+
# ------------------- UI layout (unchanged) --------------------
|
|
|
|
|
|
|
| 607 |
with gr.Row(elem_id="general_items"):
|
| 608 |
gr.Markdown("# ")
|
| 609 |
gr.Markdown(
|
|
|
|
| 641 |
elem_classes=["gradio-component", "image-container"],
|
| 642 |
)
|
| 643 |
|
| 644 |
+
# ------------------- Wiring (exact order) --------------------
|
|
|
|
|
|
|
| 645 |
generate_button.click(
|
| 646 |
fn=generate_video,
|
| 647 |
inputs=[
|
|
|
|
| 663 |
|
| 664 |
if __name__ == "__main__":
|
| 665 |
demo = create_demo()
|
| 666 |
+
# Keep the same launch flags you originally used
|
| 667 |
demo.queue().launch(share=True)
|