ttv / app.py
orderlymirror's picture
Update app.py
77f9377 verified
import spaces
import torch
import gradio as gr
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from PIL import Image
# ────────────────────────────────────────────────────────────
# 1. Load & optimize the CogVideoX pipeline with CPU offload
# ────────────────────────────────────────────────────────────
pipe = CogVideoXPipeline.from_pretrained(
"THUDM/CogVideoX1.5-5B",
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload() # auto move submodules between CPU/GPU
pipe.vae.enable_slicing() # slice VAE for extra VRAM savings
# ────────────────────────────────────────────────────────────
# 2. Resolution parsing & sanitization
# ────────────────────────────────────────────────────────────
def make_divisible_by_8(x: int) -> int:
return (x // 8) * 8
def parse_resolution(res_str: str):
"""
Convert strings like "480p" into (height, width) both divisible by 8
while preserving ~16:9 aspect ratio.
"""
h = int(res_str.rstrip("p"))
w = int(h * 16 / 9)
return make_divisible_by_8(h), make_divisible_by_8(w)
# ────────────────────────────────────────────────────────────
# 3. GPU‑decorated video generation function
# ────────────────────────────────────────────────────────────
@spaces.GPU(duration=180) # allow up to 180s of GPU time
def generate_video(
prompt: str,
steps: int,
frames: int,
fps: int,
resolution: str
) -> str:
# 3.1 Determine target resolution and native resolution
target_h, target_w = parse_resolution(resolution)
# 3.2 Run the diffusion pipeline at native resolution
output = pipe(
prompt=prompt,
num_inference_steps=steps,
num_frames=frames,
)
video_frames = output.frames[0] # list of PIL Images at native size
# 3.3 Resize frames to user-specified resolution
resized_frames = [
frame.resize((target_w, target_h), Image.LANCZOS)
for frame in video_frames
]
# 3.4 Export to MP4 (H.264) with chosen FPS
video_path = export_to_video(resized_frames, "generated.mp4", fps=fps)
return video_path
# ────────────────────────────────────────────────────────────
# 4. Build the Gradio interface with interactive controls
# ────────────────────────────────────────────────────────────
with gr.Blocks(title="Textual Imagination: A text to video synthesis") as demo:
gr.Markdown(
"""
# 🎞️ Textual Imagination: A text to video synthesis
Generate videos from text prompts.
Adjust inference steps, frame count, fps, and resolution below.
"""
)
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
lines=2
)
steps_slider = gr.Slider(
minimum=1, maximum=100, step=1, value=50,
label="Inference Steps"
)
frames_slider = gr.Slider(
minimum=16, maximum=320, step=1, value=161,
label="Total Frames"
)
fps_slider = gr.Slider(
minimum=1, maximum=60, step=1, value=16,
label="Frames per Second (FPS)"
)
res_dropdown = gr.Dropdown(
choices=["360p", "480p", "720p", "1080p"],
value="480p",
label="Resolution"
)
gen_button = gr.Button("Generate Video")
with gr.Column():
video_output = gr.Video(
label="Generated Video",
format="mp4"
)
gen_button.click(
fn=generate_video,
inputs=[prompt_input, steps_slider, frames_slider, fps_slider, res_dropdown],
outputs=video_output
)
# ────────────────────────────────────────────────────────────
# 5. Launch: disable SSR so Gradio blocks and stays alive
# ────────────────────────────────────────────────────────────
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False
)