BoxOfColors Claude Sonnet 4.6 commited on
Commit
7592f82
·
1 Parent(s): ebde550

refactor: replace _cpu_ctx with thread-local storage, deduplicate xregen wrappers, parallel downloads, quiet=True

Browse files

- Replace fragile function-attribute CPU→GPU context passing (_fn._cpu_ctx = {})
with thread-local storage (_tl.<name>_ctx) for thread safety under ZeroGPU
multi-user concurrency — 6 sites updated across generate_* and regen_* paths
- Add _xregen_dispatch() generator helper to deduplicate the pending-yield /
infer / splice-yield skeleton shared by xregen_taro, xregen_mmaudio,
xregen_hunyuan (~40 lines removed)
- Parallelize all 7 startup downloads with ThreadPoolExecutor (I/O-bound network
calls run concurrently, cutting Space cold-start time ~proportionally)
- Consolidate per-model scalar constants into MODEL_CONFIGS as single source of
truth; add _clamp_duration() / _estimate_gpu_duration() / _estimate_regen_duration()
helpers to eliminate repeated duration-clamping boilerplate
- Restore quiet=True in mux_video_audio (was temporarily quiet=False for debugging)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +214 -179
app.py CHANGED
@@ -17,6 +17,7 @@ import tempfile
17
  import random
18
  import threading
19
  import time
 
20
  from pathlib import Path
21
 
22
  import torch
@@ -35,69 +36,102 @@ CKPT_REPO_ID = "JackIsNotInTheBox/Generate_Audio_for_Video_Checkpoints"
35
  CACHE_DIR = "/tmp/model_ckpts"
36
  os.makedirs(CACHE_DIR, exist_ok=True)
37
 
38
- # ---- TARO checkpoints (in TARO/ subfolder of the HF repo) ----
39
- print("Downloading TARO checkpoints…")
40
- cavp_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/cavp_epoch66.ckpt", cache_dir=CACHE_DIR)
41
- onset_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/onset_model.ckpt", cache_dir=CACHE_DIR)
42
- taro_ckpt_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/taro_ckpt.pt", cache_dir=CACHE_DIR)
43
- print("TARO checkpoints downloaded.")
44
-
45
- # ---- MMAudio checkpoints (in MMAudio/ subfolder) ----
46
- # MMAudio normally auto-downloads from its own HF repo, but we
47
- # override the paths so it pulls from our consolidated repo instead.
48
  MMAUDIO_WEIGHTS_DIR = Path(CACHE_DIR) / "MMAudio" / "weights"
49
  MMAUDIO_EXT_DIR = Path(CACHE_DIR) / "MMAudio" / "ext_weights"
 
50
  MMAUDIO_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
51
  MMAUDIO_EXT_DIR.mkdir(parents=True, exist_ok=True)
52
-
53
- print("Downloading MMAudio checkpoints…")
54
- mmaudio_model_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/mmaudio_large_44k_v2.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_WEIGHTS_DIR), local_dir_use_symlinks=False)
55
- mmaudio_vae_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/v1-44.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
56
- mmaudio_synchformer_path = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/synchformer_state_dict.pth", cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
57
- print("MMAudio checkpoints downloaded.")
58
-
59
- # ---- HunyuanVideoFoley checkpoints (in HunyuanFoley/ subfolder) ----
60
- HUNYUAN_MODEL_DIR = Path(CACHE_DIR) / "HunyuanFoley"
61
  HUNYUAN_MODEL_DIR.mkdir(parents=True, exist_ok=True)
62
 
63
- print("Downloading HunyuanVideoFoley checkpoints…")
64
- hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/hunyuanvideo_foley.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
65
- hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/vae_128d_48k.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
66
- hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/synchformer_state_dict.pth", cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
67
- print("HunyuanVideoFoley checkpoints downloaded.")
68
-
69
- # Pre-download CLAP model so from_pretrained() reads local cache inside the
70
- # ZeroGPU daemonic worker (spawning child processes there is not allowed).
71
- print("Pre-downloading CLAP model (laion/larger_clap_general)…")
72
- snapshot_download(repo_id="laion/larger_clap_general")
73
- print("CLAP model pre-downloaded.")
74
-
75
- # Pre-download MMAudio's CLIP model (apple/DFN5B-CLIP-ViT-H-14-384, ~3.95 GB).
76
- # open_clip.create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384')
77
- # fetches this at first use — inside the GPU window on cold workers — which
78
- # burns ~5-10s of the allocated ZeroGPU budget before inference even starts.
79
- print("Pre-downloading MMAudio CLIP model (apple/DFN5B-CLIP-ViT-H-14-384)…")
80
- snapshot_download(repo_id="apple/DFN5B-CLIP-ViT-H-14-384")
81
- print("MMAudio CLIP model pre-downloaded.")
82
-
83
- # Pre-download TARO's AudioLDM2 VAE + vocoder (cvssp/audioldm2).
84
- # AutoencoderKL.from_pretrained() and SpeechT5HifiGan.from_pretrained() fetch
85
- # this repo inside the GPU window on every cold worker start, burning GPU budget
86
- # before inference even begins. Pre-fetching here ensures the cache is warm.
87
- print("Pre-downloading AudioLDM2 (cvssp/audioldm2)…")
88
- snapshot_download(repo_id="cvssp/audioldm2")
89
- print("AudioLDM2 pre-downloaded.")
90
-
91
- # Pre-download MMAudio's BigVGAN vocoder (nvidia/bigvgan_v2_44khz_128band_512x, ~489MB).
92
- # This is fetched inside the GPU window on cold workers during MMAudio inference/regen.
93
- print("Pre-downloading BigVGAN vocoder (nvidia/bigvgan_v2_44khz_128band_512x)…")
94
- snapshot_download(repo_id="nvidia/bigvgan_v2_44khz_128band_512x")
95
- print("BigVGAN vocoder pre-downloaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  # ================================================================== #
98
  # SHARED CONSTANTS / HELPERS #
99
  # ================================================================== #
100
 
 
 
 
 
 
 
 
 
101
  MAX_SLOTS = 8 # max parallel generation slots shown in UI
102
  MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≤ ~64 s at 8 s/seg)
