Spaces:
Running on Zero
Running on Zero
| import gc | |
| import html | |
| import random | |
| import sys | |
| import time | |
| import uuid | |
| from pathlib import Path | |
| from urllib.parse import quote | |
| import gradio as gr | |
| import imageio | |
| import numpy as np | |
| import ftfy | |
| try: | |
| import spaces | |
| except ImportError: | |
| class _SpacesShim: | |
| def GPU(*args, **kwargs): | |
| def decorator(fn): | |
| return fn | |
| return decorator | |
| spaces = _SpacesShim() | |
| import torch | |
| from diffusers.pipelines.wan import pipeline_wan_i2v | |
| from diffusers import AutoencoderKLWan as DiffusersWanVAE | |
| from diffusers import WanImageToVideoPipeline | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from transformers import CLIPVisionModel | |
| from src.models.Wan.autoencoder_wanT import AutoencoderKLWan | |
| from src.models.Wan.transformer_wan import WanDecoderTransformer | |
| ROOT = Path(__file__).resolve().parent | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" | |
| REFDECODER_REPO_ID = "Arrokothwhi/RefDecoder" | |
| REFDECODER_CKPT_PATH_IN_REPO = "I2V_Wan2.1/model.pt" | |
| OUTPUT_ROOT = ROOT / "gradio_outputs" | |
| NEGATIVE_PROMPT = ( | |
| "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, " | |
| "images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, " | |
| "incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, " | |
| "misshapen limbs, fused fingers, still picture, messy background, three legs, many people " | |
| "in the background, walking backwards" | |
| ) | |
| TARGET_AREA = 480 * 832 | |
| FPS = 16 | |
| NUM_FRAMES = 17 | |
| NUM_INFERENCE_STEPS = 50 | |
| GUIDANCE_SCALE = 5.0 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| PIPE_DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 | |
| # Some diffusers Wan builds reference a module-level `ftfy` during prompt cleaning. | |
| # Make it explicit so Spaces don't fail if that global was not initialized. | |
| pipeline_wan_i2v.ftfy = ftfy | |
| def download_refdecoder_ckpt(): | |
| print("[init] Downloading RefDecoder checkpoint metadata/file if needed") | |
| ckpt_path = hf_hub_download( | |
| repo_id=REFDECODER_REPO_ID, | |
| filename=REFDECODER_CKPT_PATH_IN_REPO, | |
| ) | |
| print(f"[init] RefDecoder checkpoint ready at: {ckpt_path}") | |
| return ckpt_path | |
| def download_wan_weights(): | |
| print(f"[init] Downloading Wan I2V weights from {MODEL_ID}") | |
| repo_dir = snapshot_download(repo_id=MODEL_ID) | |
| print(f"[init] Wan I2V weights ready at: {repo_dir}") | |
| return repo_dir | |
| REFDECODER_CKPT_LOCAL_PATH = download_refdecoder_ckpt() | |
| download_wan_weights() | |
| OUTPUT_ROOT.mkdir(parents=True, exist_ok=True) | |
| def log_cuda_mem(tag): | |
| if not torch.cuda.is_available(): | |
| print(f"[mem] {tag}: CUDA not available") | |
| return | |
| try: | |
| free_bytes, total_bytes = torch.cuda.mem_get_info() | |
| except RuntimeError as exc: | |
| print(f"[mem] {tag}: CUDA not currently leased ({exc})") | |
| return | |
| allocated_bytes = torch.cuda.memory_allocated() | |
| reserved_bytes = torch.cuda.memory_reserved() | |
| print( | |
| f"[mem] {tag}: " | |
| f"free={free_bytes / 1024**3:.2f} GB, " | |
| f"total={total_bytes / 1024**3:.2f} GB, " | |
| f"allocated={allocated_bytes / 1024**3:.2f} GB, " | |
| f"reserved={reserved_bytes / 1024**3:.2f} GB" | |
| ) | |
| def get_module_dtype(module): | |
| try: | |
| return next(module.parameters()).dtype | |
| except StopIteration: | |
| return PIPE_DTYPE | |
| def load_generation_pipe(): | |
| image_encoder = CLIPVisionModel.from_pretrained( | |
| MODEL_ID, | |
| subfolder="image_encoder", | |
| torch_dtype=PIPE_DTYPE, | |
| ) | |
| vae = DiffusersWanVAE.from_pretrained( | |
| MODEL_ID, | |
| subfolder="vae", | |
| torch_dtype=PIPE_DTYPE, | |
| ) | |
| pipe = WanImageToVideoPipeline.from_pretrained( | |
| MODEL_ID, | |
| vae=vae, | |
| image_encoder=image_encoder, | |
| torch_dtype=PIPE_DTYPE, | |
| ) | |
| return pipe | |
| def load_wan_vae(): | |
| vae = DiffusersWanVAE.from_pretrained( | |
| MODEL_ID, | |
| subfolder="vae", | |
| torch_dtype=PIPE_DTYPE, | |
| ) | |
| vae.eval() | |
| return vae | |
| def load_refdecoder_module(): | |
| vae = AutoencoderKLWan( | |
| dropout_p=0.0, | |
| use_reference=True, | |
| ).eval() | |
| transformer = WanDecoderTransformer( | |
| chunk=5, | |
| num_layers=10, | |
| num_heads=12, | |
| head_dim=128, | |
| reusing=True, | |
| pretrained=False, | |
| ).eval() | |
| checkpoint = torch.load(REFDECODER_CKPT_LOCAL_PATH, map_location="cpu") | |
| state_dict = checkpoint.get("state_dict", checkpoint.get("module", checkpoint)) | |
| vae_sd = {} | |
| transformer_sd = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith("vae."): | |
| vae_sd[key[len("vae.") :]] = value | |
| elif key.startswith("transformer."): | |
| transformer_sd[key[len("transformer.") :]] = value | |
| vae.load_state_dict(vae_sd, strict=False) | |
| transformer.load_state_dict(transformer_sd, strict=False) | |
| return vae, transformer | |
| # Preload all models on CPU at init so each @spaces.GPU lease only pays for the | |
| # CPU -> GPU transfer, not the full from_pretrained / checkpoint read. | |
| GENERATION_PIPE = load_generation_pipe() | |
| WAN_VAE = load_wan_vae() | |
| REFDECODER_VAE, REFDECODER_TRANSFORMER = load_refdecoder_module() | |
| def resize_image_for_wan(image, pipe): | |
| image = image.convert("RGB") | |
| aspect_ratio = image.height / image.width | |
| mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] | |
| height = round(np.sqrt(TARGET_AREA * aspect_ratio)) // mod_value * mod_value | |
| width = round(np.sqrt(TARGET_AREA / aspect_ratio)) // mod_value * mod_value | |
| resized = image.resize((width, height)) | |
| return resized, height, width | |
| def build_reference_frame(image, device): | |
| ref_array = np.asarray(image).astype(np.float32) | |
| ref_tensor = torch.from_numpy(ref_array).permute(2, 0, 1) | |
| ref_tensor = (ref_tensor / 255.0 - 0.5) * 2.0 | |
| return ref_tensor.unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.float32) | |
| def normalize_latent_shape(latents): | |
| if isinstance(latents, list): | |
| latents = latents[0] | |
| if latents.ndim == 4: | |
| latents = latents.unsqueeze(0) | |
| if latents.ndim != 5: | |
| raise ValueError(f"Expected latent shape [B,C,T,H,W], got {tuple(latents.shape)}") | |
| return latents | |
| def gradio_file_url(path): | |
| return f"/gradio_api/file={quote(str(path), safe='/')}" | |
| def build_compare_html(wan_video_path, ref_video_path): | |
| compare_id = f"compare-{uuid.uuid4().hex}" | |
| wan_url = gradio_file_url(wan_video_path) if wan_video_path else "" | |
| ref_url = gradio_file_url(ref_video_path) if ref_video_path else "" | |
| base_source = ( | |
| f'<video class="compare-video compare-base" src="{wan_url}" autoplay muted loop playsinline></video>' | |
| if wan_url | |
| else '<div class="compare-video compare-base compare-placeholder"></div>' | |
| ) | |
| overlay_source = ( | |
| f'<video class="compare-video compare-overlay" src="{ref_url}" autoplay muted loop playsinline></video>' | |
| if ref_url | |
| else '<div class="compare-video compare-overlay compare-placeholder"></div>' | |
| ) | |
| inner_doc = f""" | |
| <!doctype html> | |
| <html> | |
| <head> | |
| <meta charset="utf-8" /> | |
| <meta name="viewport" content="width=device-width, initial-scale=1" /> | |
| <style> | |
| html, body {{ | |
| margin: 0; | |
| padding: 0; | |
| background: transparent; | |
| font-family: Manrope, Inter, system-ui, sans-serif; | |
| }} | |
| .compare-shell {{ | |
| display: flex; | |
| flex-direction: column; | |
| gap: 12px; | |
| }} | |
| .compare-topbar {{ | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| gap: 12px; | |
| }} | |
| .compare-chip {{ | |
| padding: 12px 22px; | |
| border-radius: 999px; | |
| background: rgba(31, 106, 82, 0.14); | |
| color: #123a2d; | |
| font-size: 22px; | |
| font-weight: 800; | |
| letter-spacing: 0.03em; | |
| text-transform: uppercase; | |
| box-shadow: inset 0 0 0 1px rgba(31, 106, 82, 0.12); | |
| justify-self: start; | |
| }} | |
| .compare-chip-right {{ | |
| background: rgba(201, 111, 66, 0.16); | |
| color: #6e3d23; | |
| box-shadow: inset 0 0 0 1px rgba(201, 111, 66, 0.16); | |
| justify-self: end; | |
| }} | |
| .compare-button {{ | |
| border: 0; | |
| border-radius: 999px; | |
| padding: 10px 22px; | |
| background: #1f6a52; | |
| color: white; | |
| font-size: 16px; | |
| font-weight: 700; | |
| cursor: pointer; | |
| justify-self: center; | |
| }} | |
| .compare-stage {{ | |
| position: relative; | |
| width: 100%; | |
| aspect-ratio: 16 / 9; | |
| overflow: hidden; | |
| border-radius: 22px; | |
| background: #16120f; | |
| border: 1px solid rgba(255,255,255,0.08); | |
| }} | |
| .compare-video {{ | |
| position: absolute; | |
| inset: 0; | |
| width: 100%; | |
| height: 100%; | |
| object-fit: contain; | |
| background: #16120f; | |
| }} | |
| .compare-overlay {{ | |
| clip-path: inset(0 0 0 50%); | |
| }} | |
| .compare-placeholder {{ | |
| background: | |
| linear-gradient(135deg, rgba(255,255,255,0.055), transparent 35%), | |
| #16120f; | |
| }} | |
| .compare-divider {{ | |
| position: absolute; | |
| top: 0; | |
| bottom: 0; | |
| left: 50%; | |
| width: 2px; | |
| background: rgba(255,255,255,0.96); | |
| box-shadow: 0 0 0 1px rgba(31, 26, 20, 0.15); | |
| transform: translateX(-1px); | |
| pointer-events: none; | |
| }} | |
| .compare-divider::after {{ | |
| content: ""; | |
| position: absolute; | |
| top: 50%; | |
| left: 50%; | |
| width: 18px; | |
| height: 18px; | |
| border-radius: 999px; | |
| background: #fff; | |
| border: 2px solid rgba(31, 26, 20, 0.18); | |
| transform: translate(-50%, -50%); | |
| }} | |
| .compare-range {{ | |
| position: absolute; | |
| inset: 0; | |
| width: 100%; | |
| height: 100%; | |
| opacity: 0.01; | |
| cursor: ew-resize; | |
| margin: 0; | |
| -webkit-appearance: none; | |
| appearance: none; | |
| }} | |
| .compare-caption {{ | |
| color: #201a14; | |
| font-size: 14px; | |
| line-height: 1.5; | |
| text-align: center; | |
| }} | |
| .compare-controls {{ | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| gap: 10px; | |
| flex-wrap: wrap; | |
| }} | |
| .compare-controls .compare-button {{ | |
| padding: 9px 16px; | |
| font-size: 14px; | |
| }} | |
| .compare-button-step {{ | |
| background: #2f5746; | |
| }} | |
| .compare-button-reset {{ | |
| background: #c96f42; | |
| }} | |
| .compare-button[disabled] {{ | |
| opacity: 0.55; | |
| cursor: not-allowed; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="compare-shell" id="{compare_id}"> | |
| <div class="compare-topbar"> | |
| <div class="compare-chip">Wan Baseline</div> | |
| <div class="compare-chip compare-chip-right">RefDecoder</div> | |
| </div> | |
| <div class="compare-stage"> | |
| {base_source} | |
| {overlay_source} | |
| <div class="compare-divider"></div> | |
| <input class="compare-range" type="range" min="0" max="100" value="50" /> | |
| </div> | |
| <div class="compare-controls"> | |
| <button class="compare-button compare-button-step" type="button" data-action="prev">− 1 Frame</button> | |
| <button class="compare-button" type="button" data-action="toggle">Pause</button> | |
| <button class="compare-button compare-button-step" type="button" data-action="next">+ 1 Frame</button> | |
| <button class="compare-button compare-button-reset" type="button" data-action="reset">Reset Playback</button> | |
| </div> | |
| <div class="compare-caption">Drag the divider to compare the two decoders on the same latent video.</div> | |
| </div> | |
| <script> | |
| (() => {{ | |
| const root = document.getElementById("{compare_id}"); | |
| const base = root.querySelector(".compare-base"); | |
| const overlay = root.querySelector(".compare-overlay"); | |
| const divider = root.querySelector(".compare-divider"); | |
| const slider = root.querySelector(".compare-range"); | |
| const button = root.querySelector('[data-action="toggle"]'); | |
| const prevBtn = root.querySelector('[data-action="prev"]'); | |
| const nextBtn = root.querySelector('[data-action="next"]'); | |
| const resetBtn = root.querySelector('[data-action="reset"]'); | |
| const stepButtons = [prevBtn, nextBtn, resetBtn]; | |
| const videos = Array.from(root.querySelectorAll("video")); | |
| const FRAME_DELTA = 1 / {FPS}; | |
| const applySplit = () => {{ | |
| const value = Number(slider.value); | |
| overlay.style.clipPath = `inset(0 0 0 ${{value}}%)`; | |
| divider.style.left = `${{value}}%`; | |
| }}; | |
| const syncVideo = (source, target) => {{ | |
| if (Math.abs((target.currentTime || 0) - (source.currentTime || 0)) > 0.08) {{ | |
| try {{ target.currentTime = source.currentTime; }} catch (e) {{}} | |
| }} | |
| }}; | |
| const playBoth = () => {{ | |
| videos.forEach((video) => video.play().catch(() => {{}})); | |
| button.textContent = "Pause"; | |
| }}; | |
| const pauseBoth = () => {{ | |
| videos.forEach((video) => video.pause()); | |
| button.textContent = "Play"; | |
| }}; | |
| const bindSync = (primary, secondary) => {{ | |
| primary.addEventListener("play", () => secondary.play().catch(() => {{}})); | |
| primary.addEventListener("pause", () => secondary.pause()); | |
| primary.addEventListener("seeking", () => syncVideo(primary, secondary)); | |
| primary.addEventListener("timeupdate", () => syncVideo(primary, secondary)); | |
| primary.addEventListener("ratechange", () => {{ secondary.playbackRate = primary.playbackRate; }}); | |
| }}; | |
| const stepFrame = (delta) => {{ | |
| if (!videos.length) return; | |
| pauseBoth(); | |
| videos.forEach((video) => {{ | |
| const duration = isFinite(video.duration) ? video.duration : 0; | |
| let nextTime = (video.currentTime || 0) + delta; | |
| if (duration > 0) {{ | |
| nextTime = ((nextTime % duration) + duration) % duration; | |
| }} else {{ | |
| nextTime = Math.max(0, nextTime); | |
| }} | |
| try {{ video.currentTime = nextTime; }} catch (e) {{}} | |
| }}); | |
| }}; | |
| const resetPlayback = () => {{ | |
| pauseBoth(); | |
| videos.forEach((video) => {{ | |
| try {{ video.currentTime = 0; }} catch (e) {{}} | |
| }}); | |
| }}; | |
| if (base.tagName === "VIDEO" && overlay.tagName === "VIDEO") {{ | |
| bindSync(base, overlay); | |
| bindSync(overlay, base); | |
| }} else {{ | |
| button.disabled = true; | |
| button.textContent = "Play"; | |
| button.style.opacity = "0.55"; | |
| stepButtons.forEach((btn) => {{ if (btn) btn.disabled = true; }}); | |
| }} | |
| videos.forEach((video) => {{ | |
| video.addEventListener("loadeddata", playBoth, {{ once: true }}); | |
| }}); | |
| button.addEventListener("click", () => {{ | |
| if (!videos.length || videos[0].paused) {{ | |
| playBoth(); | |
| }} else {{ | |
| pauseBoth(); | |
| }} | |
| }}); | |
| if (prevBtn) prevBtn.addEventListener("click", () => stepFrame(-FRAME_DELTA)); | |
| if (nextBtn) nextBtn.addEventListener("click", () => stepFrame(FRAME_DELTA)); | |
| if (resetBtn) resetBtn.addEventListener("click", resetPlayback); | |
| slider.addEventListener("input", applySplit); | |
| applySplit(); | |
| const reportHeight = () => {{ | |
| const h = Math.ceil(root.getBoundingClientRect().height + 2); | |
| parent.postMessage({{ type: "compare-iframe-height", id: "{compare_id}", height: h }}, "*"); | |
| }}; | |
| reportHeight(); | |
| window.addEventListener("load", reportHeight); | |
| if (typeof ResizeObserver !== "undefined") {{ | |
| new ResizeObserver(reportHeight).observe(root); | |
| }} | |
| videos.forEach((video) => {{ | |
| video.addEventListener("loadedmetadata", reportHeight); | |
| }}); | |
| }})(); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return ( | |
| '<iframe class="compare-frame" ' | |
| 'sandbox="allow-scripts allow-same-origin" ' | |
| 'scrolling="no" ' | |
| 'srcdoc="' + html.escape(inner_doc, quote=True) + '"></iframe>' | |
| ) | |
| def save_video_tensor(video_tensor, output_path): | |
| video = (video_tensor / 2 + 0.5).clamp(0, 1) | |
| video = video.squeeze(0).permute(1, 2, 3, 0).detach().cpu().float().numpy() | |
| video = (video * 255).astype(np.uint8) | |
| imageio.mimwrite(output_path, video, fps=FPS, quality=10) | |
| return str(output_path) | |
| def decode_with_wan_vae(latents, vae): | |
| vae_dtype = get_module_dtype(vae) | |
| latents = latents.to(device=DEVICE, dtype=vae_dtype) | |
| latents_mean = torch.tensor(vae.config.latents_mean, device=DEVICE, dtype=vae_dtype).view(1, -1, 1, 1, 1) | |
| latents_std = torch.tensor(vae.config.latents_std, device=DEVICE, dtype=vae_dtype).view(1, -1, 1, 1, 1) | |
| latents = latents * latents_std + latents_mean | |
| with torch.no_grad(): | |
| video = vae.decode(latents, return_dict=False)[0] | |
| return video | |
| def decode_with_refdecoder(latents, reference_frame, vae, transformer): | |
| decode_dtype = get_module_dtype(vae) | |
| latents = latents.to(device=DEVICE, dtype=decode_dtype) | |
| latents_mean = torch.tensor( | |
| vae.config.latents_mean, | |
| device=DEVICE, | |
| dtype=decode_dtype, | |
| ).view(1, -1, 1, 1, 1) | |
| latents_std = torch.tensor( | |
| vae.config.latents_std, | |
| device=DEVICE, | |
| dtype=decode_dtype, | |
| ).view(1, -1, 1, 1, 1) | |
| latents = latents * latents_std + latents_mean | |
| reference_frame = reference_frame.to(device=DEVICE, dtype=decode_dtype) | |
| with torch.no_grad(): | |
| video = vae.decode( | |
| latents, | |
| transformer, | |
| return_dict=True, | |
| reference_frame=reference_frame, | |
| skip=False, | |
| window_size=-1, | |
| ).sample | |
| if hasattr(vae, "clear_cache"): | |
| vae.clear_cache() | |
| return video | |
| _NUM_DENOISING_CHUNKS = 4 | |
| CHUNK_BOUNDARIES = tuple( | |
| NUM_INFERENCE_STEPS * (i + 1) // _NUM_DENOISING_CHUNKS | |
| for i in range(_NUM_DENOISING_CHUNKS) | |
| ) | |
| assert CHUNK_BOUNDARIES[-1] == NUM_INFERENCE_STEPS | |
| def _run_diffusion_steps( | |
| latents, | |
| condition, | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| image_embeds, | |
| timesteps, | |
| start_step, | |
| end_step, | |
| transformer_dtype, | |
| ): | |
| """Inlined Wan 2.1 I2V denoising loop. Runs steps [start_step, end_step).""" | |
| transformer = GENERATION_PIPE.transformer | |
| scheduler = GENERATION_PIPE.scheduler | |
| with torch.no_grad(): | |
| for i in range(start_step, end_step): | |
| step_start = time.perf_counter() | |
| t = timesteps[i] | |
| latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) | |
| timestep = t.expand(latents.shape[0]) | |
| with transformer.cache_context("cond"): | |
| noise_pred = transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=prompt_embeds, | |
| encoder_hidden_states_image=image_embeds, | |
| return_dict=False, | |
| )[0] | |
| with transformer.cache_context("uncond"): | |
| noise_uncond = transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=negative_prompt_embeds, | |
| encoder_hidden_states_image=image_embeds, | |
| return_dict=False, | |
| )[0] | |
| noise_pred = noise_uncond + GUIDANCE_SCALE * (noise_pred - noise_uncond) | |
| latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| step_secs = time.perf_counter() - step_start | |
| print( | |
| f"[diffusion] step {i + 1}/{NUM_INFERENCE_STEPS} " | |
| f"(t={float(t):.1f}, {step_secs:.2f}s)", | |
| flush=True, | |
| ) | |
| return latents | |
| def generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width): | |
| """Encode prompt+image, prepare initial latents and condition. NO denoising. | |
| Loads only the encoders + VAE to GPU (not the 14B transformer). Returns a | |
| CPU-resident state dict consumable by generate_latents_chunk_on_gpu. | |
| """ | |
| log_cuda_mem("start generate_latents_setup_on_gpu") | |
| text_encoder = GENERATION_PIPE.text_encoder | |
| image_encoder = GENERATION_PIPE.image_encoder | |
| vae = GENERATION_PIPE.vae | |
| text_encoder.to(DEVICE) | |
| image_encoder.to(DEVICE) | |
| vae.to(DEVICE) | |
| try: | |
| transformer_dtype = GENERATION_PIPE.transformer.dtype | |
| prompt_embeds, negative_prompt_embeds = GENERATION_PIPE.encode_prompt( | |
| prompt=prompt, | |
| negative_prompt=NEGATIVE_PROMPT, | |
| do_classifier_free_guidance=True, | |
| num_videos_per_prompt=1, | |
| max_sequence_length=512, | |
| device=DEVICE, | |
| ) | |
| prompt_embeds = prompt_embeds.to(transformer_dtype) | |
| negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) | |
| image_embeds = GENERATION_PIPE.encode_image(resized_image, DEVICE) | |
| image_embeds = image_embeds.repeat(1, 1, 1).to(transformer_dtype) | |
| image_tensor = GENERATION_PIPE.video_processor.preprocess( | |
| resized_image, height=height, width=width | |
| ).to(DEVICE, dtype=torch.float32) | |
| generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
| latents, condition = GENERATION_PIPE.prepare_latents( | |
| image_tensor, | |
| 1, | |
| GENERATION_PIPE.vae.config.z_dim, | |
| height, | |
| width, | |
| NUM_FRAMES, | |
| torch.float32, | |
| DEVICE, | |
| generator, | |
| None, | |
| None, | |
| ) | |
| state = { | |
| "prompt_embeds": prompt_embeds.detach().cpu(), | |
| "negative_prompt_embeds": negative_prompt_embeds.detach().cpu(), | |
| "image_embeds": image_embeds.detach().cpu(), | |
| "condition": condition.detach().cpu(), | |
| "latents": latents.detach().cpu(), | |
| "step_idx": 0, | |
| } | |
| finally: | |
| text_encoder.to("cpu") | |
| image_encoder.to("cpu") | |
| vae.to("cpu") | |
| log_cuda_mem("end generate_latents_setup_on_gpu") | |
| return state | |
| def generate_latents_chunk_on_gpu(state, end_step): | |
| """Run denoising steps from state['step_idx'] to end_step. Only transformer is moved to GPU.""" | |
| log_cuda_mem(f"start latents chunk -> step {end_step}") | |
| transformer = GENERATION_PIPE.transformer | |
| transformer.to(DEVICE) | |
| try: | |
| GENERATION_PIPE.scheduler.set_timesteps(NUM_INFERENCE_STEPS, device=DEVICE) | |
| timesteps = GENERATION_PIPE.scheduler.timesteps | |
| transformer_dtype = transformer.dtype | |
| latents = state["latents"].to(DEVICE) | |
| condition = state["condition"].to(DEVICE) | |
| prompt_embeds = state["prompt_embeds"].to(DEVICE) | |
| negative_prompt_embeds = state["negative_prompt_embeds"].to(DEVICE) | |
| image_embeds = state["image_embeds"].to(DEVICE) | |
| latents = _run_diffusion_steps( | |
| latents, | |
| condition, | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| image_embeds, | |
| timesteps, | |
| state["step_idx"], | |
| end_step, | |
| transformer_dtype, | |
| ) | |
| state["latents"] = latents.detach().cpu() | |
| state["step_idx"] = end_step | |
| finally: | |
| transformer.to("cpu") | |
| log_cuda_mem(f"end latents chunk -> step {end_step}") | |
| return state | |
| def decode_wan_on_gpu(latents): | |
| log_cuda_mem("start decode_wan_on_gpu") | |
| WAN_VAE.to(DEVICE) | |
| try: | |
| video = decode_with_wan_vae(latents, WAN_VAE) | |
| video = video.detach().cpu() | |
| finally: | |
| WAN_VAE.to("cpu") | |
| log_cuda_mem("after wan decode") | |
| return video | |
| def decode_refdecoder_on_gpu(latents, reference_frame): | |
| log_cuda_mem("start decode_refdecoder_on_gpu") | |
| REFDECODER_VAE.to(DEVICE) | |
| REFDECODER_TRANSFORMER.to(DEVICE) | |
| try: | |
| video = decode_with_refdecoder( | |
| latents, | |
| reference_frame, | |
| REFDECODER_VAE, | |
| REFDECODER_TRANSFORMER, | |
| ) | |
| video = video.detach().cpu() | |
| finally: | |
| REFDECODER_VAE.to("cpu") | |
| REFDECODER_TRANSFORMER.to("cpu") | |
| log_cuda_mem("after refdecoder decode") | |
| return video | |
| def generate_and_decode(image, prompt, seed, progress=gr.Progress()): | |
| if image is None: | |
| raise gr.Error("Please upload an input image.") | |
| if DEVICE != "cuda": | |
| raise gr.Error("This demo expects a CUDA GPU to run Wan I2V generation.") | |
| request_start = time.perf_counter() | |
| prompt = prompt.strip() if prompt else "" | |
| seed = int(seed) if seed is not None else random.randint(0, 2**32 - 1) | |
| run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}" | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| # 1 setup chunk (encoders + VAE) + len(CHUNK_BOUNDARIES) denoising chunks. | |
| total_chunks = 1 + len(CHUNK_BOUNDARIES) | |
| progress(0.0, desc=f"Generating latents (1/{total_chunks})") | |
| t0 = time.perf_counter() | |
| resized_image, height, width = resize_image_for_wan(image, GENERATION_PIPE) | |
| state = generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width) | |
| for chunk_idx, end_step in enumerate(CHUNK_BOUNDARIES, start=2): | |
| progress( | |
| 0.8 * (chunk_idx - 1) / total_chunks, | |
| desc=f"Generating latents ({chunk_idx}/{total_chunks})", | |
| ) | |
| state = generate_latents_chunk_on_gpu(state, end_step) | |
| latents = normalize_latent_shape(state["latents"]) | |
| latent_secs = time.perf_counter() - t0 | |
| print(f"[timing] latent generation: {latent_secs:.2f}s") | |
| reference_frame = build_reference_frame(resized_image, "cpu") | |
| latent_path = run_dir / "wan_latents.pt" | |
| torch.save( | |
| { | |
| "latents": latents, | |
| "height": height, | |
| "width": width, | |
| "prompt": prompt, | |
| "seed": seed, | |
| }, | |
| latent_path, | |
| ) | |
| progress(0.8, desc="Decoding Wan baseline") | |
| t0 = time.perf_counter() | |
| wan_video = decode_wan_on_gpu(latents) | |
| wan_secs = time.perf_counter() - t0 | |
| print(f"[timing] wan decode: {wan_secs:.2f}s") | |
| wan_video_path = save_video_tensor(wan_video, run_dir / "wan_vae.mp4") | |
| del wan_video | |
| gc.collect() | |
| progress(0.9, desc="Decoding RefDecoder") | |
| t0 = time.perf_counter() | |
| ref_video = decode_refdecoder_on_gpu(latents, reference_frame) | |
| ref_secs = time.perf_counter() - t0 | |
| print(f"[timing] refdecoder decode: {ref_secs:.2f}s") | |
| ref_video_path = save_video_tensor(ref_video, run_dir / "refdecoder.mp4") | |
| del ref_video | |
| gc.collect() | |
| compare_html = build_compare_html(wan_video_path, ref_video_path) | |
| total_secs = time.perf_counter() - request_start | |
| print( | |
| f"[timing] request total: {total_secs:.2f}s " | |
| f"(latents={latent_secs:.2f}s, wan={wan_secs:.2f}s, ref={ref_secs:.2f}s)" | |
| ) | |
| return ( | |
| gr.update(value=compare_html, visible=True), | |
| wan_video_path, | |
| ref_video_path, | |
| "", | |
| gr.update(value=wan_video_path, interactive=True), | |
| gr.update(value=ref_video_path, interactive=True), | |
| ) | |
| CUSTOM_CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Manrope:wght@400;500;600;700;800&display=swap'); | |
| :root { | |
| --page-bg: #f4f1e8; | |
| --card-bg: rgba(255, 252, 246, 0.92); | |
| --card-border: rgba(50, 43, 32, 0.12); | |
| --accent: #1f6a52; | |
| --accent-2: #c96f42; | |
| --text-main: #201a14; | |
| --text-soft: #201a14; | |
| --ui-font: "Manrope", "Inter", "Segoe UI", sans-serif; | |
| } | |
| .gradio-container { | |
| background: | |
| radial-gradient(circle at top left, rgba(201, 111, 66, 0.18), transparent 26%), | |
| radial-gradient(circle at top right, rgba(31, 106, 82, 0.16), transparent 28%), | |
| linear-gradient(180deg, #f8f4ec 0%, var(--page-bg) 100%); | |
| font-family: var(--ui-font); | |
| } | |
| .app-shell { | |
| max-width: 1320px; | |
| margin: 0 auto; | |
| } | |
| .hero-card, | |
| .panel-card, | |
| .output-card { | |
| background: var(--card-bg); | |
| border: 1px solid var(--card-border); | |
| border-radius: 24px; | |
| box-shadow: 0 18px 50px rgba(49, 39, 26, 0.08); | |
| } | |
| .hero-card { | |
| padding: 28px 30px 20px 30px; | |
| margin-bottom: 18px; | |
| } | |
| .hero-kicker { | |
| display: inline-block; | |
| padding: 6px 12px; | |
| border-radius: 999px; | |
| background: rgba(31, 106, 82, 0.10); | |
| color: var(--accent); | |
| font-size: 12px; | |
| font-weight: 700; | |
| letter-spacing: 0.08em; | |
| text-transform: uppercase; | |
| } | |
| .hero-title { | |
| margin: 14px 0 8px 0; | |
| font-size: 42px; | |
| line-height: 1.05; | |
| font-weight: 800; | |
| color: var(--text-main); | |
| } | |
| .hero-copy { | |
| margin: 0; | |
| max-width: 840px; | |
| color: var(--text-soft); | |
| font-size: 17px; | |
| line-height: 1.6; | |
| font-family: var(--ui-font); | |
| } | |
| .panel-card, | |
| .output-card { | |
| padding: 18px; | |
| } | |
| .panel-card { | |
| overflow: hidden; | |
| } | |
| .section-title { | |
| margin: 0 0 6px 0; | |
| color: var(--text-main); | |
| font-size: 22px; | |
| font-weight: 750; | |
| } | |
| .section-copy { | |
| margin: 0 0 14px 0; | |
| color: var(--text-soft); | |
| font-size: 14px; | |
| line-height: 1.55; | |
| font-family: var(--ui-font); | |
| } | |
| .compare-note { | |
| padding: 12px 14px; | |
| border-radius: 16px; | |
| background: rgba(201, 111, 66, 0.08); | |
| color: #6a4128; | |
| font-size: 14px; | |
| line-height: 1.5; | |
| margin-bottom: 14px; | |
| } | |
| #generate-btn { | |
| min-height: 108px; | |
| height: 100%; | |
| width: 100%; | |
| font-size: 16px; | |
| font-weight: 700; | |
| background: linear-gradient(135deg, var(--accent) 0%, #154f3d 100%); | |
| border: none; | |
| } | |
| #generate-btn:hover { | |
| filter: brightness(1.04); | |
| } | |
| .output-grid { | |
| gap: 14px; | |
| } | |
| .compare-shell { | |
| display: flex; | |
| flex-direction: column; | |
| gap: 12px; | |
| } | |
| .compare-frame { | |
| width: 100%; | |
| /* aspect-ratio is a tight fallback for the brief moment before the parent | |
| JS estimator (and then the iframe's own postMessage) sets the height. */ | |
| aspect-ratio: 16 / 11; | |
| border: 0; | |
| background: transparent; | |
| overflow: hidden; | |
| display: block; | |
| transition: height 120ms ease; | |
| } | |
| .compare-topbar { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| gap: 12px; | |
| } | |
| .compare-chip { | |
| padding: 8px 12px; | |
| border-radius: 999px; | |
| background: rgba(31, 106, 82, 0.08); | |
| color: var(--text-main); | |
| font-size: 12px; | |
| font-weight: 700; | |
| letter-spacing: 0.04em; | |
| text-transform: uppercase; | |
| } | |
| .compare-chip-right { | |
| background: rgba(201, 111, 66, 0.10); | |
| } | |
| .compare-stage { | |
| position: relative; | |
| width: 100%; | |
| aspect-ratio: 16 / 9; | |
| overflow: hidden; | |
| border-radius: 22px; | |
| background: #16120f; | |
| border: 1px solid rgba(255,255,255,0.08); | |
| } | |
| .compare-video { | |
| position: absolute; | |
| inset: 0; | |
| width: 100%; | |
| height: 100%; | |
| object-fit: contain; | |
| background: #16120f; | |
| } | |
| .compare-overlay { | |
| clip-path: inset(0 0 0 50%); | |
| } | |
| .compare-divider { | |
| position: absolute; | |
| top: 0; | |
| bottom: 0; | |
| left: 50%; | |
| width: 2px; | |
| background: rgba(255,255,255,0.96); | |
| box-shadow: 0 0 0 1px rgba(31, 26, 20, 0.15); | |
| transform: translateX(-1px); | |
| pointer-events: none; | |
| } | |
| .compare-divider::after { | |
| content: ""; | |
| position: absolute; | |
| top: 50%; | |
| left: 50%; | |
| width: 18px; | |
| height: 18px; | |
| border-radius: 999px; | |
| background: #fff; | |
| border: 2px solid rgba(31, 26, 20, 0.18); | |
| transform: translate(-50%, -50%); | |
| } | |
| .compare-range { | |
| position: absolute; | |
| inset: 0; | |
| width: 100%; | |
| height: 100%; | |
| opacity: 0; | |
| cursor: ew-resize; | |
| } | |
| .compare-caption { | |
| color: var(--text-soft); | |
| font-size: 14px; | |
| line-height: 1.5; | |
| font-family: var(--ui-font); | |
| } | |
| .compare-panel { | |
| padding-bottom: 34px; | |
| } | |
| .seed-action-row { | |
| align-items: stretch; | |
| } | |
| .seed-action-row > .gradio-column { | |
| min-width: 0; | |
| } | |
| .run-status { | |
| margin-top: 8px; | |
| color: var(--text-soft); | |
| font-size: 13px; | |
| line-height: 1.4; | |
| min-height: 1.4em; | |
| } | |
| .run-status p { | |
| margin: 0; | |
| } | |
| .download-row { | |
| margin-top: 12px; | |
| gap: 12px; | |
| justify-content: center; | |
| flex-wrap: wrap; | |
| } | |
| .download-row button { | |
| border: 0 !important; | |
| border-radius: 999px !important; | |
| padding: 10px 22px !important; | |
| font-size: 14px !important; | |
| font-weight: 700 !important; | |
| box-shadow: none !important; | |
| min-height: 0 !important; | |
| } | |
| button.download-baseline { | |
| background: var(--accent) !important; | |
| color: #fff !important; | |
| } | |
| button.download-ref { | |
| background: var(--accent-2) !important; | |
| color: #fff !important; | |
| } | |
| .download-row button:hover:not([disabled]):not(:disabled) { | |
| filter: brightness(1.05); | |
| } | |
| button.download-baseline[disabled], | |
| button.download-baseline:disabled { | |
| background: rgba(31, 106, 82, 0.14) !important; | |
| color: #123a2d !important; | |
| box-shadow: inset 0 0 0 1px rgba(31, 106, 82, 0.12) !important; | |
| opacity: 1 !important; | |
| cursor: not-allowed; | |
| } | |
| button.download-ref[disabled], | |
| button.download-ref:disabled { | |
| background: rgba(201, 111, 66, 0.16) !important; | |
| color: #6e3d23 !important; | |
| box-shadow: inset 0 0 0 1px rgba(201, 111, 66, 0.16) !important; | |
| opacity: 1 !important; | |
| cursor: not-allowed; | |
| } | |
| """ | |
| with gr.Blocks(title="RefDecoder I2V Demo", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: | |
| with gr.Column(elem_classes="app-shell"): | |
| gr.HTML(""" | |
| <script> | |
| (() => { | |
| if (window.__refdecoderResizeBound) return; | |
| window.__refdecoderResizeBound = true; | |
| const STAGE_RATIO = 9 / 16; | |
| const CHROME = 160; | |
| const observed = new WeakSet(); | |
| const estimateHeight = (iframe) => { | |
| if (iframe.dataset.exactSized === "1") return; | |
| const w = iframe.getBoundingClientRect().width; | |
| if (w > 0) { | |
| iframe.style.height = Math.round(w * STAGE_RATIO + CHROME) + "px"; | |
| } | |
| }; | |
| const trackIframe = (iframe) => { | |
| if (observed.has(iframe)) return; | |
| observed.add(iframe); | |
| estimateHeight(iframe); | |
| new ResizeObserver(() => estimateHeight(iframe)).observe(iframe); | |
| }; | |
| document.querySelectorAll("iframe.compare-frame").forEach(trackIframe); | |
| new MutationObserver((mutations) => { | |
| for (const m of mutations) { | |
| for (const n of m.addedNodes) { | |
| if (n.nodeType !== 1) continue; | |
| if (n.matches && n.matches("iframe.compare-frame")) trackIframe(n); | |
| const inner = n.querySelectorAll && n.querySelectorAll("iframe.compare-frame"); | |
| if (inner) inner.forEach(trackIframe); | |
| } | |
| } | |
| }).observe(document.body, { childList: true, subtree: true }); | |
| window.addEventListener("message", (e) => { | |
| if (!e.data || e.data.type !== "compare-iframe-height") return; | |
| const h = Math.max(200, Number(e.data.height) || 0); | |
| document.querySelectorAll("iframe.compare-frame").forEach((f) => { | |
| if (f.contentWindow === e.source) { | |
| f.style.height = h + "px"; | |
| f.dataset.exactSized = "1"; | |
| } | |
| }); | |
| }); | |
| })(); | |
| </script> | |
| <div class="hero-card"> | |
| <div class="hero-title">RefDecoder I2V Demo</div> | |
| <p class="hero-copy"> | |
| Upload one image, optionally add a prompt, and compare two decoders on the same Wan latent video. | |
| The app generates latents once, then renders them with Wan's original VAE and with RefDecoder. | |
| </p> | |
| </div> | |
| """) | |
| with gr.Column(elem_classes=["panel-card", "compare-panel"]): | |
| gr.HTML(""" | |
| <div class="section-title">Inputs</div> | |
| <div class="section-copy"> | |
| Upload a reference image, optionally add a prompt, and compare the decoders below. | |
| </div> | |
| """) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=3): | |
| image_input = gr.Image( | |
| label="Input Image", | |
| type="pil", | |
| height=180, | |
| ) | |
| with gr.Column(scale=5): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| lines=2, | |
| placeholder="A woman turns toward the camera as her hair moves in the wind...", | |
| ) | |
| with gr.Row(equal_height=True, elem_classes="seed-action-row"): | |
| with gr.Column(scale=1): | |
| seed_input = gr.Number( | |
| label="Seed", | |
| value=None, | |
| precision=0, | |
| info="Optional", | |
| ) | |
| with gr.Column(scale=1): | |
| run_button = gr.Button( | |
| "Generate Comparison", | |
| variant="primary", | |
| elem_id="generate-btn", | |
| ) | |
| status_md = gr.Markdown(value="", elem_classes="run-status") | |
| with gr.Column(elem_classes="panel-card"): | |
| gr.HTML(""" | |
| <div class="section-title">Decoder Comparison</div> | |
| <div class="section-copy"> | |
| Left side shows Wan Baseline. Right side shows RefDecoder. Drag the divider across the frame to compare them. | |
| </div> | |
| """) | |
| compare_output = gr.HTML(value=build_compare_html(None, None)) | |
| with gr.Row(elem_classes="download-row"): | |
| wan_download_btn = gr.DownloadButton( | |
| label="Download Baseline", | |
| value=None, | |
| interactive=False, | |
| elem_classes="download-baseline", | |
| ) | |
| ref_download_btn = gr.DownloadButton( | |
| label="Download RefDecoder", | |
| value=None, | |
| interactive=False, | |
| elem_classes="download-ref", | |
| ) | |
| wan_video_hidden = gr.Video(visible=False) | |
| ref_video_hidden = gr.Video(visible=False) | |
| def reset_for_new_run(): | |
| return ( | |
| "", | |
| gr.update(value=None, interactive=False), | |
| gr.update(value=None, interactive=False), | |
| ) | |
| run_button.click( | |
| fn=reset_for_new_run, | |
| inputs=None, | |
| outputs=[status_md, wan_download_btn, ref_download_btn], | |
| queue=False, | |
| show_progress="hidden", | |
| ).then( | |
| fn=generate_and_decode, | |
| inputs=[image_input, prompt_input, seed_input], | |
| outputs=[ | |
| compare_output, | |
| wan_video_hidden, | |
| ref_video_hidden, | |
| status_md, | |
| wan_download_btn, | |
| ref_download_btn, | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=2).launch(allowed_paths=[str(OUTPUT_ROOT)]) | |