BoxOfColors Claude Sonnet 4.6 commited on
Commit
e7175d4
·
1 Parent(s): ac67bf3

fix: remove ctx_key from all function signatures — use fn-name-keyed global dict

Browse files

ctx_key as a function argument exposed it to Gradio's API endpoint discovery,
causing 'Too many arguments provided for the endpoint' errors and GPU task aborts.

Fix: remove ctx_key from all @spaces.GPU function signatures and their duration
callables. Store/retrieve context using _ctx_store(fn_name, data) /
_ctx_load(fn_name) — a global dict keyed by function name. This is safe because
ZeroGPU is synchronous (wrapper blocks until GPU fn returns), so only one call
per GPU function is in-flight at a time within a single process.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +50 -57
app.py CHANGED
@@ -124,35 +124,30 @@ print(f"[startup] All downloads done in {time.perf_counter() - _t_dl_start:.1f}s
124
  # SHARED CONSTANTS / HELPERS #
125
  # ================================================================== #
126
 
127
- # CPU → GPU context passing via UUID-keyed global store.
128
  #
129
- # ZeroGPU dispatches @spaces.GPU functions on its own worker thread, so
130
- # threading.local() doesn't work. Passing context as a function argument
131
- # is the right idea, but ZeroGPU validates args against the *duration*
132
- # callable's signature any extra param not present in the duration fn
133
- # gets dropped or set to None before the GPU fn runs.
134
  #
135
- # Solution: add ctx_key="" to BOTH the duration fn AND the GPU fn.
136
- # The wrapper stores the context dict in _GPU_CTX[uuid] and passes the
137
- # uuid string as ctx_key. The GPU fn does _GPU_CTX.pop(ctx_key).
138
- # Since the dict is global (not thread-local), the GPU worker thread can
139
- # read it regardless of which thread wrote it. The uuid ensures
140
- # concurrent requests don't collide.
141
- import uuid as _uuid_mod
142
- _GPU_CTX: dict = {}
143
- _GPU_CTX_LOCK = threading.Lock()
144
-
145
- def _ctx_store(data: dict) -> str:
146
- """Store *data* in the global context dict; return the UUID key."""
147
- key = _uuid_mod.uuid4().hex
148
  with _GPU_CTX_LOCK:
149
- _GPU_CTX[key] = data
150
- return key
151
 
152
- def _ctx_load(key: str) -> dict:
153
- """Pop and return the context dict for *key*."""
154
  with _GPU_CTX_LOCK:
155
- return _GPU_CTX.pop(key, {})
156
 
157
  MAX_SLOTS = 8 # max parallel generation slots shown in UI
158
  MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≤ ~64 s at 8 s/seg)
@@ -577,7 +572,7 @@ def _taro_calc_max_samples(total_dur_s: float, num_steps: int, crossfade_s: floa
577
 
578
 
579
  def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode,
580
- crossfade_s, crossfade_db, num_samples, ctx_key=""):
581
  """Pre-GPU callable — must match _taro_gpu_infer's input order exactly."""
582
  return _estimate_gpu_duration("taro", int(num_samples), int(num_steps),
583
  video_file=video_file, crossfade_s=crossfade_s)
@@ -794,7 +789,7 @@ def _cpu_preprocess(video_file: str, model_dur: float,
794
 
795
  @spaces.GPU(duration=_taro_duration)
796
  def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
797
- crossfade_s, crossfade_db, num_samples, ctx_key=""):
798
  """GPU-only TARO inference — model loading + feature extraction + diffusion.
799
  Returns list of (wavs_list, onset_feats) per sample."""
800
  seed_val = int(seed_val)
@@ -810,7 +805,7 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
810
  from TARO.onset_util import extract_onset
811
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
812
 
813
- ctx = _ctx_load(ctx_key)
814
  tmp_dir = ctx["tmp_dir"]
815
  silent_video = ctx["silent_video"]
816
  segments = ctx["segments"]
@@ -882,14 +877,14 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
882
  tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
883
  video_file, TARO_MODEL_DUR, crossfade_s)
884
 
