| """Image-to-video generation using Wan 2.1 on-device via diffusers. |
| |
| Runs Wan 2.1 14B I2V locally on GPU (designed for HF Spaces ZeroGPU). |
| Same public interface as video_generator_api.py so app.py can swap backends. |
| """ |
|
|
| import json |
| import time |
| from pathlib import Path |
| from typing import Optional |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
| |
| |
| |
|
|
| MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" |
|
|
| NUM_FRAMES = 81 |
| FPS = 16 |
| NUM_INFERENCE_STEPS = 25 |
| GUIDANCE_SCALE = 5.0 |
| SEED = 42 |
|
|
| |
| MAX_AREA = 480 * 832 |
|
|
| |
| _pipe = None |
|
|
|
|
| def _get_pipe(): |
| """Load Wan 2.1 I2V pipeline (lazy singleton).""" |
| global _pipe |
| if _pipe is not None: |
| return _pipe |
|
|
| from diffusers import AutoencoderKLWan, WanImageToVideoPipeline |
| from transformers import CLIPVisionModel |
|
|
| print(f"Loading Wan 2.1 I2V pipeline ({MODEL_ID})...") |
|
|
| |
| image_encoder = CLIPVisionModel.from_pretrained( |
| MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32, |
| ) |
| vae = AutoencoderKLWan.from_pretrained( |
| MODEL_ID, subfolder="vae", torch_dtype=torch.float32, |
| ) |
|
|
| _pipe = WanImageToVideoPipeline.from_pretrained( |
| MODEL_ID, |
| vae=vae, |
| image_encoder=image_encoder, |
| torch_dtype=torch.bfloat16, |
| ) |
|
|
| |
| |
| from torchao.quantization import quantize_, Float8WeightOnlyConfig |
| quantize_(_pipe.transformer, Float8WeightOnlyConfig()) |
|
|
| _pipe.to("cuda") |
|
|
| print("Wan 2.1 I2V pipeline ready.") |
| return _pipe |
|
|
|
|
| def unload(): |
| """Unload the pipeline to free GPU memory.""" |
| global _pipe |
| if _pipe is not None: |
| _pipe.to("cpu") |
| del _pipe |
| _pipe = None |
| torch.cuda.empty_cache() |
| print("Wan 2.1 I2V pipeline unloaded.") |
|
|
|
|
| def _resize_for_480p(image: Image.Image, pipe) -> tuple[Image.Image, int, int]: |
| """Resize image to fit 480p area while respecting model patch constraints.""" |
| aspect_ratio = image.height / image.width |
| mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] |
| height = round(np.sqrt(MAX_AREA * aspect_ratio)) // mod_value * mod_value |
| width = round(np.sqrt(MAX_AREA / aspect_ratio)) // mod_value * mod_value |
| return image.resize((width, height)), height, width |
|
|
|
|
| def generate_clip( |
| image_path: str | Path, |
| prompt: str, |
| output_path: str | Path, |
| negative_prompt: str = "", |
| seed: Optional[int] = None, |
| ) -> Path: |
| """Generate a video clip from an image using on-device Wan 2.1. |
| |
| Args: |
| image_path: Path to the source image. |
| prompt: Motion/scene description. |
| output_path: Where to save the .mp4 clip. |
| negative_prompt: What to avoid. |
| seed: Random seed. |
| |
| Returns: |
| Path to the saved video clip. |
| """ |
| from diffusers.utils import export_to_video |
|
|
| output_path = Path(output_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| pipe = _get_pipe() |
|
|
| |
| image = Image.open(image_path).convert("RGB") |
| image, height, width = _resize_for_480p(image, pipe) |
|
|
| generator = None |
| if seed is not None: |
| generator = torch.Generator(device="cpu").manual_seed(seed) |
|
|
| output = pipe( |
| image=image, |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| height=height, |
| width=width, |
| num_frames=NUM_FRAMES, |
| num_inference_steps=NUM_INFERENCE_STEPS, |
| guidance_scale=GUIDANCE_SCALE, |
| generator=generator, |
| ) |
|
|
| export_to_video(output.frames[0], str(output_path), fps=FPS) |
| return output_path |
|
|
|
|
| def generate_all( |
| segments: list[dict], |
| images_dir: str | Path, |
| output_dir: str | Path, |
| seed: int = SEED, |
| progress_callback=None, |
| ) -> list[Path]: |
| """Generate video clips for all segments. |
| |
| Args: |
| segments: List of segment dicts with 'segment', 'prompt' keys. |
| images_dir: Directory containing generated images. |
| output_dir: Directory to save video clips. |
| seed: Base seed (incremented per segment). |
| |
| Returns: |
| List of saved video clip paths. |
| """ |
| images_dir = Path(images_dir) |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| paths = [] |
| for seg in segments: |
| idx = seg["segment"] |
| image_path = images_dir / f"segment_{idx:03d}.png" |
| clip_path = output_dir / f"clip_{idx:03d}.mp4" |
|
|
| if clip_path.exists(): |
| print(f" Segment {idx}/{len(segments)}: already exists, skipping") |
| paths.append(clip_path) |
| continue |
|
|
| if not image_path.exists(): |
| print(f" Segment {idx}: image not found at {image_path}, skipping") |
| continue |
|
|
| |
| prompt = seg.get("video_prompt", seg.get("scene", seg.get("prompt", ""))) |
| neg = seg.get("negative_prompt", "") |
|
|
| print(f" Segment {idx}/{len(segments)}: generating video clip...") |
| t0 = time.time() |
| generate_clip(image_path, prompt, clip_path, neg, seed=seed + idx) |
| elapsed = time.time() - t0 |
| print(f" Saved {clip_path.name} ({elapsed:.1f}s)") |
|
|
| paths.append(clip_path) |
| if progress_callback: |
| progress_callback(idx, len(segments)) |
|
|
| return paths |
|
|
|
|
| def run( |
| data_dir: str | Path, |
| seed: int = SEED, |
| progress_callback=None, |
| ) -> list[Path]: |
| """Full video generation pipeline: read segments, generate clips, save. |
| |
| Args: |
| data_dir: Run directory containing segments.json and images/. |
| seed: Base random seed. |
| |
| Returns: |
| List of saved video clip paths. |
| """ |
| data_dir = Path(data_dir) |
|
|
| with open(data_dir / "segments.json") as f: |
| segments = json.load(f) |
|
|
| paths = generate_all( |
| segments, |
| images_dir=data_dir / "images", |
| output_dir=data_dir / "clips", |
| seed=seed, |
| progress_callback=progress_callback, |
| ) |
|
|
| print(f"\nGenerated {len(paths)} video clips in {data_dir / 'clips'}") |
| return paths |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
|
|
| if len(sys.argv) < 2: |
| print("Usage: python -m src.video_generator_hf <data_dir>") |
| print(" e.g. python -m src.video_generator_hf data/Gone/run_001") |
| sys.exit(1) |
|
|
| run(sys.argv[1]) |
|
|