103
 
@@ -351,7 +385,7 @@ def mux_video_audio(silent_video: str, audio_path: str, output_path: str,
351
  pix_fmt="yuv420p",
352
  acodec="aac", audio_bitrate="128k",
353
  movflags="+faststart",
354
- ).run(overwrite_output=True, quiet=False)
355
 
356
 
357
  # ------------------------------------------------------------------ #
@@ -417,65 +451,76 @@ def _cf_join(a: np.ndarray, b: np.ndarray,
417
  # latents_scale: [0.18215]*8 — AudioLDM2 VAE scale factor
418
  # ================================================================== #
419
 
420
- TARO_SR = 16000
421
- TARO_TRUNCATE = 131072
422
- TARO_FPS = 4
423
- TARO_TRUNCATE_FRAME = int(TARO_FPS * TARO_TRUNCATE / TARO_SR) # 32
 
 
 
 
 
 
 
 
 
 
424
  TARO_TRUNCATE_ONSET = 120
425
- TARO_MODEL_DUR = TARO_TRUNCATE / TARO_SR # 8.192 s
426
- TARO_SECS_PER_STEP = 0.025 # measured 0.023s/step on H200; was 0.05, tightened to halve GPU allocation
427
-
428
- TARO_LOAD_OVERHEAD = 15 # seconds: model load + CAVP feature extraction
429
- MMAUDIO_WINDOW = 8.0 # seconds — MMAudio's fixed generation window
430
- MMAUDIO_SECS_PER_STEP = 0.25 # measured 0.230s/step on H200 (8.3s video, 2 segs × 25 steps = 11.5s wall)
431
- MMAUDIO_LOAD_OVERHEAD = 30 # 15s warm + 15s model init; open_clip pre-downloaded at startup
432
- HUNYUAN_MAX_DUR = 15.0 # seconds — HunyuanFoley max video duration
433
- HUNYUAN_SECS_PER_STEP = 0.35 # measured 0.328s/step on H200 (8.3s video, 1 seg × 50 steps = 16.4s wall)
434
- HUNYUAN_LOAD_OVERHEAD = 55 # ~55s to load the 10GB XXL model weights into GPU
435
- GPU_DURATION_CAP = 300 # hard cap per call — never reserve more than this
436
 
437
- # ------------------------------------------------------------------ #
438
- # Model configuration registry — single source of truth for per-model #
439
- # constants used by duration estimation, segmentation, and UI. #
440
- # ------------------------------------------------------------------ #
441
  MODEL_CONFIGS = {
442
  "taro": {
443
- "window_s": TARO_MODEL_DUR, # 8.192 s
444
- "sr": TARO_SR, # 16000
445
- "secs_per_step": TARO_SECS_PER_STEP, # 0.025
446
- "load_overhead": TARO_LOAD_OVERHEAD, # 15
447
  "tab_prefix": "taro",
448
- "regen_fn": None, # set after function definitions (avoids forward-ref)
449
  "label": "TARO",
 
450
  },
451
  "mmaudio": {
452
- "window_s": MMAUDIO_WINDOW, # 8.0 s
453
- "sr": 48000, # resampled to 48kHz in post-processing
454
- "secs_per_step": MMAUDIO_SECS_PER_STEP, # 0.25
455
- "load_overhead": MMAUDIO_LOAD_OVERHEAD, # 15
456
  "tab_prefix": "mma",
457
- "regen_fn": None,
458
  "label": "MMAudio",
 
459
  },
460
  "hunyuan": {
461
- "window_s": HUNYUAN_MAX_DUR, # 15.0 s
462
  "sr": 48000,
463
- "secs_per_step": HUNYUAN_SECS_PER_STEP, # 0.35
464
- "load_overhead": HUNYUAN_LOAD_OVERHEAD, # 55
465
  "tab_prefix": "hf",
466
- "regen_fn": None,
467
  "label": "HunyuanFoley",
 
468
  },
469
  }
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
  def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int,
473
  total_dur_s: float = None, crossfade_s: float = 0,
474
  video_file: str = None) -> int:
475
- """Generic GPU duration estimator used by all models.
476
 
477
- Computes: num_samples × n_segs × num_steps × secs_per_step + load_overhead
478
- Clamped to [60, GPU_DURATION_CAP].
479
  """
480
  cfg = MODEL_CONFIGS[model_key]
481
  try:
@@ -484,25 +529,18 @@ def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int,
484
  n_segs = len(_build_segments(total_dur_s, cfg["window_s"], float(crossfade_s)))
485
  except Exception:
486
  n_segs = 1
487
- secs = int(num_samples) * n_segs * int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
488
- result = min(GPU_DURATION_CAP, max(60, int(secs)))
489
  print(f"[duration] {cfg['label']}: {int(num_samples)}samp × {n_segs}seg × "
490
- f"{int(num_steps)}steps → {secs:.0f}s → capped {result}s")
491
- return result
492
 
493
 
494
  def _estimate_regen_duration(model_key: str, num_steps: int) -> int:
495
- """Generic GPU duration estimator for single-segment regen.
496
-
497
- Floor is 20s — enough headroom above the 10s ZeroGPU abort threshold
498
- for any model on a warm worker. Cold-start spin-up happens *before*
499
- the timer starts so raising the floor does not help with cold-start aborts.
500
- """
501
  cfg = MODEL_CONFIGS[model_key]
