eeuuia commited on
Commit
cc8649a
·
verified ·
1 Parent(s): b42e494

Update api/ltx/vae_aduc_pipeline.py

Browse files
Files changed (1) hide show
  1. api/ltx/vae_aduc_pipeline.py +139 -152
api/ltx/vae_aduc_pipeline.py CHANGED
@@ -1,177 +1,164 @@
1
  # FILE: api/ltx/vae_aduc_pipeline.py
2
- # DESCRIPTION: A high-level client for submitting VAE-related jobs to the LTXAducManager pool.
3
- # It handles encoding media to latents, decoding latents to pixels, and creating ConditioningItems.
 
4
 
5
- import logging
6
- import time
7
- import torch
8
  import os
9
- import torchvision.transforms.functional as TVF
10
- from PIL import Image
11
- from typing import List, Union, Tuple, Literal, Optional
12
- from dataclasses import dataclass
13
- from pathlib import Path
14
  import sys
 
 
 
 
 
15
 
16
- # O cliente importa o MANAGER para submeter os trabalhos ao pool de workers.
17
- from api.ltx.ltx_aduc_manager import ltx_aduc_manager
18
-
19
- # --- Adiciona o path do LTX-Video para importações de baixo nível ---
20
- LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
21
- def add_deps_to_path():
22
- repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
23
- if repo_path not in sys.path:
24
- sys.path.insert(0, repo_path)
25
- add_deps_to_path()
26
-
27
- # Importações para anotação de tipos e para as funções de trabalho (jobs).
28
- from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
29
- from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
30
- import ltx_video.pipelines.crf_compressor as crf_compressor
31
 
32
  # ==============================================================================
33
- # --- DEFINIÇÕES DE ESTRUTURA E HELPERS ---
34
  # ==============================================================================
35
-
36
- @dataclass
37
- class LatentConditioningItem:
38
- """
39
- Estrutura de dados para passar latentes condicionados entre serviços.
40
- O tensor latente é mantido na CPU para economizar VRAM entre as etapas.
41
- """
42
- latent_tensor: torch.Tensor
43
- media_frame_number: int
44
- conditioning_strength: float
45
-
46
- def load_image_to_tensor_with_resize_and_crop(
47
- image_input: Union[str, Image.Image],
48
- target_height: int,
49
- target_width: int,
50
- ) -> torch.Tensor:
51
- """
52
- Carrega e processa uma imagem para um tensor de pixel 5D, normalizado para [-1, 1],
53
- pronto para ser enviado ao VAE para encoding.
54
- """
55
- if isinstance(image_input, str):
56
- image = Image.open(image_input).convert("RGB")
57
- elif isinstance(image_input, Image.Image):
58
- image = image_input.convert("RGB")
59
- else:
60
- raise ValueError("image_input must be a file path or a PIL Image object")
61
-
62
- # Lógica de corte e redimensionamento para manter a proporção
63
- input_width, input_height = image.size
64
- aspect_ratio_target = target_width / target_height
65
- aspect_ratio_frame = input_width / input_height
66
- if aspect_ratio_frame > aspect_ratio_target:
67
- new_width, new_height = int(input_height * aspect_ratio_target), input_height
68
- x_start = (input_width - new_width) // 2
69
- image = image.crop((x_start, 0, x_start + new_width, new_height))
70
- else:
71
- new_height = int(input_width / aspect_ratio_target)
72
- y_start = (input_height - new_height) // 2
73
- image = image.crop((0, y_start, input_width, y_start + new_height))
74
-
75
- image = image.resize((target_width, target_height), Image.Resampling.LANCZOS)
76
-
77
- # Conversão para tensor e normalização
78
- frame_tensor = TVF.to_tensor(image)
79
- frame_tensor_hwc = frame_tensor.permute(1, 2, 0)
80
- frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
81
- frame_tensor = frame_tensor_hwc.permute(2, 0, 1)
82
 
83
- frame_tensor = (frame_tensor * 2.0) - 1.0
84
- return frame_tensor.unsqueeze(0).unsqueeze(2)
 
 
 
85
 
86
  # ==============================================================================
87
- # --- FUNÇÕES DE TRABALHO (Jobs a serem executados no Pool de VAE) ---
88
  # ==============================================================================
89
 
