|
from typing import Dict, Any, Union, Optional, Tuple |
|
import torch |
|
from diffusers import LTXPipeline, LTXImageToVideoPipeline |
|
from PIL import Image |
|
import base64 |
|
import io |
|
import tempfile |
|
import numpy as np |
|
from moviepy.editor import ImageSequenceClip |
|
import os |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class EndpointHandler: |
|
|
|
DEFAULT_FPS = 24 |
|
DEFAULT_DURATION = 4 |
|
DEFAULT_NUM_FRAMES = (DEFAULT_DURATION * DEFAULT_FPS) + 1 |
|
DEFAULT_NUM_STEPS = 25 |
|
DEFAULT_WIDTH = 768 |
|
DEFAULT_HEIGHT = 512 |
|
|
|
|
|
MAX_WIDTH = 1280 |
|
MAX_HEIGHT = 720 |
|
MAX_FRAMES = 257 |
|
|
|
ENABLE_CPU_OFFLOAD = True |
|
EXPERIMENTAL_STUFF = False |
|
|
|
def __init__(self, path: str = ""): |
|
"""Initialize the LTX Video handler with both text-to-video and image-to-video pipelines. |
|
|
|
Args: |
|
path (str): Path to the model weights directory |
|
""" |
|
if EXPERIMENTAL_STUFF: |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
self.text_to_video = LTXPipeline.from_pretrained( |
|
path, |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
self.image_to_video = LTXImageToVideoPipeline.from_pretrained( |
|
path, |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
if ENABLE_CPU_OFFLOAD: |
|
self.text_to_video.enable_model_cpu_offload() |
|
self.image_to_video.enable_model_cpu_offload() |
|
|
|
def _validate_and_adjust_resolution(self, width: int, height: int) -> Tuple[int, int]: |
|
"""Validate and adjust resolution to meet constraints. |
|
|
|
Args: |
|
width (int): Requested width |
|
height (int): Requested height |
|
|
|
Returns: |
|
Tuple[int, int]: Adjusted (width, height) |
|
""" |
|
|
|
width = round(width / 32) * 32 |
|
height = round(height / 32) * 32 |
|
|
|
|
|
width = min(width, self.MAX_WIDTH) |
|
height = min(height, self.MAX_HEIGHT) |
|
|
|
|
|
width = max(width, 32) |
|
height = max(height, 32) |
|
|
|
return width, height |
|
|
|
def _validate_and_adjust_frames(self, num_frames: Optional[int] = None, fps: Optional[int] = None) -> Tuple[int, int]: |
|
"""Validate and adjust frame count and FPS to meet constraints. |
|
|
|
Args: |
|
num_frames (Optional[int]): Requested number of frames |
|
fps (Optional[int]): Requested frames per second |
|
|
|
Returns: |
|
Tuple[int, int]: Adjusted (num_frames, fps) |
|
""" |
|
|
|
fps = fps or self.DEFAULT_FPS |
|
num_frames = num_frames or self.DEFAULT_NUM_FRAMES |
|
|
|
|
|
k = (num_frames - 1) // 8 |
|
num_frames = (k * 8) + 1 |
|
|
|
|
|
num_frames = min(num_frames, self.MAX_FRAMES) |
|
|
|
return num_frames, fps |
|
|
|
def _create_video_file(self, frames: torch.Tensor, fps: int = DEFAULT_FPS) -> bytes: |
|
"""Convert frames to an MP4 video file. |
|
|
|
Args: |
|
frames (torch.Tensor): Generated frames tensor |
|
fps (int): Frames per second for the output video |
|
|
|
Returns: |
|
bytes: MP4 video file content |
|
""" |
|
|
|
num_frames = frames.shape[1] |
|
duration = num_frames / fps |
|
logger.info(f"Creating video with {num_frames} frames at {fps} FPS (duration: {duration:.2f} seconds)") |
|
|
|
|
|
video_np = frames.squeeze(0).permute(0, 2, 3, 1).cpu().float().numpy() |
|
video_np = (video_np * 255).astype(np.uint8) |
|
|
|
|
|
_, height, width, _ = video_np.shape |
|
logger.info(f"Video dimensions: {width}x{height}") |
|
|
|
|
|
output_path = tempfile.mktemp(suffix=".mp4") |
|
|
|
try: |
|
|
|
clip = ImageSequenceClip(list(video_np), fps=fps) |
|
clip.write_videofile(output_path, codec="libx264", audio=False) |
|
|
|
|
|
with open(output_path, "rb") as f: |
|
video_content = f.read() |
|
|
|
return video_content |
|
|
|
finally: |
|
|
|
if os.path.exists(output_path): |
|
os.remove(output_path) |
|
|
|
|
|
del video_np |
|
torch.cuda.empty_cache() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Process the input data and generate video using LTX. |
|
|
|
Args: |
|
data (Dict[str, Any]): Input data containing: |
|
- prompt (str): Text description for video generation |
|
- image (Optional[str]): Base64 encoded image for image-to-video generation |
|
- width (Optional[int]): Video width (default: 768) |
|
- height (Optional[int]): Video height (default: 512) |
|
- num_frames (Optional[int]): Number of frames (default: 97) |
|
- fps (Optional[int]): Frames per second (default: 24) |
|
- num_inference_steps (Optional[int]): Number of inference steps (default: 25) |
|
- guidance_scale (Optional[float]): Guidance scale (default: 7.5) |
|
|
|
Returns: |
|
Dict[str, Any]: Dictionary containing: |
|
- video: video encoded in Base64 (h.264 MP4 video). This is a data-uri (prefixed with "data:"). |
|
- content-type: MIME type of the video (right now always "video/mp4") |
|
- metadata: Dictionary with actual values used for generation |
|
""" |
|
|
|
prompt = data.get("inputs", None) |
|
if not prompt: |
|
raise ValueError("No prompt provided in the 'inputs' field") |
|
|
|
|
|
width = data.get("width", self.DEFAULT_WIDTH) |
|
height = data.get("height", self.DEFAULT_HEIGHT) |
|
width, height = self._validate_and_adjust_resolution(width, height) |
|
|
|
|
|
num_frames = data.get("num_frames", self.DEFAULT_NUM_FRAMES) |
|
fps = data.get("fps", self.DEFAULT_FPS) |
|
num_frames, fps = self._validate_and_adjust_frames(num_frames, fps) |
|
|
|
|
|
guidance_scale = data.get("guidance_scale", 7.5) |
|
num_inference_steps = data.get("num_inference_steps", self.DEFAULT_NUM_STEPS) |
|
seed = data.get("seed", -1) |
|
seed = None if seed == -1 else int(seed) |
|
|
|
logger.info(f"Generating video with prompt: '{prompt}'") |
|
logger.info(f"Parameters: size={width}x{height}, num_frames={num_frames}, fps={fps}") |
|
logger.info(f"Additional params: guidance_scale={guidance_scale}, num_inference_steps={num_inference_steps}") |
|
|
|
try: |
|
with torch.no_grad(): |
|
generation_kwargs = { |
|
"prompt": prompt, |
|
"height": height, |
|
"width": width, |
|
"num_frames": num_frames, |
|
"guidance_scale": guidance_scale, |
|
"num_inference_steps": num_inference_steps, |
|
"output_type": "pt" |
|
} |
|
|
|
|
|
image_data = data.get("image") |
|
if image_data: |
|
if image_data.startswith('data:'): |
|
image_data = image_data.split(',', 1)[1] |
|
image_bytes = base64.b64decode(image_data) |
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
logger.info("Using image-to-video generation mode") |
|
generation_kwargs["image"] = image |
|
output = self.image_to_video(**generation_kwargs).frames |
|
else: |
|
logger.info("Using text-to-video generation mode") |
|
output = self.text_to_video(**generation_kwargs).frames |
|
|
|
|
|
video_content = self._create_video_file(output, fps=fps) |
|
|
|
|
|
video_base64 = base64.b64encode(video_content).decode('utf-8') |
|
|
|
content_type = "video/mp4" |
|
|
|
|
|
video_data_uri = f"data:{content_type};base64,{video_base64}" |
|
|
|
return { |
|
"video": video_data_uri, |
|
"content-type": content_type, |
|
"metadata": { |
|
"width": width, |
|
"height": height, |
|
"num_frames": num_frames, |
|
"fps": fps, |
|
"duration": num_frames / fps, |
|
"num_inference_steps": num_inference_steps |
|
} |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating video: {str(e)}") |
|
raise RuntimeError(f"Error generating video: {str(e)}") |