502
  secs = int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
503
- result = min(GPU_DURATION_CAP, max(60, int(secs)))
504
- print(f"[duration] {cfg['label']} regen: 1 seg × {int(num_steps)} steps → {secs:.0f}s → capped {result}s")
505
- return result
506
 
507
  _TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit
508
  _TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s)
@@ -750,8 +788,8 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
750
  from TARO.onset_util import extract_onset
751
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
752
 
753
- # Use pre-computed CPU results from the wrapper
754
- ctx = _taro_gpu_infer._cpu_ctx
755
  tmp_dir = ctx["tmp_dir"]
756
  silent_video = ctx["silent_video"]
757
  segments = ctx["segments"]
@@ -810,9 +848,6 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
810
 
811
  return results
812
 
813
- # Attach a context slot for the CPU wrapper to pass pre-computed data
814
- _taro_gpu_infer._cpu_ctx = {}
815
-
816
 
817
  def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
818
  crossfade_s, crossfade_db, num_samples):
@@ -826,8 +861,8 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
826
  tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
827
  video_file, TARO_MODEL_DUR, crossfade_s)
828
 
829
- # Pass pre-computed CPU results to the GPU function via context
830
- _taro_gpu_infer._cpu_ctx = {
831
  "tmp_dir": tmp_dir, "silent_video": silent_video,
832
  "segments": segments, "total_dur_s": total_dur_s,
833
  }
@@ -906,7 +941,7 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
906
 
907
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
908
 
909
- ctx = _mmaudio_gpu_infer._cpu_ctx
910
  segments = ctx["segments"]
911
  seg_clip_paths = ctx["seg_clip_paths"]
912
 
@@ -966,8 +1001,6 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
966
 
967
  return results
968
 
969
- _mmaudio_gpu_infer._cpu_ctx = {}
970
-
971
 
972
  def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
973
  cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
@@ -987,7 +1020,7 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
987
  for i, (s, e) in enumerate(segments)
988
  ]
989
 
990
- _mmaudio_gpu_infer._cpu_ctx = {
991
  "segments": segments, "seg_clip_paths": seg_clip_paths,
992
  }
993
 
@@ -1057,7 +1090,7 @@ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1057
 
1058
  model_dict, cfg = _load_hunyuan_model(device, model_size)
1059
 
1060
- ctx = _hunyuan_gpu_infer._cpu_ctx
1061
  segments = ctx["segments"]
1062
  total_dur_s = ctx["total_dur_s"]
1063
  dummy_seg_path = ctx["dummy_seg_path"]
@@ -1115,8 +1148,6 @@ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1115
 
1116
  return results
1117
 
1118
- _hunyuan_gpu_infer._cpu_ctx = {}
1119
-
1120
 
1121
  def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1122
  guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
@@ -1143,7 +1174,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1143
  for i, (s, e) in enumerate(segments)
1144
  ]
1145
 
1146
- _hunyuan_gpu_infer._cpu_ctx = {
1147
  "segments": segments, "total_dur_s": total_dur_s,
1148
  "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
1149
  }
@@ -1182,7 +1213,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1182
 
1183
  def _preload_taro_regen_ctx(meta: dict) -> dict:
1184
  """Pre-load TARO CAVP/onset features on CPU for regen.
1185
- Returns a dict suitable for _regen_taro_gpu._cpu_ctx."""
1186
  cavp_path = meta.get("cavp_path", "")
1187
  onset_path = meta.get("onset_path", "")
1188
  ctx = {}
@@ -1194,7 +1225,7 @@ def _preload_taro_regen_ctx(meta: dict) -> dict:
1194
 
1195
  def _preload_hunyuan_regen_ctx(meta: dict, seg_path: str) -> dict:
1196
  """Pre-load HunyuanFoley text features + segment path on CPU for regen.
1197
- Returns a dict suitable for _regen_hunyuan_gpu._cpu_ctx."""
1198
  ctx = {"seg_path": seg_path}
1199
  text_feats_path = meta.get("text_feats_path", "")
1200
  if text_feats_path and os.path.exists(text_feats_path):
@@ -1285,7 +1316,7 @@ def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1285
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
1286
 
1287
  # Use pre-loaded features from CPU wrapper (avoids np.load inside GPU window)
1288
- ctx = _regen_taro_gpu._cpu_ctx
1289
  if "cavp" in ctx and "onset" in ctx:
1290
  print("[TARO regen] Using pre-loaded CAVP + onset features (CPU cache hit)")
1291
  cavp_feats = ctx["cavp"]
@@ -1323,7 +1354,7 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
1323
  seg_idx = int(seg_idx)
1324
 
1325
  # CPU: pre-load cached features so np.load doesn't happen inside GPU window
1326
- _regen_taro_gpu._cpu_ctx = _preload_taro_regen_ctx(meta)
1327
 
1328
  # GPU: inference only
1329
  new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
@@ -1365,7 +1396,7 @@ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1365
  sr = seq_cfg.sampling_rate
1366
 
1367
  # Use pre-extracted segment clip from the CPU wrapper
1368
- seg_path = _regen_mmaudio_gpu._cpu_ctx.get("seg_path")
1369
  assert seg_path, "[MMAudio regen] seg_path not set — wrapper must pre-extract segment clip"
1370
 
1371
  rng = torch.Generator(device=device)
@@ -1391,8 +1422,6 @@ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1391
  new_wav = new_wav[:, :seg_samples]
1392
  return new_wav, sr
1393
 
1394
- _regen_mmaudio_gpu._cpu_ctx = {}
1395
-
1396
 
