eeuuia commited on
Commit
ecd2b0d
·
verified ·
1 Parent(s): ecd3981

Update api/ltx/ltx_aduc_manager.py

Browse files
Files changed (1) hide show
  1. api/ltx/ltx_aduc_manager.py +36 -22
api/ltx/ltx_aduc_manager.py CHANGED
@@ -9,10 +9,11 @@ from pathlib import Path
9
  import threading
10
  import queue
11
  import time
 
12
  from typing import List, Optional, Callable, Any, Tuple
13
 
14
  # Imports dos builders e do gpu_manager
15
- from api.ltx.ltx.ltx_utils import get_main_ltx_pipeline, get_main_vae
16
  from managers.gpu_manager import gpu_manager
17
 
18
  # --- Adiciona o path do LTX-Video para importação de tipos ---
@@ -63,24 +64,42 @@ class LTXMainWorker(BaseWorker):
63
  def __init__(self, worker_id: int, device: torch.device):
64
  super().__init__(worker_id, device)
65
  self.pipeline: Optional[LTXVideoPipeline] = None
 
66
 
67
  def _load_models(self):
68
  logging.info(f"[LTXWorker-{self.worker_id}] Loading models to CPU...")
69
  self.pipeline = get_main_ltx_pipeline()
 
70
  logging.info(f"[LTXWorker-{self.worker_id}] Moving pipeline to {self.device}...")
71
  self.pipeline.to(self.device)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def execute(self, job_func: Callable, args: tuple, kwargs: dict) -> Any:
74
- """Executa um trabalho, gerenciando o estado 'busy'."""
75
  self.is_busy = True
76
  logging.info(f"Worker {self.worker_id} (LTX) starting job: {job_func.__name__}")
77
  try:
78
- result = job_func(self.pipeline, *args, **kwargs)
 
79
  logging.info(f"Worker {self.worker_id} (LTX) finished job successfully.")
80
  return result
81
  except Exception as e:
82
  logging.error(f"Worker {self.worker_id} (LTX) job failed!", exc_info=True)
83
- self.is_healthy = False # Falha em um job marca o worker como não saudável
84
  raise
85
  finally:
86
  self.is_busy = False
@@ -99,7 +118,6 @@ class VAEWorker(BaseWorker):
99
  self.vae.eval()
100
 
101
  def execute(self, job_func: Callable, args: tuple, kwargs: dict) -> Any:
102
- """Executa um trabalho, gerenciando o estado 'busy'."""
103
  self.is_busy = True
104
  logging.info(f"Worker {self.worker_id} (VAE) starting job: {job_func.__name__}")
105
  try:
@@ -138,7 +156,6 @@ class LTXAducManager:
138
 
139
  self._initialize_workers()
140
 
141
- # Inicia threads consumidores para processar as filas
142
  self.ltx_dispatcher = threading.Thread(target=self._dispatch_jobs, args=(self.ltx_job_queue, self.ltx_workers), daemon=True)
143
  self.vae_dispatcher = threading.Thread(target=self._dispatch_jobs, args=(self.vae_job_queue, self.vae_workers), daemon=True)
144
  self.health_monitor = threading.Thread(target=self._health_check_loop, daemon=True)
@@ -152,17 +169,16 @@ class LTXAducManager:
152
 
153
  def _initialize_workers(self):
154
  """Cria e inicia os workers com base nas GPUs alocadas."""
155
- # Supondo que gpu_manager agora tenha get_ltx_devices() e get_seedvr_devices() que retornam listas
156
- ltx_gpus = gpu_manager.get_ltx_device() # Ajuste se o nome for diferente
157
- vae_gpus = gpu_manager.get_ltx_vae_device() # Ajuste se o nome for diferente
158
 
159
  with self.pool_lock:
160
- for i, device_id in enumerate([ltx_gpus]): # Assumindo que retorna uma lista
161
  worker = LTXMainWorker(worker_id=i, device=torch.device(f"cuda:{device_id}"))
162
  self.ltx_workers.append(worker)
163
  worker.start()
164
 
165
- for i, device_id in enumerate([vae_gpus]): # Assumindo que retorna uma lista
166
  worker = VAEWorker(worker_id=i, device=torch.device(f"cuda:{device_id}"))
167
  self.vae_workers.append(worker)
168
  worker.start()
@@ -170,6 +186,8 @@ class LTXAducManager:
170
  def _get_available_worker(self, worker_pool: List[BaseWorker]) -> Optional[BaseWorker]:
171
  """Encontra um worker saudável e desocupado no pool."""
