EuuIia commited on
Commit
c825b23
·
verified ·
1 Parent(s): 33de423

Update video_service.py

Browse files
Files changed (1) hide show
  1. video_service.py +19 -8
video_service.py CHANGED
@@ -311,8 +311,11 @@ class VideoService:
311
 
312
  return pipeline, latent_upsampler
313
 
314
- # Precisão: promove FP8->BF16 e define dtype de autocast
315
  def _promote_fp8_weights_to_bf16(self, module):
 
 
 
316
  f8 = getattr(torch, "float8_e4m3fn", None)
317
  if f8 is None:
318
  return
@@ -329,24 +332,32 @@ class VideoService:
329
  b.data = b.data.to(torch.bfloat16)
330
  except Exception:
331
  pass
332
-
333
  def _apply_precision_policy(self):
334
  prec = str(self.config.get("precision", "")).lower()
335
  self.runtime_autocast_dtype = torch.float32
336
  if prec == "float8_e4m3fn":
337
- # FP8 experimental: promove pesos para BF16 e padroniza autocast em BF16
338
- if hasattr(torch, "float8_e4m3fn"):
339
- self._promote_fp8_weights_to_bf16(self.pipeline)
340
- if self.latent_upsampler:
341
- self._promote_fp8_weights_to_bf16(self.latent_upsampler)
342
  self.runtime_autocast_dtype = torch.bfloat16
 
 
 
 
 
 
 
 
 
 
 
 
343
  elif prec == "bfloat16":
344
  self.runtime_autocast_dtype = torch.bfloat16
345
  elif prec == "mixed_precision":
346
  self.runtime_autocast_dtype = torch.float16
347
  else:
348
  self.runtime_autocast_dtype = torch.float32
349
-
350
  def _prepare_conditioning_tensor(self, filepath, height, width, padding_values):
351
  tensor = load_image_to_tensor_with_resize_and_crop(filepath, height, width)
352
  tensor = torch.nn.functional.pad(tensor, padding_values)
 
311
 
312
  return pipeline, latent_upsampler
313
 
314
+ # Precisão: promove FP8->BF16 e define dtype de autocast (versão segura)
315
  def _promote_fp8_weights_to_bf16(self, module):
316
+ # Só promova se for realmente um nn.Module; Pipelines não são nn.Module
317
+ if not isinstance(module, torch.nn.Module):
318
+ return
319
  f8 = getattr(torch, "float8_e4m3fn", None)
320
  if f8 is None:
321
  return
 
332
  b.data = b.data.to(torch.bfloat16)
333
  except Exception:
334
  pass
335
+
336
  def _apply_precision_policy(self):
337
  prec = str(self.config.get("precision", "")).lower()
338
  self.runtime_autocast_dtype = torch.float32
339
  if prec == "float8_e4m3fn":
340
+ # FP8: kernels nativos da LTX podem estar ativos; por padrão, não promover pesos
 
 
 
 
341
  self.runtime_autocast_dtype = torch.bfloat16
342
+ force_promote = os.getenv("LTXV_FORCE_BF16_ON_FP8", "0") == "1"
343
+ if force_promote and hasattr(torch, "float8_e4m3fn"):
344
+ # Promove apenas módulos reais; ignora objetos Pipeline
345
+ try:
346
+ self._promote_fp8_weights_to_bf16(self.pipeline)
347
+ except Exception:
348
+ pass
349
+ try:
350
+ if self.latent_upsampler:
351
+ self._promote_fp8_weights_to_bf16(self.latent_upsampler)
352
+ except Exception:
353
+ pass
354
  elif prec == "bfloat16":
355
  self.runtime_autocast_dtype = torch.bfloat16
356
  elif prec == "mixed_precision":
357
  self.runtime_autocast_dtype = torch.float16
358
  else:
359
  self.runtime_autocast_dtype = torch.float32
360
+
361
  def _prepare_conditioning_tensor(self, filepath, height, width, padding_values):
362
  tensor = load_image_to_tensor_with_resize_and_crop(filepath, height, width)
363
  tensor = torch.nn.functional.pad(tensor, padding_values)