90
- def _job_encode_media(vae: CausalVideoAutoencoder, pixel_tensor: torch.Tensor) -> torch.Tensor:
91
- """Job que codifica um tensor de pixel em um tensor latente."""
92
- device = vae.device
93
- dtype = vae.dtype
94
- pixel_tensor_gpu = pixel_tensor.to(device, dtype=dtype)
95
- latents = vae_encode(pixel_tensor_gpu, vae, vae_per_channel_normalize=True)
96
- return latents.cpu()
97
 
98
- def _job_decode_latent(vae: CausalVideoAutoencoder, latent_tensor: torch.Tensor) -> torch.Tensor:
99
- """Job que decodifica um tensor latente em um tensor de pixels."""
100
- device = vae.device
101
- dtype = vae.dtype
102
- latent_tensor_gpu = latent_tensor.to(device, dtype=dtype)
103
- pixels = vae_decode(latent_tensor_gpu, vae, is_video=True, vae_per_channel_normalize=True)
104
- return pixels.cpu()
105
 
106
- # ==============================================================================
107
- # --- A CLASSE CLIENTE (Interface Pública) ---
108
- # ==============================================================================
109
-
110
- class VaeAducPipeline:
111
- """
112
- Cliente de alto nível para orquestrar todas as tarefas relacionadas ao VAE.
113
- Ele define a lógica de negócios e submete os trabalhos ao LTXAducManager.
114
- """
115
  def __init__(self):
116
- logging.info("✅ VAE ADUC Pipeline (Client) initialized and ready to submit jobs.")
117
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- def __call__(
 
120
  self,
121
- media: Union[torch.Tensor, List[Union[Image.Image, str]]],
122
- task: Literal['encode', 'decode', 'create_conditioning_items'],
123
- target_resolution: Optional[Tuple[int, int]] = (512, 512),
124
- conditioning_params: Optional[List[Tuple[int, float]]] = None
125
- ) -> Union[List[torch.Tensor], torch.Tensor, List[LatentConditioningItem]]:
126
  """
127
- Ponto de entrada principal para executar tarefas de VAE.
128
-
129
- Args:
130
- media: O dado de entrada.
131
- task: A tarefa a executar ('encode', 'decode', 'create_conditioning_items').
132
- target_resolution: A resolução (altura, largura) para o pré-processamento.
133
- conditioning_params: Para 'create_conditioning_items', uma lista de tuplas
134
- (frame_number, strength) para cada item de mídia.
135
-
136
- Returns:
137
- O resultado da tarefa, sempre na CPU.
138
  """
139
  t0 = time.time()
140
- logging.info(f"VAE Client received a '{task}' job.")
141
-
142
- if task == 'encode':
143
- if not isinstance(media, list): media = [media]
144
- pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, target_resolution[0], target_resolution[1]) for m in media]
145
- results = [ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_encode_media, pixel_tensor=pt) for pt in pixel_tensors]
146
- return results
147
-
148
- elif task == 'decode':
149
- if not isinstance(media, torch.Tensor):
150
- raise TypeError("Para a tarefa 'decode', 'media' deve ser um único tensor latente.")
151
- return ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_decode_latent, latent_tensor=media)
152
-
153
- elif task == 'create_conditioning_items':
154
- if not isinstance(media, list) or not isinstance(conditioning_params, list) or len(media) != len(conditioning_params):
155
- raise ValueError("Para 'create_conditioning_items', 'media' e 'conditioning_params' devem ser listas de mesmo tamanho.")
156
-
157
- pixel_tensors = [load_image_to_tensor_with_resize_and_crop(m, target_resolution[0], target_resolution[1]) for m in media]
158
- conditioning_items = []
159
- for i, pt in enumerate(pixel_tensors):
160
- latent_tensor = ltx_aduc_manager.submit_job(job_type='vae', job_func=_job_encode_media, pixel_tensor=pt)
161
- frame_number, strength = conditioning_params[i]
162
- conditioning_items.append(LatentConditioningItem(
163
- latent_tensor=latent_tensor,
164
- media_frame_number=frame_number,
165
- conditioning_strength=strength
166
- ))
167
  return conditioning_items
168
-
169
- else:
170
- raise ValueError(f"Tarefa desconhecida: '{task}'. Opções: 'encode', 'decode', 'create_conditioning_items'.")
171
 
