Spaces:
Sleeping
Sleeping
import spaces | |
import torch | |
import gradio as gr | |
from diffusers import CogVideoXPipeline | |
from diffusers.utils import export_to_video | |
from PIL import Image | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 1. Load & optimize the CogVideoX pipeline with CPU offload | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
pipe = CogVideoXPipeline.from_pretrained( | |
"THUDM/CogVideoX1.5-5B", | |
torch_dtype=torch.bfloat16 | |
) | |
pipe.enable_model_cpu_offload() # auto move submodules between CPU/GPU | |
pipe.vae.enable_slicing() # slice VAE for extra VRAM savings | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 2. Resolution parsing & sanitization | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def make_divisible_by_8(x: int) -> int: | |
return (x // 8) * 8 | |
def parse_resolution(res_str: str): | |
""" | |
Convert strings like "480p" into (height, width) both divisible by 8 | |
while preserving ~16:9 aspect ratio. | |
""" | |
h = int(res_str.rstrip("p")) | |
w = int(h * 16 / 9) | |
return make_divisible_by_8(h), make_divisible_by_8(w) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 3. GPUβdecorated video generation function | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# allow up to 180s of GPU time | |
def generate_video( | |
prompt: str, | |
steps: int, | |
frames: int, | |
fps: int, | |
resolution: str | |
) -> str: | |
# 3.1 Determine target resolution and native resolution | |
target_h, target_w = parse_resolution(resolution) | |
# 3.2 Run the diffusion pipeline at native resolution | |
output = pipe( | |
prompt=prompt, | |
num_inference_steps=steps, | |
num_frames=frames, | |
) | |
video_frames = output.frames[0] # list of PIL Images at native size | |
# 3.3 Resize frames to user-specified resolution | |
resized_frames = [ | |
frame.resize((target_w, target_h), Image.LANCZOS) | |
for frame in video_frames | |
] | |
# 3.4 Export to MP4 (H.264) with chosen FPS | |
video_path = export_to_video(resized_frames, "generated.mp4", fps=fps) | |
return video_path | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 4. Build the Gradio interface with interactive controls | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
with gr.Blocks(title="Textual Imagination: A text to video synthesis") as demo: | |
gr.Markdown( | |
""" | |
# ποΈ Textual Imagination: A text to video synthesis | |
Generate videos from text prompts. | |
Adjust inference steps, frame count, fps, and resolution below. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox( | |
label="Prompt", | |
lines=2 | |
) | |
steps_slider = gr.Slider( | |
minimum=1, maximum=100, step=1, value=50, | |
label="Inference Steps" | |
) | |
frames_slider = gr.Slider( | |
minimum=16, maximum=320, step=1, value=161, | |
label="Total Frames" | |
) | |
fps_slider = gr.Slider( | |
minimum=1, maximum=60, step=1, value=16, | |
label="Frames per Second (FPS)" | |
) | |
res_dropdown = gr.Dropdown( | |
choices=["360p", "480p", "720p", "1080p"], | |
value="480p", | |
label="Resolution" | |
) | |
gen_button = gr.Button("Generate Video") | |
with gr.Column(): | |
video_output = gr.Video( | |
label="Generated Video", | |
format="mp4" | |
) | |
gen_button.click( | |
fn=generate_video, | |
inputs=[prompt_input, steps_slider, frames_slider, fps_slider, res_dropdown], | |
outputs=video_output | |
) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 5. Launch: disable SSR so Gradio blocks and stays alive | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
ssr_mode=False | |
) | |