885
- ctx_key = _ctx_store({
886
  "tmp_dir": tmp_dir, "silent_video": silent_video,
887
  "segments": segments, "total_dur_s": total_dur_s,
888
  })
889
 
890
  # ── GPU inference only ──
891
  results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
892
- crossfade_s, crossfade_db, num_samples, ctx_key)
893
 
894
  # ── CPU post-processing (no GPU needed) ──
895
  # Upsample 16kHz → 48kHz and normalise result tuples to (seg_wavs, ...)
@@ -938,8 +933,7 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
938
 
939
 
940
  def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
941
- cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples,
942
- ctx_key=""):
943
  """Pre-GPU callable — must match _mmaudio_gpu_infer's input order exactly."""
944
  return _estimate_gpu_duration("mmaudio", int(num_samples), int(num_steps),
945
  video_file=video_file, crossfade_s=crossfade_s)
@@ -947,8 +941,7 @@ def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
947
 
948
  @spaces.GPU(duration=_mmaudio_duration)
949
  def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
950
- cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples,
951
- ctx_key=""):
952
  """GPU-only MMAudio inference — model loading + flow-matching generation.
953
  Returns list of (seg_audios, sr) per sample."""
954
  _ensure_syspath("MMAudio")
@@ -963,7 +956,7 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
963
 
964
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
965
 
966
- ctx = _ctx_load(ctx_key)
967
  segments = ctx["segments"]
968
  seg_clip_paths = ctx["seg_clip_paths"]
969
 
@@ -1042,12 +1035,12 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
1042
  for i, (s, e) in enumerate(segments)
1043
  ]
1044
 
1045
- ctx_key = _ctx_store({"segments": segments, "seg_clip_paths": seg_clip_paths})
1046
 
1047
  # ── GPU inference only ──
1048
  results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1049
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1050
- num_samples, ctx_key)
1051
 
1052
  # ── CPU post-processing ──
1053
  # Resample 44100 → 48000 and normalise tuples to (seg_wavs, ...)
@@ -1085,7 +1078,7 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
1085
 
1086
  def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
1087
  guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
1088
- num_samples, ctx_key=""):
1089
  """Pre-GPU callable — must match _hunyuan_gpu_infer's input order exactly."""
1090
  return _estimate_gpu_duration("hunyuan", int(num_samples), int(num_steps),
1091
  video_file=video_file, crossfade_s=crossfade_s)
@@ -1094,7 +1087,7 @@ def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
1094
  @spaces.GPU(duration=_hunyuan_duration)
1095
  def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1096
  guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
1097
- num_samples, ctx_key=""):
1098
  """GPU-only HunyuanFoley inference — model loading + feature extraction + denoising.
1099
  Returns list of (seg_wavs, sr, text_feats) per sample."""
1100
  _ensure_syspath("HunyuanVideo-Foley")
@@ -1113,7 +1106,7 @@ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1113
 
1114
  model_dict, cfg = _load_hunyuan_model(device, model_size)
1115
 
1116
- ctx = _ctx_load(ctx_key)
1117
  segments = ctx["segments"]
1118
  total_dur_s = ctx["total_dur_s"]
1119
  dummy_seg_path = ctx["dummy_seg_path"]
@@ -1197,7 +1190,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1197
  for i, (s, e) in enumerate(segments)
1198
  ]
1199
 
1200
- ctx_key = _ctx_store({
1201
  "segments": segments, "total_dur_s": total_dur_s,
1202
  "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
1203
  })
@@ -1205,7 +1198,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1205
  # ── GPU inference only ──
1206
  results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1207
  guidance_scale, num_steps, model_size,
1208
- crossfade_s, crossfade_db, num_samples, ctx_key)
1209
 
1210
  # ── CPU post-processing (no GPU needed) ──
1211
  def _hunyuan_extras(sample_idx, result, td):
@@ -1285,7 +1278,7 @@ def _splice_and_save(new_wav, seg_idx, meta, slot_id):
1285
 
1286
  def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
1287
  seed_val, cfg_scale, num_steps, mode,
1288
- crossfade_s, crossfade_db, slot_id=None, ctx_key=""):
1289
  # If cached CAVP/onset features exist, skip ~10s feature-extractor overhead
