Spaces:
Paused
Paused
| import spaces | |
| import os | |
| import sys | |
| import uuid | |
| import shutil | |
| import gradio as gr | |
| import torch | |
| from omegaconf import OmegaConf | |
| from torchvision.io import write_video | |
| from einops import rearrange | |
| from huggingface_hub import snapshot_download | |
| from pipeline import ( | |
| CausalDiffusionInferencePipeline, | |
| CausalInferencePipeline, | |
| ) | |
| from utils.dataset import TextDataset | |
| from utils.misc import set_seed | |
| from demo_utils.memory import get_cuda_free_memory_gb, DynamicSwapInstaller | |
| # ------------------------------------------------------------------- | |
| # Download checkpoints once when the Space starts | |
| # ------------------------------------------------------------------- | |
| snapshot_download( | |
| repo_id="Wan-AI/Wan2.1-T2V-1.3B", | |
| local_dir="./checkpoints/Wan2.1-T2V-1.3B", | |
| ) | |
| snapshot_download( | |
| repo_id="KlingTeam/VideoReward", | |
| local_dir="./checkpoints/Videoreward", | |
| ) | |
| snapshot_download( | |
| repo_id="gdhe17/Self-Forcing", | |
| local_dir="./checkpoints/ode_init.pt", | |
| ) | |
| snapshot_download( | |
| repo_id="JaydenLu666/Reward-Forcing-T2V-1.3B", | |
| local_dir="./checkpoints/Reward-Forcing-T2V-1.3B", | |
| ) | |
| # === Paths === | |
| CONFIG_PATH = "configs/reward_forcing.yaml" | |
| CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt" | |
| PROMPT_DIR = "prompts/gradio_inputs" | |
| OUTPUT_ROOT = "videos" | |
| os.makedirs(PROMPT_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_ROOT, exist_ok=True) | |
| def reward_forcing_inference( | |
| prompt_txt_path: str, | |
| num_output_frames: int, | |
| use_ema: bool, | |
| output_root: str, | |
| progress: gr.Progress, | |
| ): | |
| """ | |
| Inline / simplified version of inference.py: | |
| - single GPU | |
| - text-to-video only | |
| - one .txt file = N prompts, but returns only the first generated video | |
| """ | |
| logs = "" | |
| # --------------------- Device & randomness --------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| set_seed(0) | |
| free_vram = get_cuda_free_memory_gb(device) | |
| logs += f"Free VRAM {free_vram} GB\n" | |
| low_memory = free_vram < 40 | |
| torch.set_grad_enabled(False) | |
| # --------------------- Phase 1: model & config init --------------------- | |
| progress(0.05, desc="Init: loading config") | |
| logs += "Loading config...\n" | |
| config = OmegaConf.load(CONFIG_PATH) | |
| default_config = OmegaConf.load("configs/default_config.yaml") | |
| config = OmegaConf.merge(default_config, config) | |
| progress(0.15, desc="Init: creating pipeline") | |
| logs += "Creating pipeline...\n" | |
| if hasattr(config, "denoising_step_list"): | |
| pipeline = CausalInferencePipeline(config, device=device) | |
| else: | |
| pipeline = CausalDiffusionInferencePipeline(config, device=device) | |
| progress(0.35, desc="Init: loading checkpoint") | |
| logs += "Loading checkpoint weights...\n" | |
| state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu") | |
| pipeline.generator.load_state_dict(state_dict) | |
| checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH)) | |
| checkpoint_step = checkpoint_step.split("_")[-1] | |
| progress(0.55, desc="Init: moving model to device") | |
| logs += "Moving model to device...\n" | |
| pipeline = pipeline.to(dtype=torch.bfloat16) | |
| if low_memory: | |
| DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device) | |
| else: | |
| pipeline.text_encoder.to(device=device) | |
| pipeline.generator.to(device=device) | |
| pipeline.vae.to(device=device) | |
| # --------------------- Dataset setup --------------------- | |
| progress(0.65, desc="Preparing dataset") | |
| logs += "Preparing dataset (TextDataset)...\n" | |
| dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None) | |
| num_prompts = len(dataset) | |
| logs += f"Number of prompts: {num_prompts}\n" | |
| from torch.utils.data import DataLoader, SequentialSampler | |
| sampler = SequentialSampler(dataset) | |
| dataloader = DataLoader( | |
| dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False | |
| ) | |
| # --------------------- Clean output folder --------------------- | |
| progress(0.7, desc="Cleaning output folder") | |
| output_folder = os.path.join( | |
| output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step | |
| ) | |
| shutil.rmtree(output_folder, ignore_errors=True) | |
| os.makedirs(output_folder, exist_ok=True) | |
| logs += f"Output directory: {output_folder}\n" | |
| # --------------------- Phase 2: inference loop --------------------- | |
| for i, batch_data in progress.tqdm( | |
| enumerate(dataloader), | |
| total=num_prompts, | |
| desc="Video generation", | |
| unit="prompt", | |
| ): | |
| idx = batch_data["idx"].item() | |
| # Unpack dataset batch | |
| if isinstance(batch_data, dict): | |
| batch = batch_data | |
| elif isinstance(batch_data, list): | |
| batch = batch_data[0] | |
| else: | |
| batch = batch_data | |
| all_video = [] | |
| # TEXT-TO-VIDEO only | |
| prompt = batch["prompts"][0] | |
| extended_prompt = batch.get("extended_prompts", [None])[0] | |
| prompts = [extended_prompt] if extended_prompt else [prompt] | |
| initial_latent = None | |
| sampled_noise = torch.randn( | |
| [1, num_output_frames, 16, 60, 104], | |
| device=device, | |
| dtype=torch.bfloat16, | |
| ) | |
| logs += f"Generating for prompt: {prompt[:80]}...\n" | |
| # WAN2 inference | |
| video, latents = pipeline.inference( | |
| noise=sampled_noise, | |
| text_prompts=prompts, | |
| return_latents=True, | |
| initial_latent=initial_latent, | |
| low_memory=low_memory, | |
| ) | |
| current_video = rearrange(video, "b t c h w -> b t h w c").cpu() | |
| all_video.append(current_video) | |
| video = 255.0 * torch.cat(all_video, dim=1) | |
| pipeline.vae.model.clear_cache() | |
| if idx < num_prompts: | |
| model = "regular" if not use_ema else "ema" | |
| safe_name = prompt[:50].replace("/", "_").replace("\\", "_") | |
| output_path = os.path.join(output_folder, f"{safe_name}.mp4") | |
| write_video(output_path, video[0], fps=16) | |
| logs += f"Saved video: {output_path}\n" | |
| progress(1.0, desc="Done") | |
| return output_path, logs | |
| logs += "[WARN] No video generated.\n" | |
| return None, logs | |
| def gradio_generate( | |
| prompt: str, duration: str, use_ema: bool, progress=gr.Progress(track_tqdm=True) | |
| ): | |
| """ | |
| Triggered by Gradio: | |
| - writes prompt to a .txt file | |
| - performs inference | |
| - returns video + logs | |
| """ | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Please enter a text prompt π") | |
| # Duration β number of frames | |
| num_output_frames = 21 if duration == "5s (21 frames)" else 120 | |
| os.makedirs(PROMPT_DIR, exist_ok=True) | |
| prompt_id = uuid.uuid4().hex[:8] | |
| prompt_path = os.path.join(PROMPT_DIR, f"prompt_{prompt_id}.txt") | |
| with open(prompt_path, "w", encoding="utf-8") as f: | |
| f.write(prompt.strip() + "\n") | |
| video_path, logs = reward_forcing_inference( | |
| prompt_txt_path=prompt_path, | |
| num_output_frames=num_output_frames, | |
| use_ema=use_ema, | |
| output_root=OUTPUT_ROOT, | |
| progress=progress, | |
| ) | |
| if video_path is None or not os.path.exists(video_path): | |
| raise gr.Error("No video generated. Check logs for details.") | |
| return video_path, logs | |
| # ------------------------------------------------------------------- | |
| # Gradio UI β updated title + intro text | |
| # ------------------------------------------------------------------- | |
| with gr.Blocks(title="Reward Forcing β Text-to-Video Demo") as demo: | |
| gr.Markdown( | |
| """ | |
| # π¬ Reward Forcing β Text-to-Video Demo | |
| Generate short videos from text prompts using a model trained with the **Reward Forcing** method. | |
| Reward Forcing is a recent research technique that improves how well a video model follows a written description | |
| by guiding training with learned reward signals. You can learn more here: | |
| https://reward-forcing.github.io | |
| π Type a prompt, click **Generate**, and the video will appear below. | |
| Longer and more detailed prompts usually produce better results. | |
| > β³ The first run may take a little longer while the model loads β generation is faster afterwards. | |
| """ | |
| ) | |
| with gr.Row(): | |
| prompt_in = gr.Textbox( | |
| label="Prompt", | |
| placeholder="A cinematic shot of late-summer wheat fields moving in the wind...", | |
| lines=4, | |
| ) | |
| with gr.Row(): | |
| duration = gr.Radio( | |
| ["5s (21 frames)", "30s (120 frames)"], | |
| value="5s (21 frames)", | |
| label="Duration", | |
| ) | |
| use_ema = gr.Checkbox(value=True, label="Use EMA weights (--use_ema)") | |
| generate_btn = gr.Button("π Generate Video", variant="primary") | |
| with gr.Row(): | |
| video_out = gr.Video(label="Generated Video") | |
| logs_out = gr.Textbox(label="Logs", lines=12, interactive=False) | |
| generate_btn.click( | |
| fn=gradio_generate, | |
| inputs=[prompt_in, duration, use_ema], | |
| outputs=[video_out, logs_out], | |
| ) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() |