Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from diffusers import LTXVideoTransformer3DModel, LTXVideoPipeline | |
from transformers import T5EncoderModel, T5Tokenizer | |
import spaces | |
import numpy as np | |
import tempfile | |
import os | |
import time | |
import logging | |
from PIL import Image | |
import cv2 | |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
from fastapi.responses import FileResponse | |
import uvicorn | |
import threading | |
import json | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Global variables for model | |
pipe = None | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def load_model(): | |
"""Load the LTX-Video model with optimizations""" | |
global pipe | |
try: | |
logger.info("Loading LTX-Video model...") | |
# Load the pipeline | |
pipe = LTXVideoPipeline.from_pretrained( | |
"Lightricks/LTX-Video-0.9.7-dev", | |
torch_dtype=torch.bfloat16, | |
use_safetensors=True | |
) | |
# Move to device | |
pipe = pipe.to(device) | |
# Enable optimizations | |
pipe.vae.enable_tiling() | |
pipe.vae.enable_slicing() | |
# Enable memory efficient attention if available | |
if hasattr(pipe.unet, 'enable_xformers_memory_efficient_attention'): | |
pipe.unet.enable_xformers_memory_efficient_attention() | |
logger.info("Model loaded successfully!") | |
return True | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
return False | |
def validate_inputs(prompt, duration, image=None): | |
"""Validate input parameters""" | |
errors = [] | |
if not prompt or len(prompt.strip()) == 0: | |
errors.append("Prompt is required") | |
if len(prompt) > 500: | |
errors.append("Prompt must be less than 500 characters") | |
if duration < 3 or duration > 5: | |
errors.append("Duration must be between 3 and 5 seconds") | |
if image is not None: | |
try: | |
if isinstance(image, str): | |
img = Image.open(image) | |
else: | |
img = image | |
# Check image dimensions | |
width, height = img.size | |
if width > 1024 or height > 1024: | |
errors.append("Image dimensions must be less than 1024x1024") | |
except Exception as e: | |
errors.append(f"Invalid image: {str(e)}") | |
return errors | |
def frames_to_video(frames, output_path, fps=24): | |
"""Convert frames to video using OpenCV""" | |
try: | |
height, width = frames[0].shape[:2] | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
for frame in frames: | |
# Convert RGB to BGR for OpenCV | |
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
out.write(frame_bgr) | |
out.release() | |
return True | |
except Exception as e: | |
logger.error(f"Error creating video: {e}") | |
return False | |
def generate_video_core(prompt, negative_prompt="", duration=4, image=None): | |
"""Core video generation function with ZeroGPU decorator""" | |
global pipe | |
start_time = time.time() | |
try: | |
# Calculate number of frames (24 FPS) | |
num_frames = int(duration * 24) | |
# Prepare generation parameters | |
generation_kwargs = { | |
"prompt": prompt, | |
"negative_prompt": negative_prompt, | |
"num_frames": num_frames, | |
"height": 512, | |
"width": 768, | |
"num_inference_steps": 30, | |
"guidance_scale": 7.5, | |
"generator": torch.Generator(device=device).manual_seed(42) | |
} | |
# Add image if provided | |
if image is not None: | |
if isinstance(image, str): | |
image = Image.open(image) | |
# Resize image to match output dimensions | |
image = image.resize((768, 512), Image.Resampling.LANCZOS) | |
generation_kwargs["image"] = image | |
logger.info(f"Starting generation with {num_frames} frames...") | |
# Generate video | |
with torch.inference_mode(): | |
result = pipe(**generation_kwargs) | |
# Get the generated frames | |
frames = result.frames[0] # First (and only) video in batch | |
# Convert to numpy arrays if needed | |
if torch.is_tensor(frames): | |
frames = frames.cpu().numpy() | |
# Ensure frames are in the right format (0-255 uint8) | |
if frames.dtype != np.uint8: | |
frames = (frames * 255).astype(np.uint8) | |
# Create temporary video file | |
temp_dir = tempfile.mkdtemp() | |
video_path = os.path.join(temp_dir, "generated_video.mp4") | |
# Convert frames to video | |
success = frames_to_video(frames, video_path, fps=24) | |
if not success: | |
raise Exception("Failed to create video file") | |
generation_time = time.time() - start_time | |
logger.info(f"Video generated successfully in {generation_time:.2f} seconds") | |
return video_path, f"Generated in {generation_time:.2f}s" | |
except Exception as e: | |
logger.error(f"Error generating video: {e}") | |
raise Exception(f"Generation failed: {str(e)}") | |
def generate_video_gradio(prompt, negative_prompt, duration, image): | |
"""Gradio interface wrapper""" | |
try: | |
# Validate inputs | |
errors = validate_inputs(prompt, duration, image) | |
if errors: | |
return None, f"Validation errors: {'; '.join(errors)}" | |
# Check if model is loaded | |
if pipe is None: | |
return None, "Model not loaded. Please wait for initialization." | |
# Generate video | |
video_path, status = generate_video_core(prompt, negative_prompt, duration, image) | |
return video_path, status | |
except Exception as e: | |
logger.error(f"Gradio generation error: {e}") | |
return None, f"Error: {str(e)}" | |
# Create Gradio interface | |
def create_gradio_interface(): | |
with gr.Blocks(title="LTX-Video Generator", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 🎬 LTX-Video Generator") | |
gr.Markdown("Generate 3-5 second videos using the LTX-Video model from Lightricks") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input controls | |
image_input = gr.File( | |
label="Input Image (Optional)", | |
file_types=[".png", ".jpg", ".jpeg"], | |
type="filepath" | |
) | |
prompt_input = gr.Textbox( | |
label="Prompt", | |
placeholder="Describe the video you want to generate...", | |
lines=3, | |
max_lines=5 | |
) | |
negative_prompt_input = gr.Textbox( | |
label="Negative Prompt (Optional)", | |
placeholder="What you don't want in the video...", | |
lines=2, | |
max_lines=3 | |
) | |
duration_slider = gr.Slider( | |
minimum=3, | |
maximum=5, | |
value=4, | |
step=0.5, | |
label="Duration (seconds)" | |
) | |
generate_btn = gr.Button("🎬 Generate Video", variant="primary") | |
gr.Markdown("**Estimated time:** 4-6 seconds") | |
with gr.Column(scale=1): | |
# Output controls | |
video_output = gr.Video(label="Generated Video") | |
status_output = gr.Textbox(label="Status", interactive=False) | |
# Event handlers | |
generate_btn.click( | |
fn=generate_video_gradio, | |
inputs=[prompt_input, negative_prompt_input, duration_slider, image_input], | |
outputs=[video_output, status_output] | |
) | |
# Examples | |
gr.Examples( | |
examples=[ | |
["A cat playing with a ball of yarn", "", 4, None], | |
["Ocean waves crashing on a beach at sunset", "", 3, None], | |
["A person walking through a forest", "blurry, low quality", 5, None], | |
], | |
inputs=[prompt_input, negative_prompt_input, duration_slider, image_input] | |
) | |
return demo | |
# FastAPI setup | |
app = FastAPI(title="LTX-Video API", description="Generate videos using LTX-Video model") | |
async def api_generate_video( | |
prompt: str = Form(..., description="Text prompt for video generation"), | |
negative_prompt: str = Form("", description="Negative prompt (optional)"), | |
duration: float = Form(4.0, description="Duration in seconds (3-5)"), | |
image: UploadFile = File(None, description="Input image (optional)") | |
): | |
"""Generate video via API""" | |
try: | |
# Validate inputs | |
image_path = None | |
if image: | |
# Save uploaded image temporarily | |
temp_dir = tempfile.mkdtemp() | |
image_path = os.path.join(temp_dir, image.filename) | |
with open(image_path, "wb") as f: | |
content = await image.read() | |
f.write(content) | |
errors = validate_inputs(prompt, duration, image_path) | |
if errors: | |
raise HTTPException(status_code=400, detail={"errors": errors}) | |
if pipe is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
# Generate video | |
video_path, status = generate_video_core(prompt, negative_prompt, duration, image_path) | |
# Return video file | |
return FileResponse( | |
video_path, | |
media_type="video/mp4", | |
filename=f"generated_video_{int(time.time())}.mp4" | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"API generation error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
"""API documentation""" | |
return { | |
"message": "LTX-Video API", | |
"endpoints": { | |
"/generate_video": "POST - Generate video", | |
"/docs": "GET - API documentation" | |
}, | |
"curl_example": """ | |
curl -X POST "http://localhost:7860/generate_video" \\ | |
-F "prompt=A cat playing with a ball" \\ | |
-F "duration=4" \\ | |
-F "negative_prompt=blurry" \\ | |
-F "image=@your_image.jpg" \\ | |
--output generated_video.mp4 | |
""" | |
} | |
def run_api(): | |
"""Run FastAPI server""" | |
uvicorn.run(app, host="0.0.0.0", port=7861, log_level="info") | |
def main(): | |
"""Main function""" | |
# Load model | |
logger.info("Initializing LTX-Video Generator...") | |
model_loaded = load_model() | |
if not model_loaded: | |
logger.error("Failed to load model. Exiting.") | |
return | |
# Create Gradio interface | |
demo = create_gradio_interface() | |
# Start API server in a separate thread | |
api_thread = threading.Thread(target=run_api, daemon=True) | |
api_thread.start() | |
logger.info("API server started on http://localhost:7861") | |
# Launch Gradio interface | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_api=False | |
) | |
if __name__ == "__main__": | |
main() |