1290
  try:
1291
  meta = json.loads(seg_meta_json)
@@ -1305,7 +1298,7 @@ def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
1305
  @spaces.GPU(duration=_taro_regen_duration)
1306
  def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1307
  seed_val, cfg_scale, num_steps, mode,
1308
- crossfade_s, crossfade_db, slot_id=None, ctx_key=""):
1309
  """GPU-only TARO regen — returns new_wav for a single segment."""
1310
  meta = json.loads(seg_meta_json)
1311
  seg_idx = int(seg_idx)
@@ -1372,7 +1365,7 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
1372
  def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
1373
  prompt, negative_prompt, seed_val,
1374
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1375
- slot_id=None, ctx_key=""):
1376
  return _estimate_regen_duration("mmaudio", int(num_steps))
1377
 
1378
 
@@ -1380,7 +1373,7 @@ def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
1380
  def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1381
  prompt, negative_prompt, seed_val,
1382
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1383
- slot_id=None, ctx_key=""):
1384
  """GPU-only MMAudio regen — returns (new_wav, sr) for a single segment."""
1385
  meta = json.loads(seg_meta_json)
1386
  seg_idx = int(seg_idx)
@@ -1396,7 +1389,7 @@ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1396
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
1397
  sr = seq_cfg.sampling_rate
1398
 
1399
- seg_path = _ctx_load(ctx_key).get("seg_path")
1400
  assert seg_path, "[MMAudio regen] seg_path not set — wrapper must pre-extract segment clip"
1401
 
1402
  rng = torch.Generator(device=device)
@@ -1438,13 +1431,13 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
1438
  meta["silent_video"], seg_start, seg_dur,
1439
  os.path.join(tmp_dir, "regen_seg.mp4"),
1440
  )
1441
- ctx_key = _ctx_store({"seg_path": seg_path})
1442
 
1443
  # GPU: inference only
1444
  new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1445
  prompt, negative_prompt, seed_val,
1446
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1447
- slot_id, ctx_key)
1448
 
1449
  # Resample to 48kHz if needed (MMAudio outputs at 44100 Hz)
1450
  if sr != TARGET_SR:
@@ -1463,7 +1456,7 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
1463
  def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
1464
  prompt, negative_prompt, seed_val,
1465
  guidance_scale, num_steps, model_size,
1466
- crossfade_s, crossfade_db, slot_id=None, ctx_key=""):
1467
  return _estimate_regen_duration("hunyuan", int(num_steps))
1468
 
1469
 
@@ -1471,7 +1464,7 @@ def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
1471
  def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1472
  prompt, negative_prompt, seed_val,
1473
  guidance_scale, num_steps, model_size,
1474
- crossfade_s, crossfade_db, slot_id=None, ctx_key=""):
1475
  """GPU-only HunyuanFoley regen — returns (new_wav, sr) for a single segment."""
1476
  meta = json.loads(seg_meta_json)
1477
  seg_idx = int(seg_idx)
@@ -1488,7 +1481,7 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1488
 
1489
  set_global_seed(random.randint(0, 2**32 - 1))
1490
 
1491
- ctx = _ctx_load(ctx_key)
1492
  seg_path = ctx.get("seg_path")
1493
  assert seg_path, "[HunyuanFoley regen] seg_path not set — wrapper must pre-extract segment clip"
1494
 
@@ -1532,7 +1525,7 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1532
  meta["silent_video"], seg_start, seg_dur,
1533
  os.path.join(tmp_dir, "regen_seg.mp4"),
1534
  )
1535
- ctx_key = _ctx_store({
1536
  "seg_path": seg_path,
1537
  "text_feats_path": meta.get("text_feats_path", ""),
1538
  })
@@ -1541,7 +1534,7 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1541
  new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1542
  prompt, negative_prompt, seed_val,
1543
  guidance_scale, num_steps, model_size,
1544
- crossfade_s, crossfade_db, slot_id, ctx_key)
1545
 
1546
  meta["sr"] = sr
1547
 
