Spaces:
Paused
Paused
| # aduc_framework/managers/vae_manager.py | |
| # | |
| # Versão 2.1.0 (Correção de Timestep no Decode) | |
| # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos | |
| # | |
| # - Corrige um `AssertionError` na função `decode` ao não passar o argumento | |
| # `timestep` esperado pelo decodificador do VAE. | |
| # - Adiciona um `timestep` padrão (0.05) para a decodificação, garantindo | |
| # uma reconstrução de imagem limpa e estável. | |
| import torch | |
| import logging | |
| import gc | |
| import yaml | |
| from typing import List | |
| from PIL import Image | |
| import numpy as np | |
| from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode | |
| from ..tools.hardware_manager import hardware_manager | |
| logger = logging.getLogger(__name__) | |
| class VaeManager: | |
| """ | |
| Especialista VAE "Hot" e Persistente. | |
| Carrega o modelo VAE em uma GPU dedicada uma única vez e o mantém lá, | |
| pronto para processar requisições de encode/decode com latência mínima. | |
| """ | |
| def __init__(self): | |
| with open("config.yaml", 'r') as f: | |
| config = yaml.safe_load(f) | |
| gpus_required = config['specialists'].get('vae', {}).get('gpus_required', 0) | |
| if gpus_required > 0 and torch.cuda.is_available(): | |
| device_id = hardware_manager.allocate_gpus('VAE_Manager', gpus_required)[0] | |
| self.device = torch.device(device_id) | |
| logger.info(f"VaeManager: GPU dedicada '{device_id}' alocada.") | |
| else: | |
| self.device = torch.device('cpu') | |
| logger.warning("VaeManager: Nenhuma GPU dedicada foi alocada no config.yaml. Operando em modo CPU.") | |
| try: | |
| from ..managers.ltx_manager import ltx_manager_singleton | |
| self.vae = ltx_manager_singleton.workers[0].pipeline.vae | |
| except ImportError as e: | |
| logger.critical("Falha ao importar ltx_manager_singleton. Garanta que VaeManager seja importado DEPOIS de LtxManager.", exc_info=True) | |
| raise e | |
| self.vae.to(self.device) | |
| self.vae.eval() | |
| self.dtype = self.vae.dtype | |
| logger.info(f"VaeManager inicializado. Modelo VAE está 'quente' e pronto na {self.device} com dtype {self.dtype}.") | |
| def _preprocess_pil_image(self, pil_image: Image.Image, target_resolution: tuple) -> torch.Tensor: | |
| from PIL import ImageOps | |
| img = pil_image.convert("RGB") | |
| processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS) | |
| image_np = np.array(processed_img).astype(np.float32) / 255.0 | |
| tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0).unsqueeze(2) | |
| return (tensor * 2.0) - 1.0 | |
| def encode_batch(self, pil_images: List[Image.Image], target_resolution: tuple) -> List[torch.Tensor]: | |
| if not pil_images: | |
| return [] | |
| latents_list = [] | |
| for img in pil_images: | |
| pixel_tensor = self._preprocess_pil_image(img, target_resolution) | |
| pixel_tensor_gpu = pixel_tensor.to(self.device, dtype=self.dtype) | |
| latents = vae_encode(pixel_tensor_gpu, self.vae, vae_per_channel_normalize=True) | |
| latents_list.append(latents.cpu()) | |
| return latents_list | |
| def decode(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor: | |
| """Decodifica um tensor latente para o espaço de pixels.""" | |
| latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype) | |
| # --- CORREÇÃO APLICADA AQUI --- | |
| # O modelo espera um tensor de timestep, um para cada item no batch. | |
| num_items_in_batch = latent_tensor_gpu.shape[0] | |
| timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=self.dtype) | |
| pixels = vae_decode( | |
| latent_tensor_gpu, | |
| self.vae, | |
| is_video=True, | |
| timestep=timestep_tensor, # Passando o tensor de timestep | |
| vae_per_channel_normalize=True | |
| ) | |
| # --- FIM DA CORREÇÃO --- | |
| return pixels.cpu() | |
| # --- Instância Singleton --- | |
| vae_manager_singleton = VaeManager() |