Spaces:
Running on Zero
refactor: replace _cpu_ctx with thread-local storage, deduplicate xregen wrappers, parallel downloads, quiet=True
Browse files- Replace fragile function-attribute CPU→GPU context passing (_fn._cpu_ctx = {})
with thread-local storage (_tl.<name>_ctx) for thread safety under ZeroGPU
multi-user concurrency — 6 sites updated across generate_* and regen_* paths
- Add _xregen_dispatch() generator helper to deduplicate the pending-yield /
infer / splice-yield skeleton shared by xregen_taro, xregen_mmaudio,
xregen_hunyuan (~40 lines removed)
- Parallelize all 7 startup downloads with ThreadPoolExecutor (I/O-bound network
calls run concurrently, cutting Space cold-start time ~proportionally)
- Consolidate per-model scalar constants into MODEL_CONFIGS as single source of
truth; add _clamp_duration() / _estimate_gpu_duration() / _estimate_regen_duration()
helpers to eliminate repeated duration-clamping boilerplate
- Restore quiet=True in mux_video_audio (was temporarily quiet=False for debugging)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
@@ -17,6 +17,7 @@ import tempfile
|
|
| 17 |
import random
|
| 18 |
import threading
|
| 19 |
import time
|
|
|
|
| 20 |
from pathlib import Path
|
| 21 |
|
| 22 |
import torch
|
|
@@ -35,69 +36,102 @@ CKPT_REPO_ID = "JackIsNotInTheBox/Generate_Audio_for_Video_Checkpoints"
|
|
| 35 |
CACHE_DIR = "/tmp/model_ckpts"
|
| 36 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 37 |
|
| 38 |
-
# ----
|
| 39 |
-
print("Downloading TARO checkpoints…")
|
| 40 |
-
cavp_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/cavp_epoch66.ckpt", cache_dir=CACHE_DIR)
|
| 41 |
-
onset_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/onset_model.ckpt", cache_dir=CACHE_DIR)
|
| 42 |
-
taro_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/taro_ckpt.pt", cache_dir=CACHE_DIR)
|
| 43 |
-
print("TARO checkpoints downloaded.")
|
| 44 |
-
|
| 45 |
-
# ---- MMAudio checkpoints (in MMAudio/ subfolder) ----
|
| 46 |
-
# MMAudio normally auto-downloads from its own HF repo, but we
|
| 47 |
-
# override the paths so it pulls from our consolidated repo instead.
|
| 48 |
MMAUDIO_WEIGHTS_DIR = Path(CACHE_DIR) / "MMAudio" / "weights"
|
| 49 |
MMAUDIO_EXT_DIR = Path(CACHE_DIR) / "MMAudio" / "ext_weights"
|
|
|
|
| 50 |
MMAUDIO_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 51 |
MMAUDIO_EXT_DIR.mkdir(parents=True, exist_ok=True)
|
| 52 |
-
|
| 53 |
-
print("Downloading MMAudio checkpoints…")
|
| 54 |
-
mmaudio_model_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/mmaudio_large_44k_v2.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_WEIGHTS_DIR), local_dir_use_symlinks=False)
|
| 55 |
-
mmaudio_vae_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/v1-44.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
|
| 56 |
-
mmaudio_synchformer_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/synchformer_state_dict.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
|
| 57 |
-
print("MMAudio checkpoints downloaded.")
|
| 58 |
-
|
| 59 |
-
# ---- HunyuanVideoFoley checkpoints (in HunyuanFoley/ subfolder) ----
|
| 60 |
-
HUNYUAN_MODEL_DIR = Path(CACHE_DIR) / "HunyuanFoley"
|
| 61 |
HUNYUAN_MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
#
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
print("
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
# ================================================================== #
|
| 98 |
# SHARED CONSTANTS / HELPERS #
|
| 99 |
# ================================================================== #
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
MAX_SLOTS = 8 # max parallel generation slots shown in UI
|
| 102 |
MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≤ ~64 s at 8 s/seg)
|
| 103 |
|
|
@@ -351,7 +385,7 @@ def mux_video_audio(silent_video: str, audio_path: str, output_path: str,
|
|
| 351 |
pix_fmt="yuv420p",
|
| 352 |
acodec="aac", audio_bitrate="128k",
|
| 353 |
movflags="+faststart",
|
| 354 |
-
).run(overwrite_output=True, quiet=
|
| 355 |
|
| 356 |
|
| 357 |
# ------------------------------------------------------------------ #
|
|
@@ -417,65 +451,76 @@ def _cf_join(a: np.ndarray, b: np.ndarray,
|
|
| 417 |
# latents_scale: [0.18215]*8 — AudioLDM2 VAE scale factor
|
| 418 |
# ================================================================== #
|
| 419 |
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
TARO_TRUNCATE_ONSET = 120
|
| 425 |
-
TARO_MODEL_DUR
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
TARO_LOAD_OVERHEAD = 15 # seconds: model load + CAVP feature extraction
|
| 429 |
-
MMAUDIO_WINDOW = 8.0 # seconds — MMAudio's fixed generation window
|
| 430 |
-
MMAUDIO_SECS_PER_STEP = 0.25 # measured 0.230s/step on H200 (8.3s video, 2 segs × 25 steps = 11.5s wall)
|
| 431 |
-
MMAUDIO_LOAD_OVERHEAD = 30 # 15s warm + 15s model init; open_clip pre-downloaded at startup
|
| 432 |
-
HUNYUAN_MAX_DUR = 15.0 # seconds — HunyuanFoley max video duration
|
| 433 |
-
HUNYUAN_SECS_PER_STEP = 0.35 # measured 0.328s/step on H200 (8.3s video, 1 seg × 50 steps = 16.4s wall)
|
| 434 |
-
HUNYUAN_LOAD_OVERHEAD = 55 # ~55s to load the 10GB XXL model weights into GPU
|
| 435 |
-
GPU_DURATION_CAP = 300 # hard cap per call — never reserve more than this
|
| 436 |
|
| 437 |
-
# ------------------------------------------------------------------ #
|
| 438 |
-
# Model configuration registry — single source of truth for per-model #
|
| 439 |
-
# constants used by duration estimation, segmentation, and UI. #
|
| 440 |
-
# ------------------------------------------------------------------ #
|
| 441 |
MODEL_CONFIGS = {
|
| 442 |
"taro": {
|
| 443 |
-
"window_s": TARO_MODEL_DUR,
|
| 444 |
-
"sr": TARO_SR,
|
| 445 |
-
"secs_per_step":
|
| 446 |
-
"load_overhead":
|
| 447 |
"tab_prefix": "taro",
|
| 448 |
-
"regen_fn": None, # set after function definitions (avoids forward-ref)
|
| 449 |
"label": "TARO",
|
|
|
|
| 450 |
},
|
| 451 |
"mmaudio": {
|
| 452 |
-
"window_s":
|
| 453 |
-
"sr": 48000,
|
| 454 |
-
"secs_per_step":
|
| 455 |
-
"load_overhead":
|
| 456 |
"tab_prefix": "mma",
|
| 457 |
-
"regen_fn": None,
|
| 458 |
"label": "MMAudio",
|
|
|
|
| 459 |
},
|
| 460 |
"hunyuan": {
|
| 461 |
-
"window_s":
|
| 462 |
"sr": 48000,
|
| 463 |
-
"secs_per_step":
|
| 464 |
-
"load_overhead":
|
| 465 |
"tab_prefix": "hf",
|
| 466 |
-
"regen_fn": None,
|
| 467 |
"label": "HunyuanFoley",
|
|
|
|
| 468 |
},
|
| 469 |
}
|
| 470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
|
| 472 |
def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int,
|
| 473 |
total_dur_s: float = None, crossfade_s: float = 0,
|
| 474 |
video_file: str = None) -> int:
|
| 475 |
-
"""
|
| 476 |
|
| 477 |
-
|
| 478 |
-
Clamped to [60, GPU_DURATION_CAP].
|
| 479 |
"""
|
| 480 |
cfg = MODEL_CONFIGS[model_key]
|
| 481 |
try:
|
|
@@ -484,25 +529,18 @@ def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int,
|
|
| 484 |
n_segs = len(_build_segments(total_dur_s, cfg["window_s"], float(crossfade_s)))
|
| 485 |
except Exception:
|
| 486 |
n_segs = 1
|
| 487 |
-
secs
|
| 488 |
-
result = min(GPU_DURATION_CAP, max(60, int(secs)))
|
| 489 |
print(f"[duration] {cfg['label']}: {int(num_samples)}samp × {n_segs}seg × "
|
| 490 |
-
f"{int(num_steps)}steps → {secs:.0f}s → capped
|
| 491 |
-
return
|
| 492 |
|
| 493 |
|
| 494 |
def _estimate_regen_duration(model_key: str, num_steps: int) -> int:
|
| 495 |
-
"""
|
| 496 |
-
|
| 497 |
-
Floor is 20s — enough headroom above the 10s ZeroGPU abort threshold
|
| 498 |
-
for any model on a warm worker. Cold-start spin-up happens *before*
|
| 499 |
-
the timer starts so raising the floor does not help with cold-start aborts.
|
| 500 |
-
"""
|
| 501 |
cfg = MODEL_CONFIGS[model_key]
|
| 502 |
secs = int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
return result
|
| 506 |
|
| 507 |
_TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit
|
| 508 |
_TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s)
|
|
@@ -750,8 +788,8 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 750 |
from TARO.onset_util import extract_onset
|
| 751 |
from TARO.samplers import euler_sampler, euler_maruyama_sampler
|
| 752 |
|
| 753 |
-
# Use pre-computed CPU results
|
| 754 |
-
ctx =
|
| 755 |
tmp_dir = ctx["tmp_dir"]
|
| 756 |
silent_video = ctx["silent_video"]
|
| 757 |
segments = ctx["segments"]
|
|
@@ -810,9 +848,6 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 810 |
|
| 811 |
return results
|
| 812 |
|
| 813 |
-
# Attach a context slot for the CPU wrapper to pass pre-computed data
|
| 814 |
-
_taro_gpu_infer._cpu_ctx = {}
|
| 815 |
-
|
| 816 |
|
| 817 |
def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 818 |
crossfade_s, crossfade_db, num_samples):
|
|
@@ -826,8 +861,8 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
|
| 826 |
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
|
| 827 |
video_file, TARO_MODEL_DUR, crossfade_s)
|
| 828 |
|
| 829 |
-
# Pass pre-computed CPU results to the GPU function via
|
| 830 |
-
|
| 831 |
"tmp_dir": tmp_dir, "silent_video": silent_video,
|
| 832 |
"segments": segments, "total_dur_s": total_dur_s,
|
| 833 |
}
|
|
@@ -906,7 +941,7 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
|
| 906 |
|
| 907 |
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
|
| 908 |
|
| 909 |
-
ctx =
|
| 910 |
segments = ctx["segments"]
|
| 911 |
seg_clip_paths = ctx["seg_clip_paths"]
|
| 912 |
|
|
@@ -966,8 +1001,6 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
|
| 966 |
|
| 967 |
return results
|
| 968 |
|
| 969 |
-
_mmaudio_gpu_infer._cpu_ctx = {}
|
| 970 |
-
|
| 971 |
|
| 972 |
def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
|
| 973 |
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
|
|
@@ -987,7 +1020,7 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
|
|
| 987 |
for i, (s, e) in enumerate(segments)
|
| 988 |
]
|
| 989 |
|
| 990 |
-
|
| 991 |
"segments": segments, "seg_clip_paths": seg_clip_paths,
|
| 992 |
}
|
| 993 |
|
|
@@ -1057,7 +1090,7 @@ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
|
| 1057 |
|
| 1058 |
model_dict, cfg = _load_hunyuan_model(device, model_size)
|
| 1059 |
|
| 1060 |
-
ctx =
|
| 1061 |
segments = ctx["segments"]
|
| 1062 |
total_dur_s = ctx["total_dur_s"]
|
| 1063 |
dummy_seg_path = ctx["dummy_seg_path"]
|
|
@@ -1115,8 +1148,6 @@ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
|
|
| 1115 |
|
| 1116 |
return results
|
| 1117 |
|
| 1118 |
-
_hunyuan_gpu_infer._cpu_ctx = {}
|
| 1119 |
-
|
| 1120 |
|
| 1121 |
def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
|
| 1122 |
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
|
|
@@ -1143,7 +1174,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
|
|
| 1143 |
for i, (s, e) in enumerate(segments)
|
| 1144 |
]
|
| 1145 |
|
| 1146 |
-
|
| 1147 |
"segments": segments, "total_dur_s": total_dur_s,
|
| 1148 |
"dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
|
| 1149 |
}
|
|
@@ -1182,7 +1213,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
|
|
| 1182 |
|
| 1183 |
def _preload_taro_regen_ctx(meta: dict) -> dict:
|
| 1184 |
"""Pre-load TARO CAVP/onset features on CPU for regen.
|
| 1185 |
-
Returns a dict
|
| 1186 |
cavp_path = meta.get("cavp_path", "")
|
| 1187 |
onset_path = meta.get("onset_path", "")
|
| 1188 |
ctx = {}
|
|
@@ -1194,7 +1225,7 @@ def _preload_taro_regen_ctx(meta: dict) -> dict:
|
|
| 1194 |
|
| 1195 |
def _preload_hunyuan_regen_ctx(meta: dict, seg_path: str) -> dict:
|
| 1196 |
"""Pre-load HunyuanFoley text features + segment path on CPU for regen.
|
| 1197 |
-
Returns a dict
|
| 1198 |
ctx = {"seg_path": seg_path}
|
| 1199 |
text_feats_path = meta.get("text_feats_path", "")
|
| 1200 |
if text_feats_path and os.path.exists(text_feats_path):
|
|
@@ -1285,7 +1316,7 @@ def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
|
|
| 1285 |
from TARO.samplers import euler_sampler, euler_maruyama_sampler
|
| 1286 |
|
| 1287 |
# Use pre-loaded features from CPU wrapper (avoids np.load inside GPU window)
|
| 1288 |
-
ctx =
|
| 1289 |
if "cavp" in ctx and "onset" in ctx:
|
| 1290 |
print("[TARO regen] Using pre-loaded CAVP + onset features (CPU cache hit)")
|
| 1291 |
cavp_feats = ctx["cavp"]
|
|
@@ -1323,7 +1354,7 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1323 |
seg_idx = int(seg_idx)
|
| 1324 |
|
| 1325 |
# CPU: pre-load cached features so np.load doesn't happen inside GPU window
|
| 1326 |
-
|
| 1327 |
|
| 1328 |
# GPU: inference only
|
| 1329 |
new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
|
|
@@ -1365,7 +1396,7 @@ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
|
|
| 1365 |
sr = seq_cfg.sampling_rate
|
| 1366 |
|
| 1367 |
# Use pre-extracted segment clip from the CPU wrapper
|
| 1368 |
-
seg_path =
|
| 1369 |
assert seg_path, "[MMAudio regen] seg_path not set — wrapper must pre-extract segment clip"
|
| 1370 |
|
| 1371 |
rng = torch.Generator(device=device)
|
|
@@ -1391,8 +1422,6 @@ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
|
|
| 1391 |
new_wav = new_wav[:, :seg_samples]
|
| 1392 |
return new_wav, sr
|
| 1393 |
|
| 1394 |
-
_regen_mmaudio_gpu._cpu_ctx = {}
|
| 1395 |
-
|
| 1396 |
|
| 1397 |
def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
|
| 1398 |
prompt, negative_prompt, seed_val,
|
|
@@ -1409,7 +1438,7 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1409 |
meta["silent_video"], seg_start, seg_dur,
|
| 1410 |
os.path.join(tmp_dir, "regen_seg.mp4"),
|
| 1411 |
)
|
| 1412 |
-
|
| 1413 |
|
| 1414 |
# GPU: inference only
|
| 1415 |
new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
|
|
@@ -1458,12 +1487,11 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
|
|
| 1458 |
|
| 1459 |
set_global_seed(random.randint(0, 2**32 - 1))
|
| 1460 |
|
| 1461 |
-
# Use pre-extracted segment clip from wrapper
|
| 1462 |
-
|
|
|
|
| 1463 |
assert seg_path, "[HunyuanFoley regen] seg_path not set — wrapper must pre-extract segment clip"
|
| 1464 |
|
| 1465 |
-
# Use pre-loaded text_feats from CPU wrapper (avoids torch.load inside GPU window)
|
| 1466 |
-
ctx = _regen_hunyuan_gpu._cpu_ctx
|
| 1467 |
if "text_feats" in ctx:
|
| 1468 |
print("[HunyuanFoley regen] Using pre-loaded text features (CPU cache hit)")
|
| 1469 |
from hunyuanvideo_foley.utils.feature_utils import encode_video_features
|
|
@@ -1486,8 +1514,6 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
|
|
| 1486 |
new_wav = new_wav[:, :seg_samples]
|
| 1487 |
return new_wav, sr
|
| 1488 |
|
| 1489 |
-
_regen_hunyuan_gpu._cpu_ctx = {}
|
| 1490 |
-
|
| 1491 |
|
| 1492 |
def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
|
| 1493 |
prompt, negative_prompt, seed_val,
|
|
@@ -1505,7 +1531,7 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
|
|
| 1505 |
meta["silent_video"], seg_start, seg_dur,
|
| 1506 |
os.path.join(tmp_dir, "regen_seg.mp4"),
|
| 1507 |
)
|
| 1508 |
-
|
| 1509 |
|
| 1510 |
# GPU: inference only
|
| 1511 |
new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
|
|
@@ -1575,28 +1601,43 @@ def _xregen_splice(new_wav_raw: np.ndarray, src_sr: int,
|
|
| 1575 |
return video_path, waveform_html
|
| 1576 |
|
| 1577 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1578 |
def xregen_taro(seg_idx, state_json, slot_id,
|
| 1579 |
seed_val, cfg_scale, num_steps, mode,
|
| 1580 |
crossfade_s, crossfade_db,
|
| 1581 |
request: gr.Request = None):
|
| 1582 |
"""Cross-model regen: run TARO inference and splice into *slot_id*."""
|
| 1583 |
-
meta = json.loads(state_json)
|
| 1584 |
seg_idx = int(seg_idx)
|
|
|
|
| 1585 |
|
| 1586 |
-
|
| 1587 |
-
|
| 1588 |
-
|
| 1589 |
-
|
| 1590 |
-
|
| 1591 |
-
|
| 1592 |
|
| 1593 |
-
|
| 1594 |
-
seed_val, cfg_scale, num_steps, mode,
|
| 1595 |
-
crossfade_s, crossfade_db, slot_id)
|
| 1596 |
-
# Upsample 16kHz → 48kHz (sinc, CPU)
|
| 1597 |
-
new_wav_raw = _upsample_taro(new_wav_raw)
|
| 1598 |
-
video_path, waveform_html = _xregen_splice(new_wav_raw, TARO_SR_OUT, meta, seg_idx, slot_id)
|
| 1599 |
-
yield gr.update(value=video_path), gr.update(value=waveform_html)
|
| 1600 |
|
| 1601 |
|
| 1602 |
def xregen_mmaudio(seg_idx, state_json, slot_id,
|
|
@@ -1604,26 +1645,23 @@ def xregen_mmaudio(seg_idx, state_json, slot_id,
|
|
| 1604 |
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
| 1605 |
request: gr.Request = None):
|
| 1606 |
"""Cross-model regen: run MMAudio inference and splice into *slot_id*."""
|
| 1607 |
-
meta = json.loads(state_json)
|
| 1608 |
seg_idx = int(seg_idx)
|
|
|
|
| 1609 |
seg_start, seg_end = meta["segments"][seg_idx]
|
| 1610 |
|
| 1611 |
-
|
| 1612 |
-
|
| 1613 |
-
|
| 1614 |
-
|
| 1615 |
-
|
| 1616 |
-
|
| 1617 |
-
|
| 1618 |
-
|
| 1619 |
-
|
|
|
|
|
|
|
| 1620 |
|
| 1621 |
-
|
| 1622 |
-
prompt, negative_prompt, seed_val,
|
| 1623 |
-
cfg_strength, num_steps,
|
| 1624 |
-
crossfade_s, crossfade_db, slot_id)
|
| 1625 |
-
video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id)
|
| 1626 |
-
yield gr.update(value=video_path), gr.update(value=waveform_html)
|
| 1627 |
|
| 1628 |
|
| 1629 |
def xregen_hunyuan(seg_idx, state_json, slot_id,
|
|
@@ -1632,26 +1670,23 @@ def xregen_hunyuan(seg_idx, state_json, slot_id,
|
|
| 1632 |
crossfade_s, crossfade_db,
|
| 1633 |
request: gr.Request = None):
|
| 1634 |
"""Cross-model regen: run HunyuanFoley inference and splice into *slot_id*."""
|
| 1635 |
-
meta = json.loads(state_json)
|
| 1636 |
seg_idx = int(seg_idx)
|
|
|
|
| 1637 |
seg_start, seg_end = meta["segments"][seg_idx]
|
| 1638 |
|
| 1639 |
-
|
| 1640 |
-
|
| 1641 |
-
|
| 1642 |
-
|
| 1643 |
-
|
| 1644 |
-
|
| 1645 |
-
|
| 1646 |
-
|
| 1647 |
-
|
| 1648 |
-
|
| 1649 |
-
|
| 1650 |
-
|
| 1651 |
-
|
| 1652 |
-
crossfade_s, crossfade_db, slot_id)
|
| 1653 |
-
video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id)
|
| 1654 |
-
yield gr.update(value=video_path), gr.update(value=waveform_html)
|
| 1655 |
|
| 1656 |
|
| 1657 |
# ================================================================== #
|
|
|
|
| 17 |
import random
|
| 18 |
import threading
|
| 19 |
import time
|
| 20 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 21 |
from pathlib import Path
|
| 22 |
|
| 23 |
import torch
|
|
|
|
| 36 |
CACHE_DIR = "/tmp/model_ckpts"
|
| 37 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 38 |
|
| 39 |
+
# ---- Local directories that must exist before parallel downloads start ----
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
MMAUDIO_WEIGHTS_DIR = Path(CACHE_DIR) / "MMAudio" / "weights"
|
| 41 |
MMAUDIO_EXT_DIR = Path(CACHE_DIR) / "MMAudio" / "ext_weights"
|
| 42 |
+
HUNYUAN_MODEL_DIR = Path(CACHE_DIR) / "HunyuanFoley"
|
| 43 |
MMAUDIO_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 44 |
MMAUDIO_EXT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
HUNYUAN_MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 46 |
|
| 47 |
+
# ------------------------------------------------------------------ #
|
| 48 |
+
# Parallel checkpoint + model downloads #
|
| 49 |
+
# All downloads are I/O-bound (network), so running them in threads #
|
| 50 |
+
# cuts Space cold-start time roughly proportional to the number of #
|
| 51 |
+
# independent groups (previously sequential, now concurrent). #
|
| 52 |
+
# hf_hub_download / snapshot_download are thread-safe. #
|
| 53 |
+
# ------------------------------------------------------------------ #
|
| 54 |
+
|
| 55 |
+
def _dl_taro():
|
| 56 |
+
"""Download TARO .ckpt/.pt files and return their local paths."""
|
| 57 |
+
c = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/cavp_epoch66.ckpt", cache_dir=CACHE_DIR)
|
| 58 |
+
o = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/onset_model.ckpt", cache_dir=CACHE_DIR)
|
| 59 |
+
t = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/taro_ckpt.pt", cache_dir=CACHE_DIR)
|
| 60 |
+
print("TARO checkpoints downloaded.")
|
| 61 |
+
return c, o, t
|
| 62 |
+
|
| 63 |
+
def _dl_mmaudio():
|
| 64 |
+
"""Download MMAudio .pth files and return their local paths."""
|
| 65 |
+
m = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/mmaudio_large_44k_v2.pth",
|
| 66 |
+
cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_WEIGHTS_DIR), local_dir_use_symlinks=False)
|
| 67 |
+
v = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/v1-44.pth",
|
| 68 |
+
cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
|
| 69 |
+
s = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/synchformer_state_dict.pth",
|
| 70 |
+
cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
|
| 71 |
+
print("MMAudio checkpoints downloaded.")
|
| 72 |
+
return m, v, s
|
| 73 |
+
|
| 74 |
+
def _dl_hunyuan():
|
| 75 |
+
"""Download HunyuanVideoFoley .pth files."""
|
| 76 |
+
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/hunyuanvideo_foley.pth",
|
| 77 |
+
cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
|
| 78 |
+
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/vae_128d_48k.pth",
|
| 79 |
+
cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
|
| 80 |
+
hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/synchformer_state_dict.pth",
|
| 81 |
+
cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
|
| 82 |
+
print("HunyuanVideoFoley checkpoints downloaded.")
|
| 83 |
+
|
| 84 |
+
def _dl_clap():
|
| 85 |
+
"""Pre-download CLAP so from_pretrained() hits local cache inside the ZeroGPU worker."""
|
| 86 |
+
snapshot_download(repo_id="laion/larger_clap_general")
|
| 87 |
+
print("CLAP model pre-downloaded.")
|
| 88 |
+
|
| 89 |
+
def _dl_clip():
|
| 90 |
+
"""Pre-download MMAudio's CLIP model (~3.95 GB) to avoid GPU-window budget drain."""
|
| 91 |
+
snapshot_download(repo_id="apple/DFN5B-CLIP-ViT-H-14-384")
|
| 92 |
+
print("MMAudio CLIP model pre-downloaded.")
|
| 93 |
+
|
| 94 |
+
def _dl_audioldm2():
|
| 95 |
+
"""Pre-download AudioLDM2 VAE/vocoder used by TARO's from_pretrained() calls."""
|
| 96 |
+
snapshot_download(repo_id="cvssp/audioldm2")
|
| 97 |
+
print("AudioLDM2 pre-downloaded.")
|
| 98 |
+
|
| 99 |
+
def _dl_bigvgan():
|
| 100 |
+
"""Pre-download BigVGAN vocoder (~489 MB) used by MMAudio."""
|
| 101 |
+
snapshot_download(repo_id="nvidia/bigvgan_v2_44khz_128band_512x")
|
| 102 |
+
print("BigVGAN vocoder pre-downloaded.")
|
| 103 |
+
|
| 104 |
+
print("[startup] Starting parallel checkpoint + model downloads…")
|
| 105 |
+
_t_dl_start = time.perf_counter()
|
| 106 |
+
with ThreadPoolExecutor(max_workers=7) as _pool:
|
| 107 |
+
_fut_taro = _pool.submit(_dl_taro)
|
| 108 |
+
_fut_mmaudio = _pool.submit(_dl_mmaudio)
|
| 109 |
+
_fut_hunyuan = _pool.submit(_dl_hunyuan)
|
| 110 |
+
_fut_clap = _pool.submit(_dl_clap)
|
| 111 |
+
_fut_clip = _pool.submit(_dl_clip)
|
| 112 |
+
_fut_aldm2 = _pool.submit(_dl_audioldm2)
|
| 113 |
+
_fut_bigvgan = _pool.submit(_dl_bigvgan)
|
| 114 |
+
# Raise any download exceptions immediately
|
| 115 |
+
for _fut in as_completed([_fut_taro, _fut_mmaudio, _fut_hunyuan,
|
| 116 |
+
_fut_clap, _fut_clip, _fut_aldm2, _fut_bigvgan]):
|
| 117 |
+
_fut.result()
|
| 118 |
+
|
| 119 |
+
cavp_ckpt_path, onset_ckpt_path, taro_ckpt_path = _fut_taro.result()
|
| 120 |
+
mmaudio_model_path, mmaudio_vae_path, mmaudio_synchformer_path = _fut_mmaudio.result()
|
| 121 |
+
print(f"[startup] All downloads done in {time.perf_counter() - _t_dl_start:.1f}s")
|
| 122 |
|
| 123 |
# ================================================================== #
|
| 124 |
# SHARED CONSTANTS / HELPERS #
|
| 125 |
# ================================================================== #
|
| 126 |
|
| 127 |
+
# Thread-local storage for CPU → GPU context passing.
|
| 128 |
+
# Replaces the fragile function-attribute pattern (_fn._cpu_ctx = {...}).
|
| 129 |
+
# Each wrapper writes its context under a unique key before calling the
|
| 130 |
+
# @spaces.GPU function; the GPU function reads it back. Using thread-local
|
| 131 |
+
# storage means concurrent requests on different threads don't clobber
|
| 132 |
+
# each other's context — the function-attribute approach was not thread-safe.
|
| 133 |
+
_tl = threading.local()
|
| 134 |
+
|
| 135 |
MAX_SLOTS = 8 # max parallel generation slots shown in UI
|
| 136 |
MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≤ ~64 s at 8 s/seg)
|
| 137 |
|
|
|
|
| 385 |
pix_fmt="yuv420p",
|
| 386 |
acodec="aac", audio_bitrate="128k",
|
| 387 |
movflags="+faststart",
|
| 388 |
+
).run(overwrite_output=True, quiet=True)
|
| 389 |
|
| 390 |
|
| 391 |
# ------------------------------------------------------------------ #
|
|
|
|
| 451 |
# latents_scale: [0.18215]*8 — AudioLDM2 VAE scale factor
|
| 452 |
# ================================================================== #
|
| 453 |
|
| 454 |
+
# ================================================================== #
|
| 455 |
+
# MODEL CONSTANTS & CONFIGURATION REGISTRY #
|
| 456 |
+
# ================================================================== #
|
| 457 |
+
# All per-model numeric constants live here — MODEL_CONFIGS is the #
|
| 458 |
+
# single source of truth consumed by duration estimation, segmentation,#
|
| 459 |
+
# and the UI. Standalone names kept only where other code references #
|
| 460 |
+
# them by name (TARO geometry, TARGET_SR, GPU_DURATION_CAP). #
|
| 461 |
+
# ================================================================== #
|
| 462 |
+
|
| 463 |
+
# TARO geometry — referenced directly in _taro_infer_segment
|
| 464 |
+
TARO_SR = 16000
|
| 465 |
+
TARO_TRUNCATE = 131072
|
| 466 |
+
TARO_FPS = 4
|
| 467 |
+
TARO_TRUNCATE_FRAME = int(TARO_FPS * TARO_TRUNCATE / TARO_SR) # 32
|
| 468 |
TARO_TRUNCATE_ONSET = 120
|
| 469 |
+
TARO_MODEL_DUR = TARO_TRUNCATE / TARO_SR # 8.192 s
|
| 470 |
+
|
| 471 |
+
GPU_DURATION_CAP = 300 # hard cap per @spaces.GPU call — never reserve more than this
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
MODEL_CONFIGS = {
|
| 474 |
"taro": {
|
| 475 |
+
"window_s": TARO_MODEL_DUR, # 8.192 s
|
| 476 |
+
"sr": TARO_SR, # 16000 (output resampled to TARGET_SR)
|
| 477 |
+
"secs_per_step": 0.025, # measured 0.023 s/step on H200
|
| 478 |
+
"load_overhead": 15, # model load + CAVP feature extraction
|
| 479 |
"tab_prefix": "taro",
|
|
|
|
| 480 |
"label": "TARO",
|
| 481 |
+
"regen_fn": None, # set after function definitions (avoids forward-ref)
|
| 482 |
},
|
| 483 |
"mmaudio": {
|
| 484 |
+
"window_s": 8.0, # MMAudio's fixed generation window
|
| 485 |
+
"sr": 48000, # resampled from 44100 in post-processing
|
| 486 |
+
"secs_per_step": 0.25, # measured 0.230 s/step on H200
|
| 487 |
+
"load_overhead": 30, # 15s warm + 15s model init
|
| 488 |
"tab_prefix": "mma",
|
|
|
|
| 489 |
"label": "MMAudio",
|
| 490 |
+
"regen_fn": None,
|
| 491 |
},
|
| 492 |
"hunyuan": {
|
| 493 |
+
"window_s": 15.0, # HunyuanFoley max video duration
|
| 494 |
"sr": 48000,
|
| 495 |
+
"secs_per_step": 0.35, # measured 0.328 s/step on H200
|
| 496 |
+
"load_overhead": 55, # ~55s to load the 10 GB XXL weights
|
| 497 |
"tab_prefix": "hf",
|
|
|
|
| 498 |
"label": "HunyuanFoley",
|
| 499 |
+
"regen_fn": None,
|
| 500 |
},
|
| 501 |
}
|
| 502 |
|
| 503 |
+
# Convenience aliases used only in the TARO inference path
|
| 504 |
+
TARO_SECS_PER_STEP = MODEL_CONFIGS["taro"]["secs_per_step"]
|
| 505 |
+
MMAUDIO_WINDOW = MODEL_CONFIGS["mmaudio"]["window_s"]
|
| 506 |
+
MMAUDIO_SECS_PER_STEP = MODEL_CONFIGS["mmaudio"]["secs_per_step"]
|
| 507 |
+
HUNYUAN_MAX_DUR = MODEL_CONFIGS["hunyuan"]["window_s"]
|
| 508 |
+
HUNYUAN_SECS_PER_STEP = MODEL_CONFIGS["hunyuan"]["secs_per_step"]
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def _clamp_duration(secs: float, label: str) -> int:
|
| 512 |
+
"""Clamp a raw GPU-seconds estimate to [60, GPU_DURATION_CAP] and log it."""
|
| 513 |
+
result = min(GPU_DURATION_CAP, max(60, int(secs)))
|
| 514 |
+
print(f"[duration] {label}: {secs:.0f}s raw → {result}s reserved")
|
| 515 |
+
return result
|
| 516 |
+
|
| 517 |
|
| 518 |
def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int,
|
| 519 |
total_dur_s: float = None, crossfade_s: float = 0,
|
| 520 |
video_file: str = None) -> int:
|
| 521 |
+
"""Estimate GPU seconds for a full generation call.
|
| 522 |
|
| 523 |
+
Formula: num_samples × n_segs × num_steps × secs_per_step + load_overhead
|
|
|
|
| 524 |
"""
|
| 525 |
cfg = MODEL_CONFIGS[model_key]
|
| 526 |
try:
|
|
|
|
| 529 |
n_segs = len(_build_segments(total_dur_s, cfg["window_s"], float(crossfade_s)))
|
| 530 |
except Exception:
|
| 531 |
n_segs = 1
|
| 532 |
+
secs = int(num_samples) * n_segs * int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
|
|
|
|
| 533 |
print(f"[duration] {cfg['label']}: {int(num_samples)}samp × {n_segs}seg × "
|
| 534 |
+
f"{int(num_steps)}steps → {secs:.0f}s → capped ", end="")
|
| 535 |
+
return _clamp_duration(secs, cfg["label"])
|
| 536 |
|
| 537 |
|
| 538 |
def _estimate_regen_duration(model_key: str, num_steps: int) -> int:
|
| 539 |
+
"""Estimate GPU seconds for a single-segment regen call."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
cfg = MODEL_CONFIGS[model_key]
|
| 541 |
secs = int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
|
| 542 |
+
print(f"[duration] {cfg['label']} regen: 1 seg × {int(num_steps)} steps → ", end="")
|
| 543 |
+
return _clamp_duration(secs, f"{cfg['label']} regen")
|
|
|
|
| 544 |
|
| 545 |
_TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit
|
| 546 |
_TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s)
|
|
|
|
| 788 |
from TARO.onset_util import extract_onset
|
| 789 |
from TARO.samplers import euler_sampler, euler_maruyama_sampler
|
| 790 |
|
| 791 |
+
# Use pre-computed CPU results passed via thread-local storage
|
| 792 |
+
ctx = _tl.taro_gen_ctx
|
| 793 |
tmp_dir = ctx["tmp_dir"]
|
| 794 |
silent_video = ctx["silent_video"]
|
| 795 |
segments = ctx["segments"]
|
|
|
|
| 848 |
|
| 849 |
return results
|
| 850 |
|
|
|
|
|
|
|
|
|
|
| 851 |
|
| 852 |
def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 853 |
crossfade_s, crossfade_db, num_samples):
|
|
|
|
| 861 |
tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
|
| 862 |
video_file, TARO_MODEL_DUR, crossfade_s)
|
| 863 |
|
| 864 |
+
# Pass pre-computed CPU results to the GPU function via thread-local storage
|
| 865 |
+
_tl.taro_gen_ctx = {
|
| 866 |
"tmp_dir": tmp_dir, "silent_video": silent_video,
|
| 867 |
"segments": segments, "total_dur_s": total_dur_s,
|
| 868 |
}
|
|
|
|
| 941 |
|
| 942 |
net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
|
| 943 |
|
| 944 |
+
ctx = _tl.mmaudio_gen_ctx
|
| 945 |
segments = ctx["segments"]
|
| 946 |
seg_clip_paths = ctx["seg_clip_paths"]
|
| 947 |
|
|
|
|
| 1001 |
|
| 1002 |
return results
|
| 1003 |
|
|
|
|
|
|
|
| 1004 |
|
| 1005 |
def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
|
| 1006 |
cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
|
|
|
|
| 1020 |
for i, (s, e) in enumerate(segments)
|
| 1021 |
]
|
| 1022 |
|
| 1023 |
+
_tl.mmaudio_gen_ctx = {
|
| 1024 |
"segments": segments, "seg_clip_paths": seg_clip_paths,
|
| 1025 |
}
|
| 1026 |
|
|
|
|
| 1090 |
|
| 1091 |
model_dict, cfg = _load_hunyuan_model(device, model_size)
|
| 1092 |
|
| 1093 |
+
ctx = _tl.hunyuan_gen_ctx
|
| 1094 |
segments = ctx["segments"]
|
| 1095 |
total_dur_s = ctx["total_dur_s"]
|
| 1096 |
dummy_seg_path = ctx["dummy_seg_path"]
|
|
|
|
| 1148 |
|
| 1149 |
return results
|
| 1150 |
|
|
|
|
|
|
|
| 1151 |
|
| 1152 |
def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
|
| 1153 |
guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
|
|
|
|
| 1174 |
for i, (s, e) in enumerate(segments)
|
| 1175 |
]
|
| 1176 |
|
| 1177 |
+
_tl.hunyuan_gen_ctx = {
|
| 1178 |
"segments": segments, "total_dur_s": total_dur_s,
|
| 1179 |
"dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
|
| 1180 |
}
|
|
|
|
| 1213 |
|
| 1214 |
def _preload_taro_regen_ctx(meta: dict) -> dict:
|
| 1215 |
"""Pre-load TARO CAVP/onset features on CPU for regen.
|
| 1216 |
+
Returns a dict for _tl.taro_regen_ctx (thread-local storage)."""
|
| 1217 |
cavp_path = meta.get("cavp_path", "")
|
| 1218 |
onset_path = meta.get("onset_path", "")
|
| 1219 |
ctx = {}
|
|
|
|
| 1225 |
|
| 1226 |
def _preload_hunyuan_regen_ctx(meta: dict, seg_path: str) -> dict:
|
| 1227 |
"""Pre-load HunyuanFoley text features + segment path on CPU for regen.
|
| 1228 |
+
Returns a dict for _tl.hunyuan_regen_ctx (thread-local storage)."""
|
| 1229 |
ctx = {"seg_path": seg_path}
|
| 1230 |
text_feats_path = meta.get("text_feats_path", "")
|
| 1231 |
if text_feats_path and os.path.exists(text_feats_path):
|
|
|
|
| 1316 |
from TARO.samplers import euler_sampler, euler_maruyama_sampler
|
| 1317 |
|
| 1318 |
# Use pre-loaded features from CPU wrapper (avoids np.load inside GPU window)
|
| 1319 |
+
ctx = getattr(_tl, "taro_regen_ctx", {})
|
| 1320 |
if "cavp" in ctx and "onset" in ctx:
|
| 1321 |
print("[TARO regen] Using pre-loaded CAVP + onset features (CPU cache hit)")
|
| 1322 |
cavp_feats = ctx["cavp"]
|
|
|
|
| 1354 |
seg_idx = int(seg_idx)
|
| 1355 |
|
| 1356 |
# CPU: pre-load cached features so np.load doesn't happen inside GPU window
|
| 1357 |
+
_tl.taro_regen_ctx = _preload_taro_regen_ctx(meta)
|
| 1358 |
|
| 1359 |
# GPU: inference only
|
| 1360 |
new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
|
|
|
|
| 1396 |
sr = seq_cfg.sampling_rate
|
| 1397 |
|
| 1398 |
# Use pre-extracted segment clip from the CPU wrapper
|
| 1399 |
+
seg_path = getattr(_tl, "mmaudio_regen_ctx", {}).get("seg_path")
|
| 1400 |
assert seg_path, "[MMAudio regen] seg_path not set — wrapper must pre-extract segment clip"
|
| 1401 |
|
| 1402 |
rng = torch.Generator(device=device)
|
|
|
|
| 1422 |
new_wav = new_wav[:, :seg_samples]
|
| 1423 |
return new_wav, sr
|
| 1424 |
|
|
|
|
|
|
|
| 1425 |
|
| 1426 |
def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
|
| 1427 |
prompt, negative_prompt, seed_val,
|
|
|
|
| 1438 |
meta["silent_video"], seg_start, seg_dur,
|
| 1439 |
os.path.join(tmp_dir, "regen_seg.mp4"),
|
| 1440 |
)
|
| 1441 |
+
_tl.mmaudio_regen_ctx = {"seg_path": seg_path}
|
| 1442 |
|
| 1443 |
# GPU: inference only
|
| 1444 |
new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
|
|
|
|
| 1487 |
|
| 1488 |
set_global_seed(random.randint(0, 2**32 - 1))
|
| 1489 |
|
| 1490 |
+
# Use pre-extracted segment clip + text_feats from CPU wrapper
|
| 1491 |
+
ctx = getattr(_tl, "hunyuan_regen_ctx", {})
|
| 1492 |
+
seg_path = ctx.get("seg_path")
|
| 1493 |
assert seg_path, "[HunyuanFoley regen] seg_path not set — wrapper must pre-extract segment clip"
|
| 1494 |
|
|
|
|
|
|
|
| 1495 |
if "text_feats" in ctx:
|
| 1496 |
print("[HunyuanFoley regen] Using pre-loaded text features (CPU cache hit)")
|
| 1497 |
from hunyuanvideo_foley.utils.feature_utils import encode_video_features
|
|
|
|
| 1514 |
new_wav = new_wav[:, :seg_samples]
|
| 1515 |
return new_wav, sr
|
| 1516 |
|
|
|
|
|
|
|
| 1517 |
|
| 1518 |
def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
|
| 1519 |
prompt, negative_prompt, seed_val,
|
|
|
|
| 1531 |
meta["silent_video"], seg_start, seg_dur,
|
| 1532 |
os.path.join(tmp_dir, "regen_seg.mp4"),
|
| 1533 |
)
|
| 1534 |
+
_tl.hunyuan_regen_ctx = _preload_hunyuan_regen_ctx(meta, seg_path)
|
| 1535 |
|
| 1536 |
# GPU: inference only
|
| 1537 |
new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
|
|
|
|
| 1601 |
return video_path, waveform_html
|
| 1602 |
|
| 1603 |
|
| 1604 |
+
def _xregen_dispatch(state_json: str, seg_idx: int, slot_id: str, infer_fn):
|
| 1605 |
+
"""Shared generator skeleton for all xregen_* wrappers.
|
| 1606 |
+
|
| 1607 |
+
Yields pending HTML immediately, then calls *infer_fn()* — a zero-argument
|
| 1608 |
+
callable that runs model-specific CPU prep + GPU inference and returns
|
| 1609 |
+
(wav_array, src_sr). For TARO, *infer_fn* should return the wav already
|
| 1610 |
+
upsampled to 48 kHz; pass TARO_SR_OUT as src_sr.
|
| 1611 |
+
|
| 1612 |
+
Yields:
|
| 1613 |
+
First: (gr.update(), gr.update(value=pending_html)) — shown while GPU runs
|
| 1614 |
+
Second: (gr.update(value=video_path), gr.update(value=waveform_html))
|
| 1615 |
+
"""
|
| 1616 |
+
meta = json.loads(state_json)
|
| 1617 |
+
pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
|
| 1618 |
+
yield gr.update(), gr.update(value=pending_html)
|
| 1619 |
+
|
| 1620 |
+
new_wav_raw, src_sr = infer_fn()
|
| 1621 |
+
video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id)
|
| 1622 |
+
yield gr.update(value=video_path), gr.update(value=waveform_html)
|
| 1623 |
+
|
| 1624 |
+
|
| 1625 |
def xregen_taro(seg_idx, state_json, slot_id,
|
| 1626 |
seed_val, cfg_scale, num_steps, mode,
|
| 1627 |
crossfade_s, crossfade_db,
|
| 1628 |
request: gr.Request = None):
|
| 1629 |
"""Cross-model regen: run TARO inference and splice into *slot_id*."""
|
|
|
|
| 1630 |
seg_idx = int(seg_idx)
|
| 1631 |
+
meta = json.loads(state_json)
|
| 1632 |
|
| 1633 |
+
def _run():
|
| 1634 |
+
_tl.taro_regen_ctx = _preload_taro_regen_ctx(meta)
|
| 1635 |
+
wav = _regen_taro_gpu(None, seg_idx, state_json,
|
| 1636 |
+
seed_val, cfg_scale, num_steps, mode,
|
| 1637 |
+
crossfade_s, crossfade_db, slot_id)
|
| 1638 |
+
return _upsample_taro(wav), TARO_SR_OUT # 16 kHz → 48 kHz (CPU)
|
| 1639 |
|
| 1640 |
+
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1641 |
|
| 1642 |
|
| 1643 |
def xregen_mmaudio(seg_idx, state_json, slot_id,
|
|
|
|
| 1645 |
cfg_strength, num_steps, crossfade_s, crossfade_db,
|
| 1646 |
request: gr.Request = None):
|
| 1647 |
"""Cross-model regen: run MMAudio inference and splice into *slot_id*."""
|
|
|
|
| 1648 |
seg_idx = int(seg_idx)
|
| 1649 |
+
meta = json.loads(state_json)
|
| 1650 |
seg_start, seg_end = meta["segments"][seg_idx]
|
| 1651 |
|
| 1652 |
+
def _run():
|
| 1653 |
+
seg_path = _extract_segment_clip(
|
| 1654 |
+
meta["silent_video"], seg_start, seg_end - seg_start,
|
| 1655 |
+
os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
|
| 1656 |
+
)
|
| 1657 |
+
_tl.mmaudio_regen_ctx = {"seg_path": seg_path}
|
| 1658 |
+
wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
|
| 1659 |
+
prompt, negative_prompt, seed_val,
|
| 1660 |
+
cfg_strength, num_steps,
|
| 1661 |
+
crossfade_s, crossfade_db, slot_id)
|
| 1662 |
+
return wav, src_sr
|
| 1663 |
|
| 1664 |
+
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1665 |
|
| 1666 |
|
| 1667 |
def xregen_hunyuan(seg_idx, state_json, slot_id,
|
|
|
|
| 1670 |
crossfade_s, crossfade_db,
|
| 1671 |
request: gr.Request = None):
|
| 1672 |
"""Cross-model regen: run HunyuanFoley inference and splice into *slot_id*."""
|
|
|
|
| 1673 |
seg_idx = int(seg_idx)
|
| 1674 |
+
meta = json.loads(state_json)
|
| 1675 |
seg_start, seg_end = meta["segments"][seg_idx]
|
| 1676 |
|
| 1677 |
+
def _run():
|
| 1678 |
+
seg_path = _extract_segment_clip(
|
| 1679 |
+
meta["silent_video"], seg_start, seg_end - seg_start,
|
| 1680 |
+
os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
|
| 1681 |
+
)
|
| 1682 |
+
_tl.hunyuan_regen_ctx = _preload_hunyuan_regen_ctx(meta, seg_path)
|
| 1683 |
+
wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
|
| 1684 |
+
prompt, negative_prompt, seed_val,
|
| 1685 |
+
guidance_scale, num_steps, model_size,
|
| 1686 |
+
crossfade_s, crossfade_db, slot_id)
|
| 1687 |
+
return wav, src_sr
|
| 1688 |
+
|
| 1689 |
+
yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
|
|
|
|
|
|
|
|
|
|
| 1690 |
|
| 1691 |
|
| 1692 |
# ================================================================== #
|