@@ -1658,11 +1651,11 @@ def xregen_mmaudio(seg_idx, state_json, slot_id,
1658
  meta["silent_video"], seg_start, seg_end - seg_start,
1659
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1660
  )
1661
- ctx_key = _ctx_store({"seg_path": seg_path})
1662
  wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
1663
  prompt, negative_prompt, seed_val,
1664
  cfg_strength, num_steps,
1665
- crossfade_s, crossfade_db, slot_id, ctx_key)
1666
  return wav, src_sr
1667
 
1668
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
@@ -1683,14 +1676,14 @@ def xregen_hunyuan(seg_idx, state_json, slot_id,
1683
  meta["silent_video"], seg_start, seg_end - seg_start,
1684
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1685
  )
1686
- ctx_key = _ctx_store({
1687
  "seg_path": seg_path,
1688
  "text_feats_path": meta.get("text_feats_path", ""),
1689
  })
1690
  wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1691
  prompt, negative_prompt, seed_val,
1692
  guidance_scale, num_steps, model_size,
1693
- crossfade_s, crossfade_db, slot_id, ctx_key)
1694
  return wav, src_sr
1695
 
1696
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
 
124
  # SHARED CONSTANTS / HELPERS #
125
  # ================================================================== #
126
 
127
+ # CPU → GPU context passing via function-name-keyed global store.
128
  #
129
+ # Problem: ZeroGPU runs @spaces.GPU functions on its own worker thread, so
130
+ # threading.local() is invisible to the GPU worker. Passing ctx as a
131
+ # function argument exposes it to Gradio's API endpoint, causing
132
+ # "Too many arguments" errors.
 
133
  #
134
+ # Solution: store context in a plain global dict keyed by function name.
135
+ # A per-key Lock serialises concurrent callers for the same function
136
+ # (ZeroGPU is already synchronous the wrapper blocks until the GPU fn
137
+ # returns so in practice only one call per GPU fn is in-flight at a time).
138
+ # The global dict is readable from any thread.
139
+ _GPU_CTX: dict = {}
140
+ _GPU_CTX_LOCK = threading.Lock()
141
+
142
+ def _ctx_store(fn_name: str, data: dict) -> None:
143
+ """Store *data* under *fn_name* key (overwrites previous)."""
 
 
 
144
  with _GPU_CTX_LOCK:
145
+ _GPU_CTX[fn_name] = data
 
146
 
147
+ def _ctx_load(fn_name: str) -> dict:
148
+ """Pop and return the context dict stored under *fn_name*."""
149
  with _GPU_CTX_LOCK:
150
+ return _GPU_CTX.pop(fn_name, {})
151
 
152
  MAX_SLOTS = 8 # max parallel generation slots shown in UI
153
  MAX_SEGS = 8 # max segments per slot (same as MAX_SLOTS; video ≤ ~64 s at 8 s/seg)
 
572
 
573
 
574
  def _taro_duration(video_file, seed_val, cfg_scale, num_steps, mode,
575
+ crossfade_s, crossfade_db, num_samples):
576
  """Pre-GPU callable — must match _taro_gpu_infer's input order exactly."""
577
  return _estimate_gpu_duration("taro", int(num_samples), int(num_steps),
578
  video_file=video_file, crossfade_s=crossfade_s)
 
789
 
790
  @spaces.GPU(duration=_taro_duration)
791
  def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
792
+ crossfade_s, crossfade_db, num_samples):
793
  """GPU-only TARO inference — model loading + feature extraction + diffusion.
794
  Returns list of (wavs_list, onset_feats) per sample."""
795
  seed_val = int(seed_val)
 
805
  from TARO.onset_util import extract_onset
806
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
807
 
808
+ ctx = _ctx_load("taro_gpu_infer")
809
  tmp_dir = ctx["tmp_dir"]
810
  silent_video = ctx["silent_video"]
811
  segments = ctx["segments"]
 
877
  tmp_dir, silent_video, total_dur_s, segments = _cpu_preprocess(
878
  video_file, TARO_MODEL_DUR, crossfade_s)
879
 
880
+ _ctx_store("taro_gpu_infer", {
881
  "tmp_dir": tmp_dir, "silent_video": silent_video,
882
  "segments": segments, "total_dur_s": total_dur_s,
883
  })