1397
  def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
1398
  prompt, negative_prompt, seed_val,
@@ -1409,7 +1438,7 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
1409
  meta["silent_video"], seg_start, seg_dur,
1410
  os.path.join(tmp_dir, "regen_seg.mp4"),
1411
  )
1412
- _regen_mmaudio_gpu._cpu_ctx = {"seg_path": seg_path}
1413
 
1414
  # GPU: inference only
1415
  new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
@@ -1458,12 +1487,11 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1458
 
1459
  set_global_seed(random.randint(0, 2**32 - 1))
1460
 
1461
- # Use pre-extracted segment clip from wrapper
1462
- seg_path = _regen_hunyuan_gpu._cpu_ctx.get("seg_path")
 
1463
  assert seg_path, "[HunyuanFoley regen] seg_path not set — wrapper must pre-extract segment clip"
1464
 
1465
- # Use pre-loaded text_feats from CPU wrapper (avoids torch.load inside GPU window)
1466
- ctx = _regen_hunyuan_gpu._cpu_ctx
1467
  if "text_feats" in ctx:
1468
  print("[HunyuanFoley regen] Using pre-loaded text features (CPU cache hit)")
1469
  from hunyuanvideo_foley.utils.feature_utils import encode_video_features
@@ -1486,8 +1514,6 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1486
  new_wav = new_wav[:, :seg_samples]
1487
  return new_wav, sr
1488
 
1489
- _regen_hunyuan_gpu._cpu_ctx = {}
1490
-
1491
 
1492
  def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1493
  prompt, negative_prompt, seed_val,
@@ -1505,7 +1531,7 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1505
  meta["silent_video"], seg_start, seg_dur,
1506
  os.path.join(tmp_dir, "regen_seg.mp4"),
1507
  )
1508
- _regen_hunyuan_gpu._cpu_ctx = _preload_hunyuan_regen_ctx(meta, seg_path)
1509
 
1510
  # GPU: inference only
1511
  new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
@@ -1575,28 +1601,43 @@ def _xregen_splice(new_wav_raw: np.ndarray, src_sr: int,
1575
  return video_path, waveform_html
1576
 
1577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1578
  def xregen_taro(seg_idx, state_json, slot_id,
1579
  seed_val, cfg_scale, num_steps, mode,
1580
  crossfade_s, crossfade_db,
1581
  request: gr.Request = None):
1582
  """Cross-model regen: run TARO inference and splice into *slot_id*."""
1583
- meta = json.loads(state_json)
1584
  seg_idx = int(seg_idx)
 
1585
 
1586
- # Show pending waveform immediately
1587
- pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1588
- yield gr.update(), gr.update(value=pending_html)
1589
-
1590
- # CPU: pre-load cached features so np.load doesn't happen inside GPU window
1591
- _regen_taro_gpu._cpu_ctx = _preload_taro_regen_ctx(meta)
1592
 
1593
- new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1594
- seed_val, cfg_scale, num_steps, mode,
1595
- crossfade_s, crossfade_db, slot_id)
1596
- # Upsample 16kHz → 48kHz (sinc, CPU)
1597
- new_wav_raw = _upsample_taro(new_wav_raw)
1598
- video_path, waveform_html = _xregen_splice(new_wav_raw, TARO_SR_OUT, meta, seg_idx, slot_id)
1599
- yield gr.update(value=video_path), gr.update(value=waveform_html)
1600
 
1601
 
1602
  def xregen_mmaudio(seg_idx, state_json, slot_id,
@@ -1604,26 +1645,23 @@ def xregen_mmaudio(seg_idx, state_json, slot_id,
1604
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1605
  request: gr.Request = None):
1606
  """Cross-model regen: run MMAudio inference and splice into *slot_id*."""
1607
- meta = json.loads(state_json)
1608
  seg_idx = int(seg_idx)
 
1609
  seg_start, seg_end = meta["segments"][seg_idx]
1610
 
1611
- # Show pending waveform immediately
1612
- pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1613
- yield gr.update(), gr.update(value=pending_html)
1614
-
1615
- seg_path = _extract_segment_clip(
1616
- meta["silent_video"], seg_start, seg_end - seg_start,
1617
- os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1618
- )
1619
- _regen_mmaudio_gpu._cpu_ctx = {"seg_path": seg_path}
 
 
1620
 
1621
- new_wav_raw, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
1622
- prompt, negative_prompt, seed_val,
1623
- cfg_strength, num_steps,
1624
- crossfade_s, crossfade_db, slot_id)
1625
- video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id)
1626
- yield gr.update(value=video_path), gr.update(value=waveform_html)
1627
 
1628
 
1629
  def xregen_hunyuan(seg_idx, state_json, slot_id,
@@ -1632,26 +1670,23 @@ def xregen_hunyuan(seg_idx, state_json, slot_id,
1632
  crossfade_s, crossfade_db,
1633
  request: gr.Request = None):
1634
  """Cross-model regen: run HunyuanFoley inference and splice into *slot_id*."""
1635
- meta = json.loads(state_json)
1636
  seg_idx = int(seg_idx)
 
1637
  seg_start, seg_end = meta["segments"][seg_idx]
1638
 
1639
- # Show pending waveform immediately
1640
- pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1641
- yield gr.update(), gr.update(value=pending_html)
1642
-
1643
- seg_path = _extract_segment_clip(
1644
- meta["silent_video"], seg_start, seg_end - seg_start,
1645
- os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1646
- )
1647
- _regen_hunyuan_gpu._cpu_ctx = _preload_hunyuan_regen_ctx(meta, seg_path)
1648
-
1649
- new_wav_raw, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1650
- prompt, negative_prompt, seed_val,
1651
- guidance_scale, num_steps, model_size,
1652
- crossfade_s, crossfade_db, slot_id)
1653
- video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id)
1654
- yield gr.update(value=video_path), gr.update(value=waveform_html)
1655
 