172
- # --- INSTÂNCIA SINGLETON DO CLIENTE ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  try:
174
- vae_aduc_pipeline = VaeAducPipeline()
 
 
 
 
 
175
  except Exception as e:
176
- logging.critical("CRITICAL: Failed to initialize the VaeAducPipeline client.", exc_info=True)
177
- vae_aduc_pipeline = None
 
1
  # FILE: api/ltx/vae_aduc_pipeline.py
2
+ # DESCRIPTION: A dedicated, "hot" VAE service specialist.
3
+ # It holds the VAE model on a dedicated GPU and provides high-level services
4
+ # for encoding images/tensors into conditioning items and decoding latents back to pixels.
5
 
 
 
 
6
  import os
 
 
 
 
 
7
  import sys
8
+ import time
9
+ import logging
10
+ import threading
11
+ from pathlib import Path
12
+ from typing import List, Union, Tuple
13
 
14
+ import torch
15
+ import numpy as np
16
+ from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # ==============================================================================
19
+ # --- IMPORTAÇÕES DA ARQUITETURA E DO LTX ---
20
  # ==============================================================================
21
+ try:
22
+ from api.ltx.ltx_aduc_manager import LatentConditioningItem
23
+ from api.managers.gpu_manager import gpu_manager
24
+ # Adiciona o path para as bibliotecas do LTX
25
+ LTX_VIDEO_REPO_DIR = Path("/data/LTX-Video")
26
+ if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
27
+ sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
30
+ from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
31
+ # Nossos data classes customizados para condicionamento, importados do pool manager
32
+ except ImportError as e:
33
+ raise ImportError(f"A crucial import failed for VaeServer. Check dependencies. Error: {e}")
34
 
35
  # ==============================================================================
36
+ # --- CLASSE DO SERVIÇO VAE ---
37
  # ==============================================================================
38
 
39
+ class VaeServer:
40
+ _instance = None
41
+ _lock = threading.Lock()
 
 
 
 
42
 
43
+ def __new__(cls, *args, **kwargs):
44
+ with cls._lock:
45
+ if cls._instance is None:
46
+ cls._instance = super().__new__(cls)
47
+ cls._instance._initialized = False
48
+ return cls._instance
 
49
 
 
 
 
 
 
 
 
 
 
50
  def __init__(self):
51
+ if self._initialized: return
52
+ with self._lock:
53
+ if self._initialized: return
54
+
55
+ logging.info("⚙️ Initializing VaeServer Singleton...")
56
+ t0 = time.time()
57
+
58
+ # 1. Obter o dispositivo VAE dedicado do gerenciador central
59
+ self.device = gpu_manager.get_ltx_vae_device()
60
+
61
+ # 2. Obter o modelo VAE já carregado pelo LTXPoolManager
62
+ # Isso garante consistência e evita carregar o modelo duas vezes.
63
+ try:
64
+ from api.ltx.ltx_aduc_manager import ltx_pool_manager
65
+ if ltx_pool_manager is None or ltx_pool_manager.get_pipeline() is None:
66
+ raise RuntimeError("LTXPoolManager is not initialized yet. VaeServer must be initialized after.")
67
+ self.vae = ltx_pool_manager.get_pipeline().vae
68
+ except Exception as e:
69
+ logging.critical(f"Failed to get VAE from LTXPoolManager. Error: {e}", exc_info=True)
70
+ raise
71
+
72
+ # 3. Garante que o VAE está no dispositivo correto e em modo de avaliação
73
+ self.vae.to(self.device)
74
+ self.vae.eval()
75
+ self.dtype = self.vae.dtype
76
+
77
+ self._initialized = True
78
+ logging.info(f"✅ VaeServer ready. VAE model is 'hot' on {self.device} with dtype {self.dtype}. Startup time: {time.time() - t0:.2f}s")
79
+
80
+ def _cleanup_gpu(self):
81
+ """Limpa a VRAM da GPU do VAE."""
82
+ if torch.cuda.is_available():
83
+ with torch.cuda.device(self.device):
84
+ torch.cuda.empty_cache()
85
+
86
+ def _preprocess_input(self, item: Union[Image.Image, torch.Tensor], target_resolution: Tuple[int, int]) -> torch.Tensor:
87
+ """Prepara uma imagem PIL ou um tensor para o formato de pixel que o VAE espera para encodar."""
88
+ if isinstance(item, Image.Image):
89
+ from PIL import ImageOps
90
+ img = item.convert("RGB")
91
+ processed_img = ImageOps.fit(img, target_resolution, Image.Resampling.LANCZOS)
92
+ image_np = np.array(processed_img).astype(np.float32) / 255.0
93
+ tensor = torch.from_numpy(image_np).permute(2, 0, 1) # HWC -> CHW
94
+ elif isinstance(item, torch.Tensor):
95
+ if item.ndim == 4 and item.shape[0] == 1: tensor = item.squeeze(0)
96
+ elif item.ndim == 3: tensor = item
97
+ else: raise ValueError(f"Input tensor must have 3 or 4 dimensions (CHW or BCHW), but got {item.ndim}")
98
+ else:
99
+ raise TypeError(f"Input must be a PIL Image or a torch.Tensor, but got {type(item)}")
100
+
101
+ # Converte para 5D (B, C, F, H, W) e normaliza para [-1, 1]
102
+ tensor_5d = tensor.unsqueeze(0).unsqueeze(2)
103
+ return (tensor_5d * 2.0) - 1.0
104
 
