from typing import Dict, Any, Union, Optional, Tuple import torch from diffusers import LTXPipeline, LTXImageToVideoPipeline from PIL import Image import base64 import io import tempfile import numpy as np from moviepy.editor import ImageSequenceClip import os import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: # Default configuration DEFAULT_FPS = 24 DEFAULT_DURATION = 4 # seconds DEFAULT_NUM_FRAMES = (DEFAULT_DURATION * DEFAULT_FPS) + 1 # 97 frames DEFAULT_NUM_STEPS = 25 DEFAULT_WIDTH = 768 DEFAULT_HEIGHT = 512 # Constraints MAX_WIDTH = 1280 MAX_HEIGHT = 720 MAX_FRAMES = 257 ENABLE_CPU_OFFLOAD = True EXPERIMENTAL_STUFF = False def __init__(self, path: str = ""): """Initialize the LTX Video handler with both text-to-video and image-to-video pipelines. Args: path (str): Path to the model weights directory """ if EXPERIMENTAL_STUFF: torch.backends.cuda.matmul.allow_tf32 = True # Load both pipelines with bfloat16 precision as recommended in docs self.text_to_video = LTXPipeline.from_pretrained( path, torch_dtype=torch.bfloat16 ).to("cuda") self.image_to_video = LTXImageToVideoPipeline.from_pretrained( path, torch_dtype=torch.bfloat16 ).to("cuda") if ENABLE_CPU_OFFLOAD: self.text_to_video.enable_model_cpu_offload() self.image_to_video.enable_model_cpu_offload() def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]: """Validate and adjust resolution to meet constraints. Args: width (int): Requested width height (int): Requested height Returns: Tuple[int, int]: Adjusted (width, height) """ # Round to nearest multiple of 32 width = round(width / 32) * 32 height = round(height / 32) * 32 # Enforce maximum dimensions width = min(width, self.MAX_WIDTH) height = min(height, self.MAX_HEIGHT) # Enforce minimum dimensions width = max(width, 32) height = max(height, 32) return width, height def _validate_and_adjust_frames(self, num_frames: Optional[int] = None, fps: Optional[int] = None) -> Tuple[int, int]: """Validate and adjust frame count and FPS to meet constraints. Args: num_frames (Optional[int]): Requested number of frames fps (Optional[int]): Requested frames per second Returns: Tuple[int, int]: Adjusted (num_frames, fps) """ # Use defaults if not provided fps = fps or self.DEFAULT_FPS num_frames = num_frames or self.DEFAULT_NUM_FRAMES # Adjust frames to be in format 8k + 1 k = (num_frames - 1) // 8 num_frames = (k * 8) + 1 # Enforce maximum frame count num_frames = min(num_frames, self.MAX_FRAMES) return num_frames, fps def _create_video_file(self, frames: torch.Tensor, fps: int = DEFAULT_FPS) -> bytes: """Convert frames to an MP4 video file. Args: frames (torch.Tensor): Generated frames tensor fps (int): Frames per second for the output video Returns: bytes: MP4 video file content """ # Log frame information num_frames = frames.shape[1] duration = num_frames / fps logger.info(f"Creating video with {num_frames} frames at {fps} FPS (duration: {duration:.2f} seconds)") # Convert tensor to numpy array video_np = frames.squeeze(0).permute(0, 2, 3, 1).cpu().float().numpy() video_np = (video_np * 255).astype(np.uint8) # Get dimensions _, height, width, _ = video_np.shape logger.info(f"Video dimensions: {width}x{height}") # Create temporary file output_path = tempfile.mktemp(suffix=".mp4") try: # Create video clip and write to file clip = ImageSequenceClip(list(video_np), fps=fps) clip.write_videofile(output_path, codec="libx264", audio=False) # Read the video file with open(output_path, "rb") as f: video_content = f.read() return video_content finally: # Cleanup if os.path.exists(output_path): os.remove(output_path) # Clear memory del video_np torch.cuda.empty_cache() def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """Process the input data and generate video using LTX. Args: data (Dict[str, Any]): Input data containing: - prompt (str): Text description for video generation - image (Optional[str]): Base64 encoded image for image-to-video generation - width (Optional[int]): Video width (default: 768) - height (Optional[int]): Video height (default: 512) - num_frames (Optional[int]): Number of frames (default: 97) - fps (Optional[int]): Frames per second (default: 24) - num_inference_steps (Optional[int]): Number of inference steps (default: 25) - guidance_scale (Optional[float]): Guidance scale (default: 7.5) Returns: Dict[str, Any]: Dictionary containing: - video: video encoded in Base64 (h.264 MP4 video). This is a data-uri (prefixed with "data:"). - content-type: MIME type of the video (right now always "video/mp4") - metadata: Dictionary with actual values used for generation """ # Get inputs from request data prompt = data.pop("inputs", None) if not prompt: raise ValueError("No prompt provided in the 'inputs' field") # Get and validate resolution width = data.pop("width", self.DEFAULT_WIDTH) height = data.pop("height", self.DEFAULT_HEIGHT) width, height = self._validate_and_adjust_resolution(width, height) # Get and validate frames and FPS num_frames = data.pop("num_frames", self.DEFAULT_NUM_FRAMES) fps = data.pop("fps", self.DEFAULT_FPS) num_frames, fps = self._validate_and_adjust_frames(num_frames, fps) # Get other parameters with defaults guidance_scale = data.pop("guidance_scale", 7.5) num_inference_steps = data.pop("num_inference_steps", self.DEFAULT_NUM_STEPS) seed = data.pop("seed", -1) seed = None if seed == -1 else int(seed) logger.info(f"Generating video with prompt: '{prompt}'") logger.info(f"Parameters: size={width}x{height}, num_frames={num_frames}, fps={fps}") logger.info(f"Additional params: guidance_scale={guidance_scale}, num_inference_steps={num_inference_steps}") try: with torch.no_grad(): generation_kwargs = { "prompt": prompt, "height": height, "width": width, "num_frames": num_frames, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "output_type": "pt" } # Check if image is provided for image-to-video generation image_data = data.get("image") if image_data: # Decode base64 image image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") logger.info("Using image-to-video generation mode") generation_kwargs["image"] = image output = self.image_to_video(**generation_kwargs).frames else: logger.info("Using text-to-video generation mode") output = self.text_to_video(**generation_kwargs).frames # Convert frames to video file video_content = self._create_video_file(output, fps=fps) # Encode video to base64 video_base64 = base64.b64encode(video_content).decode('utf-8') content_type = "video/mp4" # Add MP4 data URI prefix video_data_uri = f"data:{content_type};base64,{video_base64}" return { "video": video_data_uri, "content-type": content_type, "metadata": { "width": width, "height": height, "num_frames": num_frames, "fps": fps, "duration": num_frames / fps, "num_inference_steps": num_inference_steps } } except Exception as e: logger.error(f"Error generating video: {str(e)}") raise RuntimeError(f"Error generating video: {str(e)}")