import gradio as gr import torch import torchvision from diffusers import I2VGenXLPipeline, DiffusionPipeline from torchvision.transforms.functional import to_tensor from PIL import Image from utils import create_progress_updater if gr.NO_RELOAD: n_sdxl_steps = 50 n_i2v_steps = 50 high_noise_frac = 0.8 negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms" generator = torch.manual_seed(8888) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") total_steps = n_sdxl_steps + n_i2v_steps print("Device:", device) base = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True, ) refiner = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-refiner-1.0", text_encoder_2=base.text_encoder_2, vae=base.vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16", ) pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16") base.to("cuda") refiner.to("cuda") pipeline.to("cuda") base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True) refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True) pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True) def generate(prompt: str, progress=gr.Progress()): progress((0, 100), desc="Starting..") image = base( prompt=prompt, num_inference_steps=n_sdxl_steps, denoising_end=high_noise_frac, output_type="latent", callback_on_step_end=create_progress_updater( start=0, total=total_steps, desc="Generating first frame...", progress=progress, ), ).images[0] progress((n_sdxl_steps * high_noise_frac, total_steps), desc="Refining first frame...") image = refiner( prompt=prompt, num_inference_steps=n_sdxl_steps, denoising_start=high_noise_frac, image=image, callback_on_step_end=create_progress_updater( start=n_sdxl_steps * high_noise_frac, total=total_steps, desc="Refining first frame...", progress=progress, ), ).images[0] image = to_tensor(image) progress((n_sdxl_steps + 1, total_steps), desc="Generating video...") frames: list[Image.Image] = pipeline( prompt=prompt, image=image, num_inference_steps=50, negative_prompt=negative_prompt, guidance_scale=9.0, generator=generator, decode_chunk_size=4, num_frames=64, ).frames[0] progress((total_steps - 1, total_steps), desc="Finalizing...") frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames] frames = torch.stack(frames) torchvision.io.write_video("video.mp4", frames, fps=16) return "video.mp4" app = gr.Interface( fn=generate, inputs=["text"], outputs=gr.Video() ) if __name__ == "__main__": app.launch()