Spaces:
Paused
Paused
Update api/ltx_server_refactored.py
Browse files- api/ltx_server_refactored.py +62 -1
api/ltx_server_refactored.py
CHANGED
|
@@ -459,4 +459,65 @@ class VideoService:
|
|
| 459 |
try:
|
| 460 |
overrides["guidance_scale"] = json.loads(ltx_configs["guidance_scale_list"])
|
| 461 |
overrides["stg_scale"] = json.loads(ltx_configs["stg_scale_list"])
|
| 462 |
-
except (json.JSONDecodeError, KeyError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
try:
|
| 460 |
overrides["guidance_scale"] = json.loads(ltx_configs["guidance_scale_list"])
|
| 461 |
overrides["stg_scale"] = json.loads(ltx_configs["stg_scale_list"])
|
| 462 |
+
except (json.JSONDecodeError, KeyError) as e:
|
| 463 |
+
logging.warning(f"Failed to parse custom guidance values: {e}. Falling back to defaults.")
|
| 464 |
+
|
| 465 |
+
if overrides:
|
| 466 |
+
logging.info(f"Applying '{preset}' guidance preset overrides.")
|
| 467 |
+
return overrides
|
| 468 |
+
|
| 469 |
+
def _save_and_log_video(self, pixel_tensor: torch.Tensor, base_filename: str) -> Path:
|
| 470 |
+
"""Saves a pixel tensor to an MP4 file and returns the final path."""
|
| 471 |
+
# Work in a temporary directory to handle atomic move
|
| 472 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 473 |
+
temp_path = os.path.join(temp_dir, f"{base_filename}.mp4")
|
| 474 |
+
video_encode_tool_singleton.save_video_from_tensor(
|
| 475 |
+
pixel_tensor, temp_path, fps=DEFAULT_FPS
|
| 476 |
+
)
|
| 477 |
+
final_path = RESULTS_DIR / f"{base_filename}.mp4"
|
| 478 |
+
shutil.move(temp_path, final_path)
|
| 479 |
+
logging.info(f"Video saved successfully to: {final_path}")
|
| 480 |
+
return final_path
|
| 481 |
+
|
| 482 |
+
def _apply_precision_policy(self):
|
| 483 |
+
"""Sets the autocast dtype based on the configuration file."""
|
| 484 |
+
precision = str(self.config.get("precision", "bfloat16")).lower()
|
| 485 |
+
if precision in ["float8_e4m3fn", "bfloat16"]:
|
| 486 |
+
self.runtime_autocast_dtype = torch.bfloat16
|
| 487 |
+
elif precision == "mixed_precision":
|
| 488 |
+
self.runtime_autocast_dtype = torch.float16
|
| 489 |
+
else:
|
| 490 |
+
self.runtime_autocast_dtype = torch.float32
|
| 491 |
+
logging.info(f"Runtime precision policy set for autocast: {self.runtime_autocast_dtype}")
|
| 492 |
+
|
| 493 |
+
def _align(self, dim: int, alignment: int = FRAMES_ALIGNMENT) -> int:
|
| 494 |
+
"""Aligns a dimension to the nearest multiple of `alignment`."""
|
| 495 |
+
return ((dim - 1) // alignment + 1) * alignment
|
| 496 |
+
|
| 497 |
+
def _calculate_aligned_frames(self, duration_s: float, min_frames: int = 1) -> int:
|
| 498 |
+
"""Calculates the total number of frames based on duration, ensuring alignment."""
|
| 499 |
+
num_frames = int(round(duration_s * DEFAULT_FPS))
|
| 500 |
+
aligned_frames = self._align(num_frames)
|
| 501 |
+
# Ensure it's at least 1 frame longer than the alignment for some ops, and respects min_frames
|
| 502 |
+
final_frames = max(aligned_frames + 1, min_frames)
|
| 503 |
+
return final_frames
|
| 504 |
+
|
| 505 |
+
def _resolve_seed(self, seed: Optional[int]) -> int:
|
| 506 |
+
"""Returns the given seed or generates a new random one."""
|
| 507 |
+
return random.randint(0, 2**32 - 1) if seed is None else int(seed)
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
# ==============================================================================
|
| 511 |
+
# --- SINGLETON INSTANTIATION ---
|
| 512 |
+
# ==============================================================================
|
| 513 |
+
# The service is instantiated once when the module is imported, ensuring a single
|
| 514 |
+
# instance manages the models and GPU resources throughout the application's life.
|
| 515 |
+
|
| 516 |
+
try:
|
| 517 |
+
video_generation_service = VideoService()
|
| 518 |
+
logging.info("Global VideoService instance created successfully.")
|
| 519 |
+
except Exception as e:
|
| 520 |
+
logging.critical(f"Failed to initialize VideoService: {e}")
|
| 521 |
+
traceback.print_exc()
|
| 522 |
+
# Exit if the core service fails to start, as the app is non-functional
|
| 523 |
+
sys.exit(1)
|