HAL1993 commited on
Commit
864f23e
·
verified ·
1 Parent(s): d672320

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -69
app.py CHANGED
@@ -9,6 +9,7 @@ import tempfile
9
  import numpy as np
10
  from PIL import Image
11
  import os
 
12
 
13
  import gradio as gr
14
  from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
@@ -21,7 +22,7 @@ from torchao.quantization import Int8WeightOnlyConfig
21
 
22
  import aoti
23
 
24
- # -------------------- constants --------------------
25
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
26
 
27
  MAX_DIM = 832
@@ -45,7 +46,7 @@ default_negative_prompt = (
45
  "形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
46
  )
47
 
48
- # -------------------- load the pipeline --------------------
49
  pipe = WanImageToVideoPipeline.from_pretrained(
50
  MODEL_ID,
51
  transformer=WanTransformer3DModel.from_pretrained(
@@ -90,13 +91,12 @@ aoti.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Wan2", variant="fp8da")
90
  aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da")
91
 
92
  # ------------------------------------------------------------
93
- # HELPER FUNCTIONS
94
  # ------------------------------------------------------------
95
  def resize_image(image: Image.Image) -> Image.Image:
96
  """Resize / crop the input image to a size the model accepts."""
97
  w, h = image.size
98
 
99
- # square shortcut
100
  if w == h:
101
  return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
102
 
@@ -105,19 +105,19 @@ def resize_image(image: Image.Image) -> Image.Image:
105
  MIN_AR = MIN_DIM / MAX_DIM
106
  img = image
107
 
108
- if aspect > MAX_AR: # very wide → crop width
109
  crop_w = int(round(h * MAX_AR))
110
  left = (w - crop_w) // 2
111
  img = image.crop((left, 0, left + crop_w, h))
112
- elif aspect < MIN_AR: # very tall → crop height
113
  crop_h = int(round(w / MIN_AR))
114
  top = (h - crop_h) // 2
115
  img = image.crop((0, top, w, top + crop_h))
116
  else:
117
- if w > h: # landscape
118
  target_w = MAX_DIM
119
  target_h = int(round(target_w / aspect))
120
- else: # portrait
121
  target_h = MAX_DIM
122
  target_w = int(round(target_h * aspect))
123
  img = image
@@ -152,7 +152,7 @@ def get_duration(
152
  randomize_seed,
153
  progress,
154
  ):
155
- """GPU‑time estimator used by @spaces.GPU."""
156
  BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
157
  BASE_STEP_DURATION = 15
158
 
@@ -162,12 +162,11 @@ def get_duration(
162
  step_duration = BASE_STEP_DURATION * factor ** 1.5
163
  est = 10 + int(steps) * step_duration
164
 
165
- # never block the GPU for >30 s (feel free to raise while debugging)
166
- return min(est, 30)
167
 
168
 
169
  # ------------------------------------------------------------
170
- # MAIN GENERATION FUNCTION
171
  # ------------------------------------------------------------
172
  @spaces.GPU(duration=get_duration)
173
  def generate_video(
@@ -180,56 +179,70 @@ def generate_video(
180
  guidance_scale_2=1.5,
181
  seed=42,
182
  randomize_seed=False,
183
- progress=None, # optional – Gradio will inject if needed
184
  ):
185
- """Run the Wan‑2.2 pipeline and return an MP4 file."""
186
- if input_image is None:
187
- raise gr.Error("Please upload an input image.")
188
-
189
- num_frames = get_num_frames(duration_seconds)
190
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
191
-
192
- resized = resize_image(input_image)
193
-
194
- # -----------------------------------------------------------------
195
- # Model inference
196
- # -----------------------------------------------------------------
197
- out = pipe(
198
- image=resized,
199
- prompt=prompt,
200
- negative_prompt=negative_prompt,
201
- height=resized.height,
202
- width=resized.width,
203
- num_frames=num_frames,
204
- guidance_scale=float(guidance_scale),
205
- guidance_scale_2=float(guidance_scale_2),
206
- num_inference_steps=int(steps),
207
- generator=torch.Generator(device="cuda").manual_seed(current_seed),
208
- )
209
- frames = out.frames[0]
 
 
 
 
 
 
 
 
 
 
210
 
211
- # -----------------------------------------------------------------
212
- # Write temporary MP4 (ffmpeg must be present – Spaces images include it)
213
- # -----------------------------------------------------------------
214
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
215
- video_path = tmp.name
216
- export_to_video(frames, video_path, fps=FIXED_FPS)
217
 
218
- # -----------------------------------------------------------------
219
- # Clean up GPU memory for the next request
220
- # -----------------------------------------------------------------
221
- gc.collect()
222
- torch.cuda.empty_cache()
223
 
224
- return video_path, current_seed
 
 
 
 
 
 
 
 
 
 
 
225
 
226
 
227
  # ------------------------------------------------------------
228
- # UI – EXACTLY YOUR ORIGINAL LOOK & THEME
229
  # ------------------------------------------------------------
230
  def create_demo():
231
  with gr.Blocks(css="", title="Fast Image to Video") as demo:
232
- # ----------- 500‑error guard (unchanged) -----------
233
  gr.HTML(
234
  """
235
  <script>
@@ -241,7 +254,7 @@ def create_demo():
241
  """
242
  )
243
 
244
- # ----------- ALL YOUR CUSTOM CSS (copy‑paste verbatim) -----------
245
  gr.HTML(
246
  """
247
  <style>
@@ -268,7 +281,7 @@ def create_demo():
268
  body::before{
269
  content:"";
270
  display:block;
271
- height:600px; /* <-- the top gap you designed */
272
  background:#000 !important;
273
  }
274
  .gr-blocks,.container{
@@ -351,7 +364,7 @@ def create_demo():
351
  box-sizing:border-box !important;
352
  display:block !important;
353
  }
354
- /* ---- hide every Gradio progress element ---- */
355
  .image-container[aria-label="Generated Video"] .progress-text,
356
  .image-container[aria-label="Generated Video"] .gr-progress,
357
  .image-container[aria-label="Generated Video"] .gr-progress-bar,
@@ -378,7 +391,7 @@ def create_demo():
378
  .image-container[aria-label="Generated Video"] *[class*="progress"],
379
  .image-container[aria-label="Generated Video"] *[class*="loading"],
380
  .image-container[aria-label="Generated Video"] *[class*="status"],
381
- .image-container[aria-label="Generated Video"] *[class*="spinner"],
382
  .progress-text,.gr-progress,.gr-progress-bar,.progress-bar,
383
  [data-testid="progress"],.status,.loading,.spinner,.gr-spinner,
384
  .gr-loading,.gr-status,.gpu-init,.initializing,.queue,
@@ -542,7 +555,7 @@ def create_demo():
542
  """
543
  )
544
 
545
- # ------------------- UI layout (identical to your original design) -------------------
546
  with gr.Row(elem_id="general_items"):
547
  gr.Markdown("# ")
548
  gr.Markdown(
@@ -566,8 +579,6 @@ def create_demo():
566
  placeholder="Describe the desired animation or motion",
567
  elem_classes=["gradio-component"],
568
  )
569
- # (the rest of the advanced sliders you had in the “original” UI are omitted
570
- # because your current functional code only expects the arguments below)
571
  generate_button = gr.Button(
572
  "Generate Video",
573
  variant="primary",
@@ -582,21 +593,21 @@ def create_demo():
582
  elem_classes=["gradio-component", "image-container"],
583
  )
584
 
585
- # ------------------- Wire the button -------------------
586
  generate_button.click(
587
  fn=generate_video,
588
  inputs=[
589
- input_image, # image
590
- prompt, # prompt
591
- gr.State(value=6), # steps (default 6)
592
  gr.State(value=default_negative_prompt), # negative_prompt
593
- gr.State(value=3.2), # duration_seconds (you used 3.2 in the earlier clone)
594
- gr.State(value=1.5), # guidance_scale
595
- gr.State(value=1.5), # guidance_scale_2
596
- gr.State(value=42), # seed
597
- gr.State(value=True), # randomize_seed
598
  ],
599
- outputs=[output_video, gr.State(value=42)], # second output is the seed
600
  )
601
 
602
  return demo
 
9
  import numpy as np
10
  from PIL import Image
11
  import os
12
+ import traceback
13
 
14
  import gradio as gr
15
  from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
 
22
 
23
  import aoti
24
 
25
+ # ------------------- constants -------------------
26
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
27
 
28
  MAX_DIM = 832
 
46
  "形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
47
  )
48
 
49
+ # ------------------- load pipeline -------------------
50
  pipe = WanImageToVideoPipeline.from_pretrained(
51
  MODEL_ID,
52
  transformer=WanTransformer3DModel.from_pretrained(
 
91
  aoti.aoti_blocks_load(pipe.transformer_2, "zerogpu-aoti/Wan2", variant="fp8da")
92
 
93
  # ------------------------------------------------------------
94
+ # HELPERS
95
  # ------------------------------------------------------------
96
  def resize_image(image: Image.Image) -> Image.Image:
97
  """Resize / crop the input image to a size the model accepts."""
98
  w, h = image.size
99
 
 
100
  if w == h:
101
  return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
102
 
 
105
  MIN_AR = MIN_DIM / MAX_DIM
106
  img = image
107
 
108
+ if aspect > MAX_AR: # very wide
109
  crop_w = int(round(h * MAX_AR))
110
  left = (w - crop_w) // 2
111
  img = image.crop((left, 0, left + crop_w, h))
112
+ elif aspect < MIN_AR: # very tall
113
  crop_h = int(round(w / MIN_AR))
114
  top = (h - crop_h) // 2
115
  img = image.crop((0, top, w, top + crop_h))
116
  else:
117
+ if w > h: # landscape
118
  target_w = MAX_DIM
119
  target_h = int(round(target_w / aspect))
120
+ else: # portrait
121
  target_h = MAX_DIM
122
  target_w = int(round(target_h * aspect))
123
  img = image
 
152
  randomize_seed,
153
  progress,
154
  ):
155
+ """GPU‑time estimator for @spaces.GPU."""
156
  BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
157
  BASE_STEP_DURATION = 15
158
 
 
162
  step_duration = BASE_STEP_DURATION * factor ** 1.5
163
  est = 10 + int(steps) * step_duration
164
 
165
+ return min(est, 30) # safety cap
 
166
 
167
 
168
  # ------------------------------------------------------------
169
+ # MAIN GENERATION FUNCTION – now with error logging
170
  # ------------------------------------------------------------
171
  @spaces.GPU(duration=get_duration)
172
  def generate_video(
 
179
  guidance_scale_2=1.5,
180
  seed=42,
181
  randomize_seed=False,
182
+ progress=None, # optional – Gradio will inject if needed
183
  ):
184
+ """
185
+ Run the Wan‑2.2 pipeline and return an MP4 file.
186
+ Any exception is caught, printed to the Space logs, and re‑raised as a Gradio error.
187
+ """
188
+ try:
189
+ if input_image is None:
190
+ raise gr.Error("Please upload an input image.")
191
+
192
+ num_frames = get_num_frames(duration_seconds)
193
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
194
+
195
+ resized = resize_image(input_image)
196
+
197
+ # -------------------- model inference --------------------
198
+ out = pipe(
199
+ image=resized,
200
+ prompt=prompt,
201
+ negative_prompt=negative_prompt,
202
+ height=resized.height,
203
+ width=resized.width,
204
+ num_frames=num_frames,
205
+ guidance_scale=float(guidance_scale),
206
+ guidance_scale_2=float(guidance_scale_2),
207
+ num_inference_steps=int(steps),
208
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
209
+ )
210
+ frames = out.frames[0]
211
+
212
+ if not frames or len(frames) == 0:
213
+ raise RuntimeError("Pipeline returned an empty frame list.")
214
+
215
+ # -------------------- write MP4 --------------------
216
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
217
+ video_path = tmp.name
218
+ export_to_video(frames, video_path, fps=FIXED_FPS)
219
 
220
+ # -------------------- clean up --------------------
221
+ gc.collect()
222
+ torch.cuda.empty_cache()
 
 
 
223
 
224
+ return video_path, current_seed
 
 
 
 
225
 
226
+ except Exception as exc:
227
+ # -----------------------------------------------------------------
228
+ # Print a full traceback to the Space console – you’ll see it in the
229
+ # “Logs” tab. After you identify the problem you can simply delete
230
+ # this whole try/except block.
231
+ # -----------------------------------------------------------------
232
+ tb = traceback.format_exc()
233
+ print("\n=== VIDEO‑GENERATION ERROR =================================================")
234
+ print(tb)
235
+ print("============================================================================\n")
236
+ # Re‑raise as a user‑friendly Gradio error
237
+ raise gr.Error(f"Video generation failed: {type(exc).__name__}: {exc}")
238
 
239
 
240
  # ------------------------------------------------------------
241
+ # UI – unchanged visual theme (all CSS, 500‑error guard, gap, etc.)
242
  # ------------------------------------------------------------
243
  def create_demo():
244
  with gr.Blocks(css="", title="Fast Image to Video") as demo:
245
+ # ----- 500‑error guard (exact copy) -----
246
  gr.HTML(
247
  """
248
  <script>
 
254
  """
255
  )
256
 
257
+ # ----- all custom CSS (exactly as you posted) -----
258
  gr.HTML(
259
  """
260
  <style>
 
281
  body::before{
282
  content:"";
283
  display:block;
284
+ height:600px; /* top gap */
285
  background:#000 !important;
286
  }
287
  .gr-blocks,.container{
 
364
  box-sizing:border-box !important;
365
  display:block !important;
366
  }
367
+ /* ---- hide all Gradio progress UI ---- */
368
  .image-container[aria-label="Generated Video"] .progress-text,
369
  .image-container[aria-label="Generated Video"] .gr-progress,
370
  .image-container[aria-label="Generated Video"] .gr-progress-bar,
 
391
  .image-container[aria-label="Generated Video"] *[class*="progress"],
392
  .image-container[aria-label="Generated Video"] *[class*="loading"],
393
  .image-container[aria-label="Generated Video"] *[class*="status"],
394
+ .image-container[aria-label="Generated Video"] *[class*="spinner],
395
  .progress-text,.gr-progress,.gr-progress-bar,.progress-bar,
396
  [data-testid="progress"],.status,.loading,.spinner,.gr-spinner,
397
  .gr-loading,.gr-status,.gpu-init,.initializing,.queue,
 
555
  """
556
  )
557
 
558
+ # ------------------- UI components (same layout as original) -------------------
559
  with gr.Row(elem_id="general_items"):
560
  gr.Markdown("# ")
561
  gr.Markdown(
 
579
  placeholder="Describe the desired animation or motion",
580
  elem_classes=["gradio-component"],
581
  )
 
 
582
  generate_button = gr.Button(
583
  "Generate Video",
584
  variant="primary",
 
593
  elem_classes=["gradio-component", "image-container"],
594
  )
595
 
596
+ # ------------------- wiring -------------------
597
  generate_button.click(
598
  fn=generate_video,
599
  inputs=[
600
+ input_image,
601
+ prompt,
602
+ gr.State(value=6), # steps
603
  gr.State(value=default_negative_prompt), # negative_prompt
604
+ gr.State(value=3.2), # duration_seconds
605
+ gr.State(value=1.5), # guidance_scale
606
+ gr.State(value=1.5), # guidance_scale_2
607
+ gr.State(value=42), # seed
608
+ gr.State(value=True), # randomize_seed
609
  ],
610
+ outputs=[output_video, gr.State(value=42)],
611
  )
612
 
613
  return demo