Update video_service.py
Browse files- video_service.py +19 -8
video_service.py
CHANGED
|
@@ -311,8 +311,11 @@ class VideoService:
|
|
| 311 |
|
| 312 |
return pipeline, latent_upsampler
|
| 313 |
|
| 314 |
-
# Precisão: promove FP8->BF16 e define dtype de autocast
|
| 315 |
def _promote_fp8_weights_to_bf16(self, module):
|
|
|
|
|
|
|
|
|
|
| 316 |
f8 = getattr(torch, "float8_e4m3fn", None)
|
| 317 |
if f8 is None:
|
| 318 |
return
|
|
@@ -329,24 +332,32 @@ class VideoService:
|
|
| 329 |
b.data = b.data.to(torch.bfloat16)
|
| 330 |
except Exception:
|
| 331 |
pass
|
| 332 |
-
|
| 333 |
def _apply_precision_policy(self):
|
| 334 |
prec = str(self.config.get("precision", "")).lower()
|
| 335 |
self.runtime_autocast_dtype = torch.float32
|
| 336 |
if prec == "float8_e4m3fn":
|
| 337 |
-
# FP8
|
| 338 |
-
if hasattr(torch, "float8_e4m3fn"):
|
| 339 |
-
self._promote_fp8_weights_to_bf16(self.pipeline)
|
| 340 |
-
if self.latent_upsampler:
|
| 341 |
-
self._promote_fp8_weights_to_bf16(self.latent_upsampler)
|
| 342 |
self.runtime_autocast_dtype = torch.bfloat16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
elif prec == "bfloat16":
|
| 344 |
self.runtime_autocast_dtype = torch.bfloat16
|
| 345 |
elif prec == "mixed_precision":
|
| 346 |
self.runtime_autocast_dtype = torch.float16
|
| 347 |
else:
|
| 348 |
self.runtime_autocast_dtype = torch.float32
|
| 349 |
-
|
| 350 |
def _prepare_conditioning_tensor(self, filepath, height, width, padding_values):
|
| 351 |
tensor = load_image_to_tensor_with_resize_and_crop(filepath, height, width)
|
| 352 |
tensor = torch.nn.functional.pad(tensor, padding_values)
|
|
|
|
| 311 |
|
| 312 |
return pipeline, latent_upsampler
|
| 313 |
|
| 314 |
+
# Precisão: promove FP8->BF16 e define dtype de autocast (versão segura)
|
| 315 |
def _promote_fp8_weights_to_bf16(self, module):
|
| 316 |
+
# Só promova se for realmente um nn.Module; Pipelines não são nn.Module
|
| 317 |
+
if not isinstance(module, torch.nn.Module):
|
| 318 |
+
return
|
| 319 |
f8 = getattr(torch, "float8_e4m3fn", None)
|
| 320 |
if f8 is None:
|
| 321 |
return
|
|
|
|
| 332 |
b.data = b.data.to(torch.bfloat16)
|
| 333 |
except Exception:
|
| 334 |
pass
|
| 335 |
+
|
| 336 |
def _apply_precision_policy(self):
|
| 337 |
prec = str(self.config.get("precision", "")).lower()
|
| 338 |
self.runtime_autocast_dtype = torch.float32
|
| 339 |
if prec == "float8_e4m3fn":
|
| 340 |
+
# FP8: kernels nativos da LTX podem estar ativos; por padrão, não promover pesos
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
self.runtime_autocast_dtype = torch.bfloat16
|
| 342 |
+
force_promote = os.getenv("LTXV_FORCE_BF16_ON_FP8", "0") == "1"
|
| 343 |
+
if force_promote and hasattr(torch, "float8_e4m3fn"):
|
| 344 |
+
# Promove apenas módulos reais; ignora objetos Pipeline
|
| 345 |
+
try:
|
| 346 |
+
self._promote_fp8_weights_to_bf16(self.pipeline)
|
| 347 |
+
except Exception:
|
| 348 |
+
pass
|
| 349 |
+
try:
|
| 350 |
+
if self.latent_upsampler:
|
| 351 |
+
self._promote_fp8_weights_to_bf16(self.latent_upsampler)
|
| 352 |
+
except Exception:
|
| 353 |
+
pass
|
| 354 |
elif prec == "bfloat16":
|
| 355 |
self.runtime_autocast_dtype = torch.bfloat16
|
| 356 |
elif prec == "mixed_precision":
|
| 357 |
self.runtime_autocast_dtype = torch.float16
|
| 358 |
else:
|
| 359 |
self.runtime_autocast_dtype = torch.float32
|
| 360 |
+
|
| 361 |
def _prepare_conditioning_tensor(self, filepath, height, width, padding_values):
|
| 362 |
tensor = load_image_to_tensor_with_resize_and_crop(filepath, height, width)
|
| 363 |
tensor = torch.nn.functional.pad(tensor, padding_values)
|