1656
 
1657
  # ================================================================== #
 
17
  import random
18
  import threading
19
  import time
20
+ from concurrent.futures import ThreadPoolExecutor, as_completed
21
  from pathlib import Path
22
 
23
  import torch
 
36
  CACHE_DIR = "/tmp/model_ckpts"
37
  os.makedirs(CACHE_DIR, exist_ok=True)
38
 
39
+ # ---- Local directories that must exist before parallel downloads start ----
 
 
 
 
 
 
 
 
 
40
  MMAUDIO_WEIGHTS_DIR = Path(CACHE_DIR) / "MMAudio" / "weights"
41
  MMAUDIO_EXT_DIR = Path(CACHE_DIR) / "MMAudio" / "ext_weights"
42
+ HUNYUAN_MODEL_DIR = Path(CACHE_DIR) / "HunyuanFoley"
43
  MMAUDIO_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
44
  MMAUDIO_EXT_DIR.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
45
  HUNYUAN_MODEL_DIR.mkdir(parents=True, exist_ok=True)
46
 
47
+ # ------------------------------------------------------------------ #
48
+ # Parallel checkpoint + model downloads #
49
+ # All downloads are I/O-bound (network), so running them in threads #
50
+ # cuts Space cold-start time roughly proportional to the number of #
51
+ # independent groups (previously sequential, now concurrent). #
52
+ # hf_hub_download / snapshot_download are thread-safe. #
53
+ # ------------------------------------------------------------------ #
54
+
55
+ def _dl_taro():
56
+ """Download TARO .ckpt/.pt files and return their local paths."""
57
+ c = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/cavp_epoch66.ckpt", cache_dir=CACHE_DIR)
58
+ o = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/onset_model.ckpt", cache_dir=CACHE_DIR)
59
+ t = hf_hub_download(repo_id=CKPT_REPO_ID, filename="TARO/taro_ckpt.pt", cache_dir=CACHE_DIR)
60
+ print("TARO checkpoints downloaded.")
61
+ return c, o, t
62
+
63
+ def _dl_mmaudio():
64
+ """Download MMAudio .pth files and return their local paths."""
65
+ m = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/mmaudio_large_44k_v2.pth",
66
+ cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_WEIGHTS_DIR), local_dir_use_symlinks=False)
67
+ v = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/v1-44.pth",
68
+ cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
69
+ s = hf_hub_download(repo_id=CKPT_REPO_ID, filename="MMAudio/synchformer_state_dict.pth",
70
+ cache_dir=CACHE_DIR, local_dir=str(MMAUDIO_EXT_DIR), local_dir_use_symlinks=False)
71
+ print("MMAudio checkpoints downloaded.")
72
+ return m, v, s
73
+
74
+ def _dl_hunyuan():
75
+ """Download HunyuanVideoFoley .pth files."""
76
+ hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/hunyuanvideo_foley.pth",
77
+ cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
78
+ hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/vae_128d_48k.pth",
79
+ cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
80
+ hf_hub_download(repo_id=CKPT_REPO_ID, filename="HunyuanVideo-Foley/synchformer_state_dict.pth",
81
+ cache_dir=CACHE_DIR, local_dir=str(HUNYUAN_MODEL_DIR), local_dir_use_symlinks=False)
82
+ print("HunyuanVideoFoley checkpoints downloaded.")
83
+
84
+ def _dl_clap():
85
+ """Pre-download CLAP so from_pretrained() hits local cache inside the ZeroGPU worker."""
86
+ snapshot_download(repo_id="laion/larger_clap_general")
87
+ print("CLAP model pre-downloaded.")
88
+
89
+ def _dl_clip():
90
+ """Pre-download MMAudio's CLIP model (~3.95 GB) to avoid GPU-window budget drain."""
91
+ snapshot_download(repo_id="apple/DFN5B-CLIP-ViT-H-14-384")
92
+ print("MMAudio CLIP model pre-downloaded.")
93
+
94
+ def _dl_audioldm2():
95
+ """Pre-download AudioLDM2 VAE/vocoder used by TARO's from_pretrained() calls."""
96
+ snapshot_download(repo_id="cvssp/audioldm2")
97
+ print("AudioLDM2 pre-downloaded.")
98
+
99
+ def _dl_bigvgan():
100
+ """Pre-download BigVGAN vocoder (~489 MB) used by MMAudio."""
101
+ snapshot_download(repo_id="nvidia/bigvgan_v2_44khz_128band_512x")
102
+ print("BigVGAN vocoder pre-downloaded.")
103
+
104
+ print("[startup] Starting parallel checkpoint + model downloads…")
105
+ _t_dl_start = time.perf_counter()
106
+ with ThreadPoolExecutor(max_workers=7) as _pool:
107
+ _fut_taro = _pool.submit(_dl_taro)
108
+ _fut_mmaudio = _pool.submit(_dl_mmaudio)
109
+ _fut_hunyuan = _pool.submit(_dl_hunyuan)
110
+ _fut_clap = _pool.submit(_dl_clap)
111
+ _fut_clip = _pool.submit(_dl_clip)
112
+ _fut_aldm2 = _pool.submit(_dl_audioldm2)
113
+ _fut_bigvgan = _pool.submit(_dl_bigvgan)
114
+ # Raise any download exceptions immediately
115
+ for _fut in as_completed([_fut_taro, _fut_mmaudio, _fut_hunyuan,
116
+ _fut_clap, _fut_clip, _fut_aldm2, _fut_bigvgan]):
117
+ _fut.result()
118
+
119
+ cavp_ckpt_path, onset_ckpt_path, taro_ckpt_path = _fut_taro.result()
120
+ mmaudio_model_path, mmaudio_vae_path, mmaudio_synchformer_path = _fut_mmaudio.result()
121
+ print(f"[startup] All downloads done in {time.perf_counter() - _t_dl_start:.1f}s")
122
 