105
+ @torch.no_grad()
106
+ def generate_conditioning_items(
107
  self,
108
+ media_items: List[Union[Image.Image, torch.Tensor]],
109
+ target_frames: List[int],
110
+ strengths: List[float],
111
+ target_resolution: Tuple[int, int]
112
+ ) -> List[LatentConditioningItem]:
113
  """
114
+ [FUNÇÃO PRINCIPAL] Converte uma lista de imagens (PIL ou tensores de pixel)
115
+ em uma lista de LatentConditioningItem, pronta para a pipeline LTX corrigida.
 
 
 
 
 
 
 
 
 
116
  """
117
  t0 = time.time()
118
+ logging.info(f"VaeServer: Generating {len(media_items)} latent conditioning items...")
119
+
120
+ if not (len(media_items) == len(target_frames) == len(strengths)):
121
+ raise ValueError("Input lists for conditioning items must have the same length.")
122
+
123
+ conditioning_items = []
124
+ try:
125
+ for item, frame, strength in zip(media_items, target_frames, strengths):
126
+ pixel_tensor = self._preprocess_input(item, target_resolution)
127
+ pixel_tensor_gpu = pixel_tensor.to(self.device, dtype=self.dtype)
128
+ latents = vae_encode(pixel_tensor_gpu, self.vae, vae_per_channel_normalize=True)
129
+ conditioning_items.append(LatentConditioningItem(latents.cpu(), frame, strength))
130
+
131
+ logging.info(f"VaeServer: Generated {len(conditioning_items)} items in {time.time() - t0:.2f}s.")
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  return conditioning_items
133
+ finally:
134
+ self._cleanup_gpu()
 
135
 
136
+ @torch.no_grad()
137
+ def decode_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
138
+ """Decodifica um tensor latente para um tensor de pixels, retornando na CPU."""
139
+ t0 = time.time()
140
+ try:
141
+ latent_tensor_gpu = latent_tensor.to(self.device, dtype=self.dtype)
142
+ num_items_in_batch = latent_tensor_gpu.shape[0]
143
+ timestep_tensor = torch.tensor([decode_timestep] * num_items_in_batch, device=self.device, dtype=self.dtype)
144
+
145
+ pixels = vae_decode(
146
+ latent_tensor_gpu, self.vae, is_video=True,
147
+ timestep=timestep_tensor, vae_per_channel_normalize=True
148
+ )
149
+ logging.info(f"VaeServer: Decoded latents with shape {latent_tensor.shape} in {time.time() - t0:.2f}s.")
150
+ return pixels.cpu()
151
+ finally:
152
+ self._cleanup_gpu()
153
+
154
+ # --- Instância Singleton ---
155
  try:
156
+ # A inicialização depende do LTXPoolManager para obter o VAE
157
+ from api.ltx.ltx_aduc_manager import ltx_pool_manager
158
+ if ltx_pool_manager:
159
+ vae_server_singleton = VaeServer()
160
+ else:
161
+ raise RuntimeError("LTXPoolManager failed to initialize, cannot start VaeServer.")
162
  except Exception as e:
163
+ logging.critical("CRITICAL: Failed to initialize VaeServer singleton.", exc_info=True)
164
+ vae_server_singleton = None