884
 
885
  # ── GPU inference only ──
886
  results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
887
+ crossfade_s, crossfade_db, num_samples)
888
 
889
  # ── CPU post-processing (no GPU needed) ──
890
  # Upsample 16kHz → 48kHz and normalise result tuples to (seg_wavs, ...)
 
933
 
934
 
935
  def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
936
+ cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
 
937
  """Pre-GPU callable — must match _mmaudio_gpu_infer's input order exactly."""
938
  return _estimate_gpu_duration("mmaudio", int(num_samples), int(num_steps),
939
  video_file=video_file, crossfade_s=crossfade_s)
 
941
 
942
  @spaces.GPU(duration=_mmaudio_duration)
943
  def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
944
+ cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
 
945
  """GPU-only MMAudio inference — model loading + flow-matching generation.
946
  Returns list of (seg_audios, sr) per sample."""
947
  _ensure_syspath("MMAudio")
 
956
 
957
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
958
 
959
+ ctx = _ctx_load("mmaudio_gpu_infer")
960
  segments = ctx["segments"]
961
  seg_clip_paths = ctx["seg_clip_paths"]
962
 
 
1035
  for i, (s, e) in enumerate(segments)
1036
  ]
1037
 
1038
+ _ctx_store("mmaudio_gpu_infer", {"segments": segments, "seg_clip_paths": seg_clip_paths})
1039
 
1040
  # ── GPU inference only ──
1041
  results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1042
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1043
+ num_samples)
1044
 
1045
  # ── CPU post-processing ──
1046
  # Resample 44100 → 48000 and normalise tuples to (seg_wavs, ...)
 
1078
 
1079
  def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
1080
  guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
1081
+ num_samples):
1082
  """Pre-GPU callable — must match _hunyuan_gpu_infer's input order exactly."""
1083
  return _estimate_gpu_duration("hunyuan", int(num_samples), int(num_steps),
1084
  video_file=video_file, crossfade_s=crossfade_s)
 
1087
  @spaces.GPU(duration=_hunyuan_duration)
1088
  def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1089
  guidance_scale, num_steps, model_size, crossfade_s, crossfade_db,
1090
+ num_samples):
1091
  """GPU-only HunyuanFoley inference — model loading + feature extraction + denoising.
1092
  Returns list of (seg_wavs, sr, text_feats) per sample."""
1093
  _ensure_syspath("HunyuanVideo-Foley")
 
1106
 
1107
  model_dict, cfg = _load_hunyuan_model(device, model_size)
1108
 
1109
+ ctx = _ctx_load("hunyuan_gpu_infer")
1110
  segments = ctx["segments"]
1111
  total_dur_s = ctx["total_dur_s"]
1112
  dummy_seg_path = ctx["dummy_seg_path"]
 
1190
  for i, (s, e) in enumerate(segments)
1191
  ]
1192
 
1193
+ _ctx_store("hunyuan_gpu_infer", {
1194
  "segments": segments, "total_dur_s": total_dur_s,
1195
  "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
1196
  })
 
1198
  # ── GPU inference only ──
1199
  results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1200
  guidance_scale, num_steps, model_size,
1201
+ crossfade_s, crossfade_db, num_samples)
1202
 
1203
  # ── CPU post-processing (no GPU needed) ──
1204
  def _hunyuan_extras(sample_idx, result, td):
 
1278
 
1279
  def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
1280
  seed_val, cfg_scale, num_steps, mode,
1281
+ crossfade_s, crossfade_db, slot_id=None):
1282
  # If cached CAVP/onset features exist, skip ~10s feature-extractor overhead
1283
  try:
1284
  meta = json.loads(seg_meta_json)
 
1298
  @spaces.GPU(duration=_taro_regen_duration)
1299
  def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1300
  seed_val, cfg_scale, num_steps, mode,
1301
+ crossfade_s, crossfade_db, slot_id=None):
1302
  """GPU-only TARO regen — returns new_wav for a single segment."""
1303
  meta = json.loads(seg_meta_json)
1304
  seg_idx = int(seg_idx)
 
1365
  def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
