Carlexxx
feat: ✨ aBINC 2.2
fb56537
raw
history blame
4.12 kB
# aduc_framework/managers/vae_manager.py
#
# Versão 2.1.0 (Correção de Timestep no Decode)
# Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
#
# - Corrige um `AssertionError` na função `decode` ao não passar o argumento
# `timestep` esperado pelo decodificador do VAE.
# - Adiciona um `timestep` padrão (0.05) para a decodificação, garantindo
# uma reconstrução de imagem limpa e estável.
import torch
import logging
import gc
import yaml
from typing import List
from PIL import Image
import numpy as np
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
from ..tools.hardware_manager import hardware_manager
logger = logging.getLogger(__name__)
class VaeManager:
"""
Especialista VAE "Hot" e Persistente.
Carrega o modelo VAE em uma GPU dedicada uma única vez e o mantém lá,
pronto para processar requisições de encode/decode com latência mínima.
"""
def __init__(self):
with open("config.yaml", 'r') as f:
config = yaml.safe_load(f)
gpus_required = config['specialists'].get('vae', {}).get('gpus_required', 0)
if gpus_required > 0 and torch.cuda.is_available():
device_id = hardware_manager.allocate_gpus('VAE_Manager', gpus_required)[0]
self.device = torch.device(device_id)
logger.info(f"VaeManager: GPU dedicada '{device_id}' alocada.")
else:
self.device = torch.device('cpu')
logger.warning("VaeManager: Nenhuma GPU dedicada foi alocada no config.yaml. Operando em modo CPU.")
try:
from ..managers.ltx_manager import ltx_manager_singleton
self.vae = ltx_manager_singleton.workers[0].pipeline.vae
except ImportError as e:
logger.critical("Falha ao importar ltx_manager_singleton. Garanta que VaeManager seja importado DEPOIS de LtxManager.", exc_info=True)
raise e
self.vae.to(self.device)
self.vae.eval()
self.dtype = self.vae.dtype
logger.info(f"VaeManager inicializado. Modelo VAE está 'quente' e pronto na {self.device} com dtype {self.dtype}.")
def _preprocess_pil_image(self, pil_image: Image.Image, target_resolution: tuple) -> torch.Tensor:
from PIL import ImageOps
img = pil_image.convert("RGB")
processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS)
image_np = np.array(processed_img).astype(np.float32) / 255.0
tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0).unsqueeze(2)
return (tensor * 2.0) - 1.0
@torch.no_grad()
def encode_batch(self, pil_images: List[Image.Image], target_resolution: tuple) -> List[torch.Tensor]:
if not pil_images:
return []
latents_list = []
for img in pil_images:
pixel_tensor = self._preprocess_pil_image(img, target_resolution)
pixel_tensor_gpu = pixel_tensor.to(self.device, dtype=self.dtype)
latents = vae_encode(pixel_tensor_gpu, self.vae, vae_per_channel_normalize=True)
latents_list.append(latents.cpu())
return latents_list
@torch.no_grad()
def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
"""Decodifica um tensor latente para o espaço de pixels."""
latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
# --- CORREÇÃO APLICADA AQUI ---
# O modelo espera um tensor de timestep, um para cada item no batch.
num_items_in_batch = latent_tensor_gpu.shape[0]
timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=self.dtype)
pixels = vae_decode(
latent_tensor_gpu,
self.vae,
is_video=True,
timestep=timestep_tensor, # Passando o tensor de timestep
vae_per_channel_normalize=True
)
# --- FIM DA CORREÇÃO ---
return pixels.cpu()
# --- Instância Singleton ---
vae_manager_singleton = VaeManager()