172
  with self.pool_lock:
 
 
173
  for worker in worker_pool:
174
  healthy, busy = worker.get_status()
175
  if healthy and not busy:
@@ -184,7 +202,7 @@ class LTXAducManager:
184
  while worker is None:
185
  worker = self._get_available_worker(worker_pool)
186
  if worker is None:
187
- time.sleep(0.1) # Espera por um worker ficar livre
188
 
189
  try:
190
  result = worker.execute(job_func, args, kwargs)
@@ -200,36 +218,32 @@ class LTXAducManager:
200
  with self.pool_lock:
201
  for i, worker in enumerate(self.ltx_workers):
202
  if not worker.is_alive() or not worker.is_healthy:
203
- logging.warning(f"LTX Worker {worker.worker_id} on {worker.device} is UNHEALTHY. Restarting...")
204
  new_worker = LTXMainWorker(worker.worker_id, worker.device)
205
  self.ltx_workers[i] = new_worker
206
  new_worker.start()
207
- # Repetir o laço para VAE workers
208
  for i, worker in enumerate(self.vae_workers):
209
  if not worker.is_alive() or not worker.is_healthy:
210
- logging.warning(f"VAE Worker {worker.worker_id} on {worker.device} is UNHEALTHY. Restarting...")
211
  new_worker = VAEWorker(worker.worker_id, worker.device)
212
  self.vae_workers[i] = new_worker
213
  new_worker.start()
214
 
215
  def submit_job(self, job_type: str, job_func: Callable, *args, **kwargs) -> Any:
216
- """
217
- Ponto de entrada público para submeter um trabalho ao pool.
218
- Esta função é síncrona: ela espera pelo resultado.
219
- """
220
  if job_type not in ['ltx', 'vae']:
221
  raise ValueError("Invalid job_type. Must be 'ltx' or 'vae'.")
222
 
223
  job_queue = self.ltx_job_queue if job_type == 'ltx' else self.vae_job_queue
224
- future = queue.Queue() # Usamos uma fila como um 'future' para obter o resultado de volta
225
 
226
  job_queue.put((job_func, args, kwargs, future))
227
 
228
- # Bloqueia e espera pelo resultado ser colocado no 'future' pelo dispatcher
229
  result = future.get()
230
 
231
  if isinstance(result, Exception):
232
- raise result # Se o job falhou, re-lança a exceção no thread principal
233
 
234
  return result
235
 
 
9
  import threading
10
  import queue
11
  import time
12
+ import yaml
13
  from typing import List, Optional, Callable, Any, Tuple
14
 
15
  # Imports dos builders e do gpu_manager
16
+ from api.ltx.ltx_utils import get_main_ltx_pipeline, get_main_vae
17
  from managers.gpu_manager import gpu_manager
18
 
19
  # --- Adiciona o path do LTX-Video para importação de tipos ---
 
64
  def __init__(self, worker_id: int, device: torch.device):
65
  super().__init__(worker_id, device)
66
  self.pipeline: Optional[LTXVideoPipeline] = None
67
+ self.autocast_dtype: torch.dtype = torch.float32
68
 
69
  def _load_models(self):
70
  logging.info(f"[LTXWorker-{self.worker_id}] Loading models to CPU...")
71
  self.pipeline = get_main_ltx_pipeline()
72
+ self._set_precision_policy()
73
  logging.info(f"[LTXWorker-{self.worker_id}] Moving pipeline to {self.device}...")
74
  self.pipeline.to(self.device)
75
 
76
+ def _set_precision_policy(self):
77
+ """Determina o dtype para o torch.autocast com base na config."""
78
+ try:
79
+ config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
80
+ with open(config_path, "r") as file:
81
+ config = yaml.safe_load(file)
82
+ precision = str(config.get("precision", "bfloat16")).lower()
83
+ if precision in ["float8_e4m3fn", "bfloat16"]:
84
+ self.autocast_dtype = torch.bfloat16
85
+ elif precision == "mixed_precision":
86
+ self.autocast_dtype = torch.float16
87
+ logging.info(f"[LTXWorker-{self.worker_id}] Autocast precision policy set to {self.autocast_dtype}")
88
+ except Exception as e:
89
+ logging.warning(f"[LTXWorker-{self.worker_id}] Could not set precision policy from config. Defaulting to float32. Error: {e}")
90
+ self.autocast_dtype = torch.float32
91
+
92
  def execute(self, job_func: Callable, args: tuple, kwargs: dict) -> Any:
 
93
  self.is_busy = True