1366
  prompt, negative_prompt, seed_val,
1367
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1368
+ slot_id=None):
1369
  return _estimate_regen_duration("mmaudio", int(num_steps))
1370
 
1371
 
 
1373
  def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1374
  prompt, negative_prompt, seed_val,
1375
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1376
+ slot_id=None):
1377
  """GPU-only MMAudio regen — returns (new_wav, sr) for a single segment."""
1378
  meta = json.loads(seg_meta_json)
1379
  seg_idx = int(seg_idx)
 
1389
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
1390
  sr = seq_cfg.sampling_rate
1391
 
1392
+ seg_path = _ctx_load("regen_mmaudio_gpu").get("seg_path")
1393
  assert seg_path, "[MMAudio regen] seg_path not set — wrapper must pre-extract segment clip"
1394
 
1395
  rng = torch.Generator(device=device)
 
1431
  meta["silent_video"], seg_start, seg_dur,
1432
  os.path.join(tmp_dir, "regen_seg.mp4"),
1433
  )
1434
+ _ctx_store("regen_mmaudio_gpu", {"seg_path": seg_path})
1435
 
1436
  # GPU: inference only
1437
  new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1438
  prompt, negative_prompt, seed_val,
1439
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1440
+ slot_id)
1441
 
1442
  # Resample to 48kHz if needed (MMAudio outputs at 44100 Hz)
1443
  if sr != TARGET_SR:
 
1456
  def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
1457
  prompt, negative_prompt, seed_val,
1458
  guidance_scale, num_steps, model_size,
1459
+ crossfade_s, crossfade_db, slot_id=None):
1460
  return _estimate_regen_duration("hunyuan", int(num_steps))
1461
 
1462
 
 
1464
  def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1465
  prompt, negative_prompt, seed_val,
1466
  guidance_scale, num_steps, model_size,
1467
+ crossfade_s, crossfade_db, slot_id=None):
1468
  """GPU-only HunyuanFoley regen — returns (new_wav, sr) for a single segment."""
1469
  meta = json.loads(seg_meta_json)
1470
  seg_idx = int(seg_idx)
 
1481
 
1482
  set_global_seed(random.randint(0, 2**32 - 1))
1483
 
1484
+ ctx = _ctx_load("regen_hunyuan_gpu")
1485
  seg_path = ctx.get("seg_path")
1486
  assert seg_path, "[HunyuanFoley regen] seg_path not set — wrapper must pre-extract segment clip"
1487
 
 
1525
  meta["silent_video"], seg_start, seg_dur,
1526
  os.path.join(tmp_dir, "regen_seg.mp4"),
1527
  )
1528
+ _ctx_store("regen_hunyuan_gpu", {
1529
  "seg_path": seg_path,
1530
  "text_feats_path": meta.get("text_feats_path", ""),
1531
  })
 
1534
  new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1535
  prompt, negative_prompt, seed_val,
1536
  guidance_scale, num_steps, model_size,
1537
+ crossfade_s, crossfade_db, slot_id)
1538
 
1539
  meta["sr"] = sr
1540
 
 
1651
  meta["silent_video"], seg_start, seg_end - seg_start,
1652
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1653
  )
1654
+ _ctx_store("regen_mmaudio_gpu", {"seg_path": seg_path})
1655
  wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
1656
  prompt, negative_prompt, seed_val,
1657
  cfg_strength, num_steps,
1658
+ crossfade_s, crossfade_db, slot_id)
1659
  return wav, src_sr
1660
 
1661
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
 
1676
  meta["silent_video"], seg_start, seg_end - seg_start,
1677
  os.path.join(tempfile.mkdtemp(), "xregen_seg.mp4"),
1678
  )
1679
+ _ctx_store("regen_hunyuan_gpu", {
1680
  "seg_path": seg_path,
1681
  "text_feats_path": meta.get("text_feats_path", ""),
1682
  })
1683
  wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1684
  prompt, negative_prompt, seed_val,
1685
  guidance_scale, num_steps, model_size,
1686
+ crossfade_s, crossfade_db, slot_id)
1687
  return wav, src_sr
1688
 
1689
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)