Test / api /ltx /vae_aduc_pipeline.py
eeuuia's picture
Upload 5 files
9a6b3d7 verified
raw
history blame
8.1 kB
# FILE: api/ltx/vae_aduc_pipeline.py
# DESCRIPTION: A high-level client for submitting VAE-related jobs to the LTXAducManager pool.
# It handles encoding media to latents, decoding latents to pixels, and creating ConditioningItems.
import logging
import time
import torch
import os
import torchvision.transforms.functional as TVF
from PIL import Image
from typing import List, Union, Tuple, Literal, Optional
from dataclasses import dataclass
from pathlib import Path
import sys
# O cliente importa o MANAGER para submeter os trabalhos ao pool de workers.
from api.ltx.ltx_aduc_manager import ltx_aduc_manager
# --- Adiciona o path do LTX-Video para importações de baixo nível ---
LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
def add_deps_to_path():
repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
if repo_path not in sys.path:
sys.path.insert(0, repo_path)
add_deps_to_path()
# Importações para anotação de tipos e para as funções de trabalho (jobs).
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
import ltx_video.pipelines.crf_compressor as crf_compressor
# ==============================================================================
# --- DEFINIÇÕES DE ESTRUTURA E HELPERS ---
# ==============================================================================
@dataclass
class LatentConditioningItem:
"""
Estrutura de dados para passar latentes condicionados entre serviços.
O tensor latente é mantido na CPU para economizar VRAM entre as etapas.
"""
latent_tensor: torch.Tensor
media_frame_number: int
conditioning_strength: float
def load_image_to_tensor_with_resize_and_crop(
image_input: Union[str, Image.Image],
target_height: int,
target_width: int,
) -> torch.Tensor:
"""
Carrega e processa uma imagem para um tensor de pixel 5D, normalizado para [-1, 1],
pronto para ser enviado ao VAE para encoding.
"""
if isinstance(image_input, str):
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
image = image_input.convert("RGB")
else:
raise ValueError("image_input must be a file path or a PIL Image object")
# Lógica de corte e redimensionamento para manter a proporção
input_width, input_height = image.size
aspect_ratio_target = target_width / target_height
aspect_ratio_frame = input_width / input_height
if aspect_ratio_frame > aspect_ratio_target:
new_width, new_height = int(input_height * aspect_ratio_target), input_height
x_start = (input_width - new_width) // 2
image = image.crop((x_start, 0, x_start + new_width, new_height))
else:
new_height = int(input_width / aspect_ratio_target)
y_start = (input_height - new_height) // 2
image = image.crop((0, y_start, input_width, y_start + new_height))
image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
# Conversão para tensor e normalização
frame_tensor = TVF.to_tensor(image)
frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
frame_tensor = (frame_tensor * 2.0) - 1.0
return frame_tensor.unsqueeze(0).unsqueeze(2)
# ==============================================================================
# --- FUNÇÕES DE TRABALHO (Jobs a serem executados no Pool de VAE) ---
# ==============================================================================
def _job_encode_media(vae: CausalVideoAutoencoder, pixel_tensor: torch.Tensor) -> torch.Tensor:
"""Job que codifica um tensor de pixel em um tensor latente."""
device = vae.device
dtype = vae.dtype
pixel_tensor_gpu = pixel_tensor.to(device, dtype=dtype)
latents = vae_encode(pixel_tensor_gpu, vae, vae_per_channel_normalize=True)
return latents.cpu()
def _job_decode_latent(vae: CausalVideoAutoencoder, latent_tensor: torch.Tensor) -> torch.Tensor:
"""Job que decodifica um tensor latente em um tensor de pixels."""
device = vae.device
dtype = vae.dtype
latent_tensor_gpu = latent_tensor.to(device, dtype=dtype)
pixels = vae_decode(latent_tensor_gpu, vae, is_video=True, vae_per_channel_normalize=True)
return pixels.cpu()
# ==============================================================================
# --- A CLASSE CLIENTE (Interface Pública) ---
# ==============================================================================
class VaeAducPipeline:
"""
Cliente de alto nível para orquestrar todas as tarefas relacionadas ao VAE.
Ele define a lógica de negócios e submete os trabalhos ao LTXAducManager.
"""
def __init__(self):
logging.info("✅ VAE ADUC Pipeline (Client) initialized and ready to submit jobs.")
pass
def __call__(
self,
media: Union[torch.Tensor, List[Union[Image.Image, str]]],
task: Literal['encode', 'decode', 'create_conditioning_items'],
target_resolution: Optional[Tuple[int, int]] = (512, 512),
conditioning_params: Optional[List[Tuple[int, float]]] = None
) -> Union[List[torch.Tensor], torch.Tensor, List[LatentConditioningItem]]:
"""
Ponto de entrada principal para executar tarefas de VAE.
Args:
media: O dado de entrada.
task: A tarefa a executar ('encode', 'decode', 'create_conditioning_items').
target_resolution: A resolução (altura, largura) para o pré-processamento.
conditioning_params: Para 'create_conditioning_items', uma lista de tuplas
(frame_number, strength) para cada item de mídia.
Returns:
O resultado da tarefa, sempre na CPU.
"""
t0 = time.time()
logging.info(f"VAE Client received a '{task}' job.")
if task == 'encode':
if not isinstance(media, list): media = [media]
pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, target_resolution[0], target_resolution[1]) for m in media]
results = [ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_encode_media, pixel_tensor=pt) for pt in pixel_tensors]
return results
elif task == 'decode':
if not isinstance(media, torch.Tensor):
raise TypeError("Para a tarefa 'decode', 'media' deve ser um único tensor latente.")
return ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_decode_latent, latent_tensor=media)
elif task == 'create_conditioning_items':
if not isinstance(media, list) or not isinstance(conditioning_params, list) or len(media) != len(conditioning_params):
raise ValueError("Para 'create_conditioning_items', 'media' e 'conditioning_params' devem ser listas de mesmo tamanho.")
pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, target_resolution[0], target_resolution[1]) for m in media]
conditioning_items = []
for i, pt in enumerate(pixel_tensors):
latent_tensor = ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_encode_media, pixel_tensor=pt)
frame_number, strength = conditioning_params[i]
conditioning_items.append(LatentConditioningItem(
latent_tensor=latent_tensor,
media_frame_number=frame_number,
conditioning_strength=strength
))
return conditioning_items
else:
raise ValueError(f"Tarefa desconhecida: '{task}'. Opções: 'encode', 'decode', 'create_conditioning_items'.")
# --- INSTÂNCIA SINGLETON DO CLIENTE ---
try:
vae_aduc_pipeline = VaeAducPipeline()
except Exception as e:
logging.critical("CRITICAL: Failed to initialize the VaeAducPipeline client.", exc_info=True)
vae_aduc_pipeline = None