123
  # ================================================================== #
124
  # SHARED CONSTANTS / HELPERS #
125
  # ================================================================== #
126
 
127
+ # Thread-local storage for CPU → GPU context passing.
128
+ # Replaces the fragile function-attribute pattern (_fn._cpu_ctx = {...}).
129
+ # Each wrapper writes its context under a unique key before calling the
130
+ # @spaces.GPU function; the GPU function reads it back. Using thread-local
131
+ # storage means concurrent requests on different threads don't clobber
132
+ # each other's context — the function-attribute approach was not thread-safe.
133
+ _tl = threading.local()
134
+
135
  MAX_SLOTS = 8 # max parallel generation slots shown in UI
136
  MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≤ ~64 s at 8 s/seg)
137
 
 
385
  pix_fmt="yuv420p",
386
  acodec="aac", audio_bitrate="128k",
387
  movflags="+faststart",
388
+ ).run(overwrite_output=True, quiet=True)
389
 
390
 
391
  # ------------------------------------------------------------------ #
 
451
  # latents_scale: [0.18215]*8 — AudioLDM2 VAE scale factor
452
  # ================================================================== #
453
 
454
+ # ================================================================== #
455
+ # MODEL CONSTANTS & CONFIGURATION REGISTRY #
456
+ # ================================================================== #
457
+ # All per-model numeric constants live here — MODEL_CONFIGS is the #
458
+ # single source of truth consumed by duration estimation, segmentation,#
459
+ # and the UI. Standalone names kept only where other code references #
460
+ # them by name (TARO geometry, TARGET_SR, GPU_DURATION_CAP). #
461
+ # ================================================================== #
462
+
463
+ # TARO geometry — referenced directly in _taro_infer_segment
464
+ TARO_SR = 16000
465
+ TARO_TRUNCATE = 131072
466
+ TARO_FPS = 4
467
+ TARO_TRUNCATE_FRAME = int(TARO_FPS * TARO_TRUNCATE / TARO_SR) # 32
468
  TARO_TRUNCATE_ONSET = 120
469
+ TARO_MODEL_DUR = TARO_TRUNCATE / TARO_SR # 8.192 s
470
+
471
+ GPU_DURATION_CAP = 300 # hard cap per @spaces.GPU call — never reserve more than this
 
 
 
 
 
 
 
 
472
 
 
 
 
 
473
  MODEL_CONFIGS = {
474
  "taro": {
475
+ "window_s": TARO_MODEL_DUR, # 8.192 s
476
+ "sr": TARO_SR, # 16000 (output resampled to TARGET_SR)
477
+ "secs_per_step": 0.025, # measured 0.023 s/step on H200
478
+ "load_overhead": 15, # model load + CAVP feature extraction
479
  "tab_prefix": "taro",
 
480
  "label": "TARO",
481
+ "regen_fn": None, # set after function definitions (avoids forward-ref)
482
  },
483
  "mmaudio": {
484
+ "window_s": 8.0, # MMAudio's fixed generation window
485
+ "sr": 48000, # resampled from 44100 in post-processing
486
+ "secs_per_step": 0.25, # measured 0.230 s/step on H200
487
+ "load_overhead": 30, # 15s warm + 15s model init
488
  "tab_prefix": "mma",
 
489
  "label": "MMAudio",
490
+ "regen_fn": None,
491
  },
492
  "hunyuan": {
493
+ "window_s": 15.0, # HunyuanFoley max video duration
494
  "sr": 48000,
495
+ "secs_per_step": 0.35, # measured 0.328 s/step on H200
496
+ "load_overhead": 55, # ~55s to load the 10 GB XXL weights
497
  "tab_prefix": "hf",
 
498
  "label": "HunyuanFoley",
499
+ "regen_fn": None,
500
  },
501
  }
502
 
503
+ # Convenience aliases used only in the TARO inference path
504
+ TARO_SECS_PER_STEP = MODEL_CONFIGS["taro"]["secs_per_step"]
505
+ MMAUDIO_WINDOW = MODEL_CONFIGS["mmaudio"]["window_s"]
506
+ MMAUDIO_SECS_PER_STEP = MODEL_CONFIGS["mmaudio"]["secs_per_step"]
507
+ HUNYUAN_MAX_DUR = MODEL_CONFIGS["hunyuan"]["window_s"]
508
+ HUNYUAN_SECS_PER_STEP = MODEL_CONFIGS["hunyuan"]["secs_per_step"]
509
+
510
+
511
+ def _clamp_duration(secs: float, label: str) -> int:
512
+ """Clamp a raw GPU-seconds estimate to [60, GPU_DURATION_CAP] and log it."""
513
+ result = min(GPU_DURATION_CAP, max(60, int(secs)))
514
+ print(f"[duration] {label}: {secs:.0f}s raw → {result}s reserved")
515
+ return result
516
+
517
 
518
  def _estimate_gpu_duration(model_key: str, num_samples: int, num_steps: int,
519
  total_dur_s: float = None, crossfade_s: float = 0,
520
  video_file: str = None) -> int:
521
+ """Estimate GPU seconds for a full generation call.
522
 
523
+ Formula: num_samples × n_segs × num_steps × secs_per_step + load_overhead
 
524
  """
525
  cfg = MODEL_CONFIGS[model_key]
