RefDecoder / app.py
xiangfan00's picture
Increase chunk time and log progress
7baf5d2
import gc
import html
import random
import sys
import time
import uuid
from pathlib import Path
from urllib.parse import quote
import gradio as gr
import imageio
import numpy as np
import ftfy
try:
import spaces
except ImportError:
class _SpacesShim:
@staticmethod
def GPU(*args, **kwargs):
def decorator(fn):
return fn
return decorator
spaces = _SpacesShim()
import torch
from diffusers.pipelines.wan import pipeline_wan_i2v
from diffusers import AutoencoderKLWan as DiffusersWanVAE
from diffusers import WanImageToVideoPipeline
from huggingface_hub import hf_hub_download, snapshot_download
from transformers import CLIPVisionModel
from src.models.Wan.autoencoder_wanT import AutoencoderKLWan
from src.models.Wan.transformer_wan import WanDecoderTransformer
ROOT = Path(__file__).resolve().parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
REFDECODER_REPO_ID = "Arrokothwhi/RefDecoder"
REFDECODER_CKPT_PATH_IN_REPO = "I2V_Wan2.1/model.pt"
OUTPUT_ROOT = ROOT / "gradio_outputs"
NEGATIVE_PROMPT = (
"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, "
"images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, "
"incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, "
"misshapen limbs, fused fingers, still picture, messy background, three legs, many people "
"in the background, walking backwards"
)
TARGET_AREA = 480 * 832
FPS = 16
NUM_FRAMES = 17
NUM_INFERENCE_STEPS = 50
GUIDANCE_SCALE = 5.0
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PIPE_DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
# Some diffusers Wan builds reference a module-level `ftfy` during prompt cleaning.
# Make it explicit so Spaces don't fail if that global was not initialized.
pipeline_wan_i2v.ftfy = ftfy
def download_refdecoder_ckpt():
print("[init] Downloading RefDecoder checkpoint metadata/file if needed")
ckpt_path = hf_hub_download(
repo_id=REFDECODER_REPO_ID,
filename=REFDECODER_CKPT_PATH_IN_REPO,
)
print(f"[init] RefDecoder checkpoint ready at: {ckpt_path}")
return ckpt_path
def download_wan_weights():
print(f"[init] Downloading Wan I2V weights from {MODEL_ID}")
repo_dir = snapshot_download(repo_id=MODEL_ID)
print(f"[init] Wan I2V weights ready at: {repo_dir}")
return repo_dir
REFDECODER_CKPT_LOCAL_PATH = download_refdecoder_ckpt()
download_wan_weights()
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
def log_cuda_mem(tag):
if not torch.cuda.is_available():
print(f"[mem] {tag}: CUDA not available")
return
try:
free_bytes, total_bytes = torch.cuda.mem_get_info()
except RuntimeError as exc:
print(f"[mem] {tag}: CUDA not currently leased ({exc})")
return
allocated_bytes = torch.cuda.memory_allocated()
reserved_bytes = torch.cuda.memory_reserved()
print(
f"[mem] {tag}: "
f"free={free_bytes / 1024**3:.2f} GB, "
f"total={total_bytes / 1024**3:.2f} GB, "
f"allocated={allocated_bytes / 1024**3:.2f} GB, "
f"reserved={reserved_bytes / 1024**3:.2f} GB"
)
def get_module_dtype(module):
try:
return next(module.parameters()).dtype
except StopIteration:
return PIPE_DTYPE
def load_generation_pipe():
image_encoder = CLIPVisionModel.from_pretrained(
MODEL_ID,
subfolder="image_encoder",
torch_dtype=PIPE_DTYPE,
)
vae = DiffusersWanVAE.from_pretrained(
MODEL_ID,
subfolder="vae",
torch_dtype=PIPE_DTYPE,
)
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
vae=vae,
image_encoder=image_encoder,
torch_dtype=PIPE_DTYPE,
)
return pipe
def load_wan_vae():
vae = DiffusersWanVAE.from_pretrained(
MODEL_ID,
subfolder="vae",
torch_dtype=PIPE_DTYPE,
)
vae.eval()
return vae
def load_refdecoder_module():
vae = AutoencoderKLWan(
dropout_p=0.0,
use_reference=True,
).eval()
transformer = WanDecoderTransformer(
chunk=5,
num_layers=10,
num_heads=12,
head_dim=128,
reusing=True,
pretrained=False,
).eval()
checkpoint = torch.load(REFDECODER_CKPT_LOCAL_PATH, map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint.get("module", checkpoint))
vae_sd = {}
transformer_sd = {}
for key, value in state_dict.items():
if key.startswith("vae."):
vae_sd[key[len("vae.") :]] = value
elif key.startswith("transformer."):
transformer_sd[key[len("transformer.") :]] = value
vae.load_state_dict(vae_sd, strict=False)
transformer.load_state_dict(transformer_sd, strict=False)
return vae, transformer
# Preload all models on CPU at init so each @spaces.GPU lease only pays for the
# CPU -> GPU transfer, not the full from_pretrained / checkpoint read.
GENERATION_PIPE = load_generation_pipe()
WAN_VAE = load_wan_vae()
REFDECODER_VAE, REFDECODER_TRANSFORMER = load_refdecoder_module()
def resize_image_for_wan(image, pipe):
image = image.convert("RGB")
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(TARGET_AREA * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(TARGET_AREA / aspect_ratio)) // mod_value * mod_value
resized = image.resize((width, height))
return resized, height, width
def build_reference_frame(image, device):
ref_array = np.asarray(image).astype(np.float32)
ref_tensor = torch.from_numpy(ref_array).permute(2, 0, 1)
ref_tensor = (ref_tensor / 255.0 - 0.5) * 2.0
return ref_tensor.unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.float32)
def normalize_latent_shape(latents):
if isinstance(latents, list):
latents = latents[0]
if latents.ndim == 4:
latents = latents.unsqueeze(0)
if latents.ndim != 5:
raise ValueError(f"Expected latent shape [B,C,T,H,W], got {tuple(latents.shape)}")
return latents
def gradio_file_url(path):
return f"/gradio_api/file={quote(str(path), safe='/')}"
def build_compare_html(wan_video_path, ref_video_path):
compare_id = f"compare-{uuid.uuid4().hex}"
wan_url = gradio_file_url(wan_video_path) if wan_video_path else ""
ref_url = gradio_file_url(ref_video_path) if ref_video_path else ""
base_source = (
f'<video class="compare-video compare-base" src="{wan_url}" autoplay muted loop playsinline></video>'
if wan_url
else '<div class="compare-video compare-base compare-placeholder"></div>'
)
overlay_source = (
f'<video class="compare-video compare-overlay" src="{ref_url}" autoplay muted loop playsinline></video>'
if ref_url
else '<div class="compare-video compare-overlay compare-placeholder"></div>'
)
inner_doc = f"""
<!doctype html>
<html>
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<style>
html, body {{
margin: 0;
padding: 0;
background: transparent;
font-family: Manrope, Inter, system-ui, sans-serif;
}}
.compare-shell {{
display: flex;
flex-direction: column;
gap: 12px;
}}
.compare-topbar {{
display: flex;
justify-content: space-between;
align-items: center;
gap: 12px;
}}
.compare-chip {{
padding: 12px 22px;
border-radius: 999px;
background: rgba(31, 106, 82, 0.14);
color: #123a2d;
font-size: 22px;
font-weight: 800;
letter-spacing: 0.03em;
text-transform: uppercase;
box-shadow: inset 0 0 0 1px rgba(31, 106, 82, 0.12);
justify-self: start;
}}
.compare-chip-right {{
background: rgba(201, 111, 66, 0.16);
color: #6e3d23;
box-shadow: inset 0 0 0 1px rgba(201, 111, 66, 0.16);
justify-self: end;
}}
.compare-button {{
border: 0;
border-radius: 999px;
padding: 10px 22px;
background: #1f6a52;
color: white;
font-size: 16px;
font-weight: 700;
cursor: pointer;
justify-self: center;
}}
.compare-stage {{
position: relative;
width: 100%;
aspect-ratio: 16 / 9;
overflow: hidden;
border-radius: 22px;
background: #16120f;
border: 1px solid rgba(255,255,255,0.08);
}}
.compare-video {{
position: absolute;
inset: 0;
width: 100%;
height: 100%;
object-fit: contain;
background: #16120f;
}}
.compare-overlay {{
clip-path: inset(0 0 0 50%);
}}
.compare-placeholder {{
background:
linear-gradient(135deg, rgba(255,255,255,0.055), transparent 35%),
#16120f;
}}
.compare-divider {{
position: absolute;
top: 0;
bottom: 0;
left: 50%;
width: 2px;
background: rgba(255,255,255,0.96);
box-shadow: 0 0 0 1px rgba(31, 26, 20, 0.15);
transform: translateX(-1px);
pointer-events: none;
}}
.compare-divider::after {{
content: "";
position: absolute;
top: 50%;
left: 50%;
width: 18px;
height: 18px;
border-radius: 999px;
background: #fff;
border: 2px solid rgba(31, 26, 20, 0.18);
transform: translate(-50%, -50%);
}}
.compare-range {{
position: absolute;
inset: 0;
width: 100%;
height: 100%;
opacity: 0.01;
cursor: ew-resize;
margin: 0;
-webkit-appearance: none;
appearance: none;
}}
.compare-caption {{
color: #201a14;
font-size: 14px;
line-height: 1.5;
text-align: center;
}}
.compare-controls {{
display: flex;
justify-content: center;
align-items: center;
gap: 10px;
flex-wrap: wrap;
}}
.compare-controls .compare-button {{
padding: 9px 16px;
font-size: 14px;
}}
.compare-button-step {{
background: #2f5746;
}}
.compare-button-reset {{
background: #c96f42;
}}
.compare-button[disabled] {{
opacity: 0.55;
cursor: not-allowed;
}}
</style>
</head>
<body>
<div class="compare-shell" id="{compare_id}">
<div class="compare-topbar">
<div class="compare-chip">Wan Baseline</div>
<div class="compare-chip compare-chip-right">RefDecoder</div>
</div>
<div class="compare-stage">
{base_source}
{overlay_source}
<div class="compare-divider"></div>
<input class="compare-range" type="range" min="0" max="100" value="50" />
</div>
<div class="compare-controls">
<button class="compare-button compare-button-step" type="button" data-action="prev">− 1 Frame</button>
<button class="compare-button" type="button" data-action="toggle">Pause</button>
<button class="compare-button compare-button-step" type="button" data-action="next">+ 1 Frame</button>
<button class="compare-button compare-button-reset" type="button" data-action="reset">Reset Playback</button>
</div>
<div class="compare-caption">Drag the divider to compare the two decoders on the same latent video.</div>
</div>
<script>
(() => {{
const root = document.getElementById("{compare_id}");
const base = root.querySelector(".compare-base");
const overlay = root.querySelector(".compare-overlay");
const divider = root.querySelector(".compare-divider");
const slider = root.querySelector(".compare-range");
const button = root.querySelector('[data-action="toggle"]');
const prevBtn = root.querySelector('[data-action="prev"]');
const nextBtn = root.querySelector('[data-action="next"]');
const resetBtn = root.querySelector('[data-action="reset"]');
const stepButtons = [prevBtn, nextBtn, resetBtn];
const videos = Array.from(root.querySelectorAll("video"));
const FRAME_DELTA = 1 / {FPS};
const applySplit = () => {{
const value = Number(slider.value);
overlay.style.clipPath = `inset(0 0 0 ${{value}}%)`;
divider.style.left = `${{value}}%`;
}};
const syncVideo = (source, target) => {{
if (Math.abs((target.currentTime || 0) - (source.currentTime || 0)) > 0.08) {{
try {{ target.currentTime = source.currentTime; }} catch (e) {{}}
}}
}};
const playBoth = () => {{
videos.forEach((video) => video.play().catch(() => {{}}));
button.textContent = "Pause";
}};
const pauseBoth = () => {{
videos.forEach((video) => video.pause());
button.textContent = "Play";
}};
const bindSync = (primary, secondary) => {{
primary.addEventListener("play", () => secondary.play().catch(() => {{}}));
primary.addEventListener("pause", () => secondary.pause());
primary.addEventListener("seeking", () => syncVideo(primary, secondary));
primary.addEventListener("timeupdate", () => syncVideo(primary, secondary));
primary.addEventListener("ratechange", () => {{ secondary.playbackRate = primary.playbackRate; }});
}};
const stepFrame = (delta) => {{
if (!videos.length) return;
pauseBoth();
videos.forEach((video) => {{
const duration = isFinite(video.duration) ? video.duration : 0;
let nextTime = (video.currentTime || 0) + delta;
if (duration > 0) {{
nextTime = ((nextTime % duration) + duration) % duration;
}} else {{
nextTime = Math.max(0, nextTime);
}}
try {{ video.currentTime = nextTime; }} catch (e) {{}}
}});
}};
const resetPlayback = () => {{
pauseBoth();
videos.forEach((video) => {{
try {{ video.currentTime = 0; }} catch (e) {{}}
}});
}};
if (base.tagName === "VIDEO" && overlay.tagName === "VIDEO") {{
bindSync(base, overlay);
bindSync(overlay, base);
}} else {{
button.disabled = true;
button.textContent = "Play";
button.style.opacity = "0.55";
stepButtons.forEach((btn) => {{ if (btn) btn.disabled = true; }});
}}
videos.forEach((video) => {{
video.addEventListener("loadeddata", playBoth, {{ once: true }});
}});
button.addEventListener("click", () => {{
if (!videos.length || videos[0].paused) {{
playBoth();
}} else {{
pauseBoth();
}}
}});
if (prevBtn) prevBtn.addEventListener("click", () => stepFrame(-FRAME_DELTA));
if (nextBtn) nextBtn.addEventListener("click", () => stepFrame(FRAME_DELTA));
if (resetBtn) resetBtn.addEventListener("click", resetPlayback);
slider.addEventListener("input", applySplit);
applySplit();
const reportHeight = () => {{
const h = Math.ceil(root.getBoundingClientRect().height + 2);
parent.postMessage({{ type: "compare-iframe-height", id: "{compare_id}", height: h }}, "*");
}};
reportHeight();
window.addEventListener("load", reportHeight);
if (typeof ResizeObserver !== "undefined") {{
new ResizeObserver(reportHeight).observe(root);
}}
videos.forEach((video) => {{
video.addEventListener("loadedmetadata", reportHeight);
}});
}})();
</script>
</body>
</html>
"""
return (
'<iframe class="compare-frame" '
'sandbox="allow-scripts allow-same-origin" '
'scrolling="no" '
'srcdoc="' + html.escape(inner_doc, quote=True) + '"></iframe>'
)
def save_video_tensor(video_tensor, output_path):
video = (video_tensor / 2 + 0.5).clamp(0, 1)
video = video.squeeze(0).permute(1, 2, 3, 0).detach().cpu().float().numpy()
video = (video * 255).astype(np.uint8)
imageio.mimwrite(output_path, video, fps=FPS, quality=10)
return str(output_path)
def decode_with_wan_vae(latents, vae):
vae_dtype = get_module_dtype(vae)
latents = latents.to(device=DEVICE, dtype=vae_dtype)
latents_mean = torch.tensor(vae.config.latents_mean, device=DEVICE, dtype=vae_dtype).view(1, -1, 1, 1, 1)
latents_std = torch.tensor(vae.config.latents_std, device=DEVICE, dtype=vae_dtype).view(1, -1, 1, 1, 1)
latents = latents * latents_std + latents_mean
with torch.no_grad():
video = vae.decode(latents, return_dict=False)[0]
return video
def decode_with_refdecoder(latents, reference_frame, vae, transformer):
decode_dtype = get_module_dtype(vae)
latents = latents.to(device=DEVICE, dtype=decode_dtype)
latents_mean = torch.tensor(
vae.config.latents_mean,
device=DEVICE,
dtype=decode_dtype,
).view(1, -1, 1, 1, 1)
latents_std = torch.tensor(
vae.config.latents_std,
device=DEVICE,
dtype=decode_dtype,
).view(1, -1, 1, 1, 1)
latents = latents * latents_std + latents_mean
reference_frame = reference_frame.to(device=DEVICE, dtype=decode_dtype)
with torch.no_grad():
video = vae.decode(
latents,
transformer,
return_dict=True,
reference_frame=reference_frame,
skip=False,
window_size=-1,
).sample
if hasattr(vae, "clear_cache"):
vae.clear_cache()
return video
_NUM_DENOISING_CHUNKS = 4
CHUNK_BOUNDARIES = tuple(
NUM_INFERENCE_STEPS * (i + 1) // _NUM_DENOISING_CHUNKS
for i in range(_NUM_DENOISING_CHUNKS)
)
assert CHUNK_BOUNDARIES[-1] == NUM_INFERENCE_STEPS
def _run_diffusion_steps(
latents,
condition,
prompt_embeds,
negative_prompt_embeds,
image_embeds,
timesteps,
start_step,
end_step,
transformer_dtype,
):
"""Inlined Wan 2.1 I2V denoising loop. Runs steps [start_step, end_step)."""
transformer = GENERATION_PIPE.transformer
scheduler = GENERATION_PIPE.scheduler
with torch.no_grad():
for i in range(start_step, end_step):
step_start = time.perf_counter()
t = timesteps[i]
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
timestep = t.expand(latents.shape[0])
with transformer.cache_context("cond"):
noise_pred = transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_image=image_embeds,
return_dict=False,
)[0]
with transformer.cache_context("uncond"):
noise_uncond = transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states_image=image_embeds,
return_dict=False,
)[0]
noise_pred = noise_uncond + GUIDANCE_SCALE * (noise_pred - noise_uncond)
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
step_secs = time.perf_counter() - step_start
print(
f"[diffusion] step {i + 1}/{NUM_INFERENCE_STEPS} "
f"(t={float(t):.1f}, {step_secs:.2f}s)",
flush=True,
)
return latents
@spaces.GPU(duration=60)
def generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width):
"""Encode prompt+image, prepare initial latents and condition. NO denoising.
Loads only the encoders + VAE to GPU (not the 14B transformer). Returns a
CPU-resident state dict consumable by generate_latents_chunk_on_gpu.
"""
log_cuda_mem("start generate_latents_setup_on_gpu")
text_encoder = GENERATION_PIPE.text_encoder
image_encoder = GENERATION_PIPE.image_encoder
vae = GENERATION_PIPE.vae
text_encoder.to(DEVICE)
image_encoder.to(DEVICE)
vae.to(DEVICE)
try:
transformer_dtype = GENERATION_PIPE.transformer.dtype
prompt_embeds, negative_prompt_embeds = GENERATION_PIPE.encode_prompt(
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
do_classifier_free_guidance=True,
num_videos_per_prompt=1,
max_sequence_length=512,
device=DEVICE,
)
prompt_embeds = prompt_embeds.to(transformer_dtype)
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
image_embeds = GENERATION_PIPE.encode_image(resized_image, DEVICE)
image_embeds = image_embeds.repeat(1, 1, 1).to(transformer_dtype)
image_tensor = GENERATION_PIPE.video_processor.preprocess(
resized_image, height=height, width=width
).to(DEVICE, dtype=torch.float32)
generator = torch.Generator(device=DEVICE).manual_seed(seed)
latents, condition = GENERATION_PIPE.prepare_latents(
image_tensor,
1,
GENERATION_PIPE.vae.config.z_dim,
height,
width,
NUM_FRAMES,
torch.float32,
DEVICE,
generator,
None,
None,
)
state = {
"prompt_embeds": prompt_embeds.detach().cpu(),
"negative_prompt_embeds": negative_prompt_embeds.detach().cpu(),
"image_embeds": image_embeds.detach().cpu(),
"condition": condition.detach().cpu(),
"latents": latents.detach().cpu(),
"step_idx": 0,
}
finally:
text_encoder.to("cpu")
image_encoder.to("cpu")
vae.to("cpu")
log_cuda_mem("end generate_latents_setup_on_gpu")
return state
@spaces.GPU(duration=60)
def generate_latents_chunk_on_gpu(state, end_step):
"""Run denoising steps from state['step_idx'] to end_step. Only transformer is moved to GPU."""
log_cuda_mem(f"start latents chunk -> step {end_step}")
transformer = GENERATION_PIPE.transformer
transformer.to(DEVICE)
try:
GENERATION_PIPE.scheduler.set_timesteps(NUM_INFERENCE_STEPS, device=DEVICE)
timesteps = GENERATION_PIPE.scheduler.timesteps
transformer_dtype = transformer.dtype
latents = state["latents"].to(DEVICE)
condition = state["condition"].to(DEVICE)
prompt_embeds = state["prompt_embeds"].to(DEVICE)
negative_prompt_embeds = state["negative_prompt_embeds"].to(DEVICE)
image_embeds = state["image_embeds"].to(DEVICE)
latents = _run_diffusion_steps(
latents,
condition,
prompt_embeds,
negative_prompt_embeds,
image_embeds,
timesteps,
state["step_idx"],
end_step,
transformer_dtype,
)
state["latents"] = latents.detach().cpu()
state["step_idx"] = end_step
finally:
transformer.to("cpu")
log_cuda_mem(f"end latents chunk -> step {end_step}")
return state
@spaces.GPU(duration=20)
def decode_wan_on_gpu(latents):
log_cuda_mem("start decode_wan_on_gpu")
WAN_VAE.to(DEVICE)
try:
video = decode_with_wan_vae(latents, WAN_VAE)
video = video.detach().cpu()
finally:
WAN_VAE.to("cpu")
log_cuda_mem("after wan decode")
return video
@spaces.GPU(duration=25)
def decode_refdecoder_on_gpu(latents, reference_frame):
log_cuda_mem("start decode_refdecoder_on_gpu")
REFDECODER_VAE.to(DEVICE)
REFDECODER_TRANSFORMER.to(DEVICE)
try:
video = decode_with_refdecoder(
latents,
reference_frame,
REFDECODER_VAE,
REFDECODER_TRANSFORMER,
)
video = video.detach().cpu()
finally:
REFDECODER_VAE.to("cpu")
REFDECODER_TRANSFORMER.to("cpu")
log_cuda_mem("after refdecoder decode")
return video
def generate_and_decode(image, prompt, seed, progress=gr.Progress()):
if image is None:
raise gr.Error("Please upload an input image.")
if DEVICE != "cuda":
raise gr.Error("This demo expects a CUDA GPU to run Wan I2V generation.")
request_start = time.perf_counter()
prompt = prompt.strip() if prompt else ""
seed = int(seed) if seed is not None else random.randint(0, 2**32 - 1)
run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}"
run_dir.mkdir(parents=True, exist_ok=True)
# 1 setup chunk (encoders + VAE) + len(CHUNK_BOUNDARIES) denoising chunks.
total_chunks = 1 + len(CHUNK_BOUNDARIES)
progress(0.0, desc=f"Generating latents (1/{total_chunks})")
t0 = time.perf_counter()
resized_image, height, width = resize_image_for_wan(image, GENERATION_PIPE)
state = generate_latents_setup_on_gpu(resized_image, prompt, seed, height, width)
for chunk_idx, end_step in enumerate(CHUNK_BOUNDARIES, start=2):
progress(
0.8 * (chunk_idx - 1) / total_chunks,
desc=f"Generating latents ({chunk_idx}/{total_chunks})",
)
state = generate_latents_chunk_on_gpu(state, end_step)
latents = normalize_latent_shape(state["latents"])
latent_secs = time.perf_counter() - t0
print(f"[timing] latent generation: {latent_secs:.2f}s")
reference_frame = build_reference_frame(resized_image, "cpu")
latent_path = run_dir / "wan_latents.pt"
torch.save(
{
"latents": latents,
"height": height,
"width": width,
"prompt": prompt,
"seed": seed,
},
latent_path,
)
progress(0.8, desc="Decoding Wan baseline")
t0 = time.perf_counter()
wan_video = decode_wan_on_gpu(latents)
wan_secs = time.perf_counter() - t0
print(f"[timing] wan decode: {wan_secs:.2f}s")
wan_video_path = save_video_tensor(wan_video, run_dir / "wan_vae.mp4")
del wan_video
gc.collect()
progress(0.9, desc="Decoding RefDecoder")
t0 = time.perf_counter()
ref_video = decode_refdecoder_on_gpu(latents, reference_frame)
ref_secs = time.perf_counter() - t0
print(f"[timing] refdecoder decode: {ref_secs:.2f}s")
ref_video_path = save_video_tensor(ref_video, run_dir / "refdecoder.mp4")
del ref_video
gc.collect()
compare_html = build_compare_html(wan_video_path, ref_video_path)
total_secs = time.perf_counter() - request_start
print(
f"[timing] request total: {total_secs:.2f}s "
f"(latents={latent_secs:.2f}s, wan={wan_secs:.2f}s, ref={ref_secs:.2f}s)"
)
return (
gr.update(value=compare_html, visible=True),
wan_video_path,
ref_video_path,
"",
gr.update(value=wan_video_path, interactive=True),
gr.update(value=ref_video_path, interactive=True),
)
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Manrope:wght@400;500;600;700;800&display=swap');
:root {
--page-bg: #f4f1e8;
--card-bg: rgba(255, 252, 246, 0.92);
--card-border: rgba(50, 43, 32, 0.12);
--accent: #1f6a52;
--accent-2: #c96f42;
--text-main: #201a14;
--text-soft: #201a14;
--ui-font: "Manrope", "Inter", "Segoe UI", sans-serif;
}
.gradio-container {
background:
radial-gradient(circle at top left, rgba(201, 111, 66, 0.18), transparent 26%),
radial-gradient(circle at top right, rgba(31, 106, 82, 0.16), transparent 28%),
linear-gradient(180deg, #f8f4ec 0%, var(--page-bg) 100%);
font-family: var(--ui-font);
}
.app-shell {
max-width: 1320px;
margin: 0 auto;
}
.hero-card,
.panel-card,
.output-card {
background: var(--card-bg);
border: 1px solid var(--card-border);
border-radius: 24px;
box-shadow: 0 18px 50px rgba(49, 39, 26, 0.08);
}
.hero-card {
padding: 28px 30px 20px 30px;
margin-bottom: 18px;
}
.hero-kicker {
display: inline-block;
padding: 6px 12px;
border-radius: 999px;
background: rgba(31, 106, 82, 0.10);
color: var(--accent);
font-size: 12px;
font-weight: 700;
letter-spacing: 0.08em;
text-transform: uppercase;
}
.hero-title {
margin: 14px 0 8px 0;
font-size: 42px;
line-height: 1.05;
font-weight: 800;
color: var(--text-main);
}
.hero-copy {
margin: 0;
max-width: 840px;
color: var(--text-soft);
font-size: 17px;
line-height: 1.6;
font-family: var(--ui-font);
}
.panel-card,
.output-card {
padding: 18px;
}
.panel-card {
overflow: hidden;
}
.section-title {
margin: 0 0 6px 0;
color: var(--text-main);
font-size: 22px;
font-weight: 750;
}
.section-copy {
margin: 0 0 14px 0;
color: var(--text-soft);
font-size: 14px;
line-height: 1.55;
font-family: var(--ui-font);
}
.compare-note {
padding: 12px 14px;
border-radius: 16px;
background: rgba(201, 111, 66, 0.08);
color: #6a4128;
font-size: 14px;
line-height: 1.5;
margin-bottom: 14px;
}
#generate-btn {
min-height: 108px;
height: 100%;
width: 100%;
font-size: 16px;
font-weight: 700;
background: linear-gradient(135deg, var(--accent) 0%, #154f3d 100%);
border: none;
}
#generate-btn:hover {
filter: brightness(1.04);
}
.output-grid {
gap: 14px;
}
.compare-shell {
display: flex;
flex-direction: column;
gap: 12px;
}
.compare-frame {
width: 100%;
/* aspect-ratio is a tight fallback for the brief moment before the parent
JS estimator (and then the iframe's own postMessage) sets the height. */
aspect-ratio: 16 / 11;
border: 0;
background: transparent;
overflow: hidden;
display: block;
transition: height 120ms ease;
}
.compare-topbar {
display: flex;
justify-content: space-between;
align-items: center;
gap: 12px;
}
.compare-chip {
padding: 8px 12px;
border-radius: 999px;
background: rgba(31, 106, 82, 0.08);
color: var(--text-main);
font-size: 12px;
font-weight: 700;
letter-spacing: 0.04em;
text-transform: uppercase;
}
.compare-chip-right {
background: rgba(201, 111, 66, 0.10);
}
.compare-stage {
position: relative;
width: 100%;
aspect-ratio: 16 / 9;
overflow: hidden;
border-radius: 22px;
background: #16120f;
border: 1px solid rgba(255,255,255,0.08);
}
.compare-video {
position: absolute;
inset: 0;
width: 100%;
height: 100%;
object-fit: contain;
background: #16120f;
}
.compare-overlay {
clip-path: inset(0 0 0 50%);
}
.compare-divider {
position: absolute;
top: 0;
bottom: 0;
left: 50%;
width: 2px;
background: rgba(255,255,255,0.96);
box-shadow: 0 0 0 1px rgba(31, 26, 20, 0.15);
transform: translateX(-1px);
pointer-events: none;
}
.compare-divider::after {
content: "";
position: absolute;
top: 50%;
left: 50%;
width: 18px;
height: 18px;
border-radius: 999px;
background: #fff;
border: 2px solid rgba(31, 26, 20, 0.18);
transform: translate(-50%, -50%);
}
.compare-range {
position: absolute;
inset: 0;
width: 100%;
height: 100%;
opacity: 0;
cursor: ew-resize;
}
.compare-caption {
color: var(--text-soft);
font-size: 14px;
line-height: 1.5;
font-family: var(--ui-font);
}
.compare-panel {
padding-bottom: 34px;
}
.seed-action-row {
align-items: stretch;
}
.seed-action-row > .gradio-column {
min-width: 0;
}
.run-status {
margin-top: 8px;
color: var(--text-soft);
font-size: 13px;
line-height: 1.4;
min-height: 1.4em;
}
.run-status p {
margin: 0;
}
.download-row {
margin-top: 12px;
gap: 12px;
justify-content: center;
flex-wrap: wrap;
}
.download-row button {
border: 0 !important;
border-radius: 999px !important;
padding: 10px 22px !important;
font-size: 14px !important;
font-weight: 700 !important;
box-shadow: none !important;
min-height: 0 !important;
}
button.download-baseline {
background: var(--accent) !important;
color: #fff !important;
}
button.download-ref {
background: var(--accent-2) !important;
color: #fff !important;
}
.download-row button:hover:not([disabled]):not(:disabled) {
filter: brightness(1.05);
}
button.download-baseline[disabled],
button.download-baseline:disabled {
background: rgba(31, 106, 82, 0.14) !important;
color: #123a2d !important;
box-shadow: inset 0 0 0 1px rgba(31, 106, 82, 0.12) !important;
opacity: 1 !important;
cursor: not-allowed;
}
button.download-ref[disabled],
button.download-ref:disabled {
background: rgba(201, 111, 66, 0.16) !important;
color: #6e3d23 !important;
box-shadow: inset 0 0 0 1px rgba(201, 111, 66, 0.16) !important;
opacity: 1 !important;
cursor: not-allowed;
}
"""
with gr.Blocks(title="RefDecoder I2V Demo", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
with gr.Column(elem_classes="app-shell"):
gr.HTML("""
<script>
(() => {
if (window.__refdecoderResizeBound) return;
window.__refdecoderResizeBound = true;
const STAGE_RATIO = 9 / 16;
const CHROME = 160;
const observed = new WeakSet();
const estimateHeight = (iframe) => {
if (iframe.dataset.exactSized === "1") return;
const w = iframe.getBoundingClientRect().width;
if (w > 0) {
iframe.style.height = Math.round(w * STAGE_RATIO + CHROME) + "px";
}
};
const trackIframe = (iframe) => {
if (observed.has(iframe)) return;
observed.add(iframe);
estimateHeight(iframe);
new ResizeObserver(() => estimateHeight(iframe)).observe(iframe);
};
document.querySelectorAll("iframe.compare-frame").forEach(trackIframe);
new MutationObserver((mutations) => {
for (const m of mutations) {
for (const n of m.addedNodes) {
if (n.nodeType !== 1) continue;
if (n.matches && n.matches("iframe.compare-frame")) trackIframe(n);
const inner = n.querySelectorAll && n.querySelectorAll("iframe.compare-frame");
if (inner) inner.forEach(trackIframe);
}
}
}).observe(document.body, { childList: true, subtree: true });
window.addEventListener("message", (e) => {
if (!e.data || e.data.type !== "compare-iframe-height") return;
const h = Math.max(200, Number(e.data.height) || 0);
document.querySelectorAll("iframe.compare-frame").forEach((f) => {
if (f.contentWindow === e.source) {
f.style.height = h + "px";
f.dataset.exactSized = "1";
}
});
});
})();
</script>
<div class="hero-card">
<div class="hero-title">RefDecoder I2V Demo</div>
<p class="hero-copy">
Upload one image, optionally add a prompt, and compare two decoders on the same Wan latent video.
The app generates latents once, then renders them with Wan's original VAE and with RefDecoder.
</p>
</div>
""")
with gr.Column(elem_classes=["panel-card", "compare-panel"]):
gr.HTML("""
<div class="section-title">Inputs</div>
<div class="section-copy">
Upload a reference image, optionally add a prompt, and compare the decoders below.
</div>
""")
with gr.Row(equal_height=True):
with gr.Column(scale=3):
image_input = gr.Image(
label="Input Image",
type="pil",
height=180,
)
with gr.Column(scale=5):
prompt_input = gr.Textbox(
label="Prompt",
lines=2,
placeholder="A woman turns toward the camera as her hair moves in the wind...",
)
with gr.Row(equal_height=True, elem_classes="seed-action-row"):
with gr.Column(scale=1):
seed_input = gr.Number(
label="Seed",
value=None,
precision=0,
info="Optional",
)
with gr.Column(scale=1):
run_button = gr.Button(
"Generate Comparison",
variant="primary",
elem_id="generate-btn",
)
status_md = gr.Markdown(value="", elem_classes="run-status")
with gr.Column(elem_classes="panel-card"):
gr.HTML("""
<div class="section-title">Decoder Comparison</div>
<div class="section-copy">
Left side shows Wan Baseline. Right side shows RefDecoder. Drag the divider across the frame to compare them.
</div>
""")
compare_output = gr.HTML(value=build_compare_html(None, None))
with gr.Row(elem_classes="download-row"):
wan_download_btn = gr.DownloadButton(
label="Download Baseline",
value=None,
interactive=False,
elem_classes="download-baseline",
)
ref_download_btn = gr.DownloadButton(
label="Download RefDecoder",
value=None,
interactive=False,
elem_classes="download-ref",
)
wan_video_hidden = gr.Video(visible=False)
ref_video_hidden = gr.Video(visible=False)
def reset_for_new_run():
return (
"",
gr.update(value=None, interactive=False),
gr.update(value=None, interactive=False),
)
run_button.click(
fn=reset_for_new_run,
inputs=None,
outputs=[status_md, wan_download_btn, ref_download_btn],
queue=False,
show_progress="hidden",
).then(
fn=generate_and_decode,
inputs=[image_input, prompt_input, seed_input],
outputs=[
compare_output,
wan_video_hidden,
ref_video_hidden,
status_md,
wan_download_btn,
ref_download_btn,
],
)
if __name__ == "__main__":
demo.queue(max_size=2).launch(allowed_paths=[str(OUTPUT_ROOT)])