94
  logging.info(f"Worker {self.worker_id} (LTX) starting job: {job_func.__name__}")
95
  try:
96
+ # Passa a sua própria instância do pipeline e o dtype para a função do job
97
+ result = job_func(self.pipeline, self.autocast_dtype, *args, **kwargs)
98
  logging.info(f"Worker {self.worker_id} (LTX) finished job successfully.")
99
  return result
100
  except Exception as e:
101
  logging.error(f"Worker {self.worker_id} (LTX) job failed!", exc_info=True)
102
+ self.is_healthy = False
103
  raise
104
  finally:
105
  self.is_busy = False
 
118
  self.vae.eval()
119
 
120
  def execute(self, job_func: Callable, args: tuple, kwargs: dict) -> Any:
 
121
  self.is_busy = True
122
  logging.info(f"Worker {self.worker_id} (VAE) starting job: {job_func.__name__}")
123
  try:
 
156
 
157
  self._initialize_workers()
158
 
 
159
  self.ltx_dispatcher = threading.Thread(target=self._dispatch_jobs, args=(self.ltx_job_queue, self.ltx_workers), daemon=True)
160
  self.vae_dispatcher = threading.Thread(target=self._dispatch_jobs, args=(self.vae_job_queue, self.vae_workers), daemon=True)
161
  self.health_monitor = threading.Thread(target=self._health_check_loop, daemon=True)
 
169
 
170
  def _initialize_workers(self):
171
  """Cria e inicia os workers com base nas GPUs alocadas."""
172
+ ltx_gpus = [gpu_manager.get_ltx_device().index] # Assumindo que o getter retorna um device object
173
+ vae_gpus = [gpu_manager.get_ltx_vae_device().index]
 
174
 
175
  with self.pool_lock:
176
+ for i, device_id in enumerate(ltx_gpus):
177
  worker = LTXMainWorker(worker_id=i, device=torch.device(f"cuda:{device_id}"))
178
  self.ltx_workers.append(worker)
179
  worker.start()
180
 
181
+ for i, device_id in enumerate(vae_gpus):
182
  worker = VAEWorker(worker_id=i, device=torch.device(f"cuda:{device_id}"))
183
  self.vae_workers.append(worker)
184
  worker.start()
 
186
  def _get_available_worker(self, worker_pool: List[BaseWorker]) -> Optional[BaseWorker]:
187
  """Encontra um worker saudável e desocupado no pool."""
188
  with self.pool_lock:
189
+ # Simples estratégia round-robin para distribuir a carga
190
+ # Uma estratégia mais complexa poderia verificar a carga da GPU
191
  for worker in worker_pool:
192
  healthy, busy = worker.get_status()
193
  if healthy and not busy:
 
202
  while worker is None:
203
  worker = self._get_available_worker(worker_pool)
204
  if worker is None:
205
+ time.sleep(0.1)
206
 
207
  try:
208
  result = worker.execute(job_func, args, kwargs)
 
218
  with self.pool_lock:
219
  for i, worker in enumerate(self.ltx_workers):
220
  if not worker.is_alive() or not worker.is_healthy:
221
+ logging.warning(f"LTX Worker {worker.worker_id} on {worker.device} is UNHEALTHY or dead. Restarting...")
222
  new_worker = LTXMainWorker(worker.worker_id, worker.device)
223
  self.ltx_workers[i] = new_worker
224
  new_worker.start()
225
+
226
  for i, worker in enumerate(self.vae_workers):
227
  if not worker.is_alive() or not worker.is_healthy:
228
+ logging.warning(f"VAE Worker {worker.worker_id} on {worker.device} is UNHEALTHY or dead. Restarting...")
229
  new_worker = VAEWorker(worker.worker_id, worker.device)
230
  self.vae_workers[i] = new_worker
231
  new_worker.start()
232
 
233
  def submit_job(self, job_type: str, job_func: Callable, *args, **kwargs) -> Any:
234
+ """Ponto de entrada público para submeter um trabalho ao pool de forma síncrona."""
 
 
 
235
  if job_type not in ['ltx', 'vae']:
236
  raise ValueError("Invalid job_type. Must be 'ltx' or 'vae'.")
237
 
238
  job_queue = self.ltx_job_queue if job_type == 'ltx' else self.vae_job_queue
239
+ future = queue.Queue(1)
240
 
241
  job_queue.put((job_func, args, kwargs, future))
242
 
 
243
  result = future.get()
244
 
245
  if isinstance(result, Exception):
246
+ raise result
247
 
248
  return result
249