526
  try:
 
529
  n_segs = len(_build_segments(total_dur_s, cfg["window_s"], float(crossfade_s)))
530
  except Exception:
531
  n_segs = 1
532
+ secs = int(num_samples) * n_segs * int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
 
533
  print(f"[duration] {cfg['label']}: {int(num_samples)}samp × {n_segs}seg × "
534
+ f"{int(num_steps)}steps → {secs:.0f}s → capped ", end="")
535
+ return _clamp_duration(secs, cfg["label"])
536
 
537
 
538
  def _estimate_regen_duration(model_key: str, num_steps: int) -> int:
539
+ """Estimate GPU seconds for a single-segment regen call."""
 
 
 
 
 
540
  cfg = MODEL_CONFIGS[model_key]
541
  secs = int(num_steps) * cfg["secs_per_step"] + cfg["load_overhead"]
542
+ print(f"[duration] {cfg['label']} regen: 1 seg × {int(num_steps)} steps → ", end="")
543
+ return _clamp_duration(secs, f"{cfg['label']} regen")
 
544
 
545
  _TARO_CACHE_MAXLEN = 16 # evict oldest entries beyond this limit
546
  _TARO_INFERENCE_CACHE: dict = {} # keyed by (video_file, seed, cfg, steps, mode, crossfade_s)
 
788
  from TARO.onset_util import extract_onset
789
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
790
 
791
+ # Use pre-computed CPU results passed via thread-local storage
792
+ ctx = _tl.taro_gen_ctx
793
  tmp_dir = ctx["tmp_dir"]
794
  silent_video = ctx["silent_video"]
795
  segments = ctx["segments"]
 
848
 
849
  return results
850
 
 
 
 
851
 
852
  def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
853
  crossfade_s, crossfade_db, num_samples):
 
861
  tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
862
  video_file, TARO_MODEL_DUR, crossfade_s)
863
 
864
+ # Pass pre-computed CPU results to the GPU function via thread-local storage
865
+ _tl.taro_gen_ctx = {
866
  "tmp_dir": tmp_dir, "silent_video": silent_video,
867
  "segments": segments, "total_dur_s": total_dur_s,
868
  }
 
941
 
942
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
943
 
944
+ ctx = _tl.mmaudio_gen_ctx
945
  segments = ctx["segments"]
946
  seg_clip_paths = ctx["seg_clip_paths"]
947
 
 
1001
 
1002
  return results
1003
 
 
 
1004
 
1005
  def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
1006
  cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
 
1020
  for i, (s, e) in enumerate(segments)
1021
  ]
1022
 
1023
+ _tl.mmaudio_gen_ctx = {
1024
  "segments": segments, "seg_clip_paths": seg_clip_paths,
1025
  }
1026
 
 
1090
 
1091
  model_dict, cfg = _load_hunyuan_model(device, model_size)
1092
 
1093
+ ctx = _tl.hunyuan_gen_ctx
1094
  segments = ctx["segments"]
1095
  total_dur_s = ctx["total_dur_s"]
1096
  dummy_seg_path = ctx["dummy_seg_path"]
 
1148
 
1149
  return results
1150
 
 
 
1151
 
1152
  def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1153
  guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
 
1174
  for i, (s, e) in enumerate(segments)
1175
  ]
1176
 
1177
+ _tl.hunyuan_gen_ctx = {
1178
  "segments": segments, "total_dur_s": total_dur_s,
1179
  "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
1180
  }
 
1213
 
1214
  def _preload_taro_regen_ctx(meta: dict) -> dict:
1215
  """Pre-load TARO CAVP/onset features on CPU for regen.
1216
+ Returns a dict for _tl.taro_regen_ctx (thread-local storage)."""
1217
  cavp_path = meta.get("cavp_path", "")
1218
  onset_path = meta.get("onset_path", "")
1219
  ctx = {}
 
1225
 
1226
  def _preload_hunyuan_regen_ctx(meta: dict, seg_path: str) -> dict:
1227
  """Pre-load HunyuanFoley text features + segment path on CPU for regen.
1228
+ Returns a dict for _tl.hunyuan_regen_ctx (thread-local storage)."""
1229
  ctx = {"seg_path": seg_path}
1230
  text_feats_path = meta.get("text_feats_path", "")
1231
  if text_feats_path and os.path.exists(text_feats_path):
 
1316
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
1317
 
1318
  # Use pre-loaded features from CPU wrapper (avoids np.load inside GPU window)
1319
+ ctx = getattr(_tl, "taro_regen_ctx", {})
1320
  if "cavp" in ctx and "onset" in ctx:
1321
  print("[TARO regen] Using pre-loaded CAVP + onset features (CPU cache hit)")
1322
  cavp_feats = ctx["cavp"]
 
1354
  seg_idx = int(seg_idx)
1355
 
1356
  # CPU: pre-load cached features so np.load doesn't happen inside GPU window
1357
+ _tl.taro_regen_ctx = _preload_taro_regen_ctx(meta)
1358
 
1359
  # GPU: inference only
1360
  new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
 
1396
  sr = seq_cfg.sampling_rate
1397
 
1398
  # Use pre-extracted segment clip from the CPU wrapper
1399
+ seg_path = getattr(_tl, "mmaudio_regen_ctx", {}).get("seg_path")
1400
  assert seg_path, "[MMAudio regen] seg_path not set — wrapper must pre-extract segment clip"
1401
 
1402
  rng = torch.Generator(device=device)
 
1422
  new_wav = new_wav[:, :seg_samples]
1423
  return new_wav, sr
1424
 
 
 
1425
 
1426
  def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
1427
  prompt, negative_prompt, seed_val,
 
1438
  meta["silent_video"], seg_start, seg_dur,
1439
  os.path.join(tmp_dir, "regen_seg.mp4"),
1440
  )
1441
+ _tl.mmaudio_regen_ctx = {"seg_path": seg_path}
1442
 
1443
  # GPU: inference only
1444
  new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
 
1487
 
1488
  set_global_seed(random.randint(0, 2**32 - 1))
1489
 
1490
+ # Use pre-extracted segment clip + text_feats from CPU wrapper
1491
+ ctx = getattr(_tl, "hunyuan_regen_ctx", {})
1492
+ seg_path = ctx.get("seg_path")
1493
  assert seg_path, "[HunyuanFoley regen] seg_path not set — wrapper must pre-extract segment clip"
1494
 
 
 
1495
  if "text_feats" in ctx:
1496
  print("[HunyuanFoley regen] Using pre-loaded text features (CPU cache hit)")
1497
  from hunyuanvideo_foley.utils.feature_utils import encode_video_features
 
1514
  new_wav = new_wav[:, :seg_samples]
1515
  return new_wav, sr
1516
 
 
 
1517
 
1518
  def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1519
  prompt, negative_prompt, seed_val,
 
1531
  meta["silent_video"], seg_start, seg_dur,
1532
  os.path.join(tmp_dir, "regen_seg.mp4"),
1533
  )
1534
+ _tl.hunyuan_regen_ctx = _preload_hunyuan_regen_ctx(meta, seg_path)
1535
 
1536
  # GPU: inference only
1537
  new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
 
1601
  return video_path, waveform_html
1602
 
1603
 
1604
+ def _xregen_dispatch(state_json: str, seg_idx: int, slot_id: str, infer_fn):
1605
+ """Shared generator skeleton for all xregen_* wrappers.
1606
+
1607
+ Yields pending HTML immediately, then calls *infer_fn()* — a zero-argument
1608
+ callable that runs model-specific CPU prep + GPU inference and returns
1609
+ (wav_array, src_sr). For TARO, *infer_fn* should return the wav already
1610
+ upsampled to 48 kHz; pass TARO_SR_OUT as src_sr.
1611
+
1612
+ Yields:
1613
+ First: (gr.update(), gr.update(value=pending_html)) — shown while GPU runs
1614
+ Second: (gr.update(value=video_path), gr.update(value=waveform_html))
1615
+ """
1616
+ meta = json.loads(state_json)
1617
+ pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1618
+ yield gr.update(), gr.update(value=pending_html)
1619
+
1620
+ new_wav_raw, src_sr = infer_fn()
1621
+ video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id)
1622
+ yield gr.update(value=video_path), gr.update(value=waveform_html)
1623
+
1624
+
1625
  def xregen_taro(seg_idx, state_json, slot_id,
1626
  seed_val, cfg_scale, num_steps, mode,
1627
  crossfade_s, crossfade_db,
1628
  request: gr.Request = None):
1629
  """Cross-model regen: run TARO inference and splice into *slot_id*."""
 
1630
  seg_idx = int(seg_idx)
1631
+ meta = json.loads(state_json)
1632
 
1633
+ def _run():
1634
+ _tl.taro_regen_ctx = _preload_taro_regen_ctx(meta)
1635
+ wav = _regen_taro_gpu(None, seg_idx, state_json,
1636
+ seed_val, cfg_scale, num_steps, mode,
1637
+ crossfade_s, crossfade_db, slot_id)
1638
+ return _upsample_taro(wav), TARO_SR_OUT # 16 kHz → 48 kHz (CPU)
1639
 
1640
+ yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
 
 
 
 
 
 
1641
 
1642
 
1643
  def xregen_mmaudio(seg_idx, state_json, slot_id,
 
1645
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1646
  request: gr.Request = None):
1647
  """Cross-model regen: run MMAudio inference and splice into *slot_id*."""
 
1648
  seg_idx = int(seg_idx)
1649
+ meta = json.loads(state_json)
1650
  seg_start, seg_end = meta["segments"][seg_idx]
1651
 
1652
+ def _run():
1653
+ seg_path = _extract_segment_clip(
1654
+ meta["silent_video"], seg_start, seg_end - seg_start,
1655
+ os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1656
+ )
1657
+ _tl.mmaudio_regen_ctx = {"seg_path": seg_path}
1658
+ wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
1659
+ prompt, negative_prompt, seed_val,
1660
+ cfg_strength, num_steps,
1661
+ crossfade_s, crossfade_db, slot_id)
1662
+ return wav, src_sr
1663
 
1664
+ yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
 
 
 
 
 
1665
 
1666
 
1667
  def xregen_hunyuan(seg_idx, state_json, slot_id,
 
1670
  crossfade_s, crossfade_db,
1671
  request: gr.Request = None):
1672
  """Cross-model regen: run HunyuanFoley inference and splice into *slot_id*."""
 
1673
  seg_idx = int(seg_idx)
1674
+ meta = json.loads(state_json)
1675
  seg_start, seg_end = meta["segments"][seg_idx]
1676
 
1677
+ def _run():
1678
+ seg_path = _extract_segment_clip(
1679
+ meta["silent_video"], seg_start, seg_end - seg_start,
1680
+ os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1681
+ )
1682
+ _tl.hunyuan_regen_ctx = _preload_hunyuan_regen_ctx(meta, seg_path)
1683
+ wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1684
+ prompt, negative_prompt, seed_val,
1685
+ guidance_scale, num_steps, model_size,
1686
+ crossfade_s, crossfade_db, slot_id)
1687
+ return wav, src_sr
1688
+
1689
+ yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
 
 
 
1690
 
1691
 
1692
  # ================================================================== #