BoxOfColors Claude Sonnet 4.6 commited on
Commit
b60f330
·
1 Parent(s): 12556c0

feat: cross-model segment regen — regenerate any slot with TARO, MMAudio, or HunyuanFoley

Browse files

Popup now shows three model buttons (TARO / MMAudio / Hunyuan) instead of
one generic Regenerate. Clicking a different model's button fires one of
three new shared xregen_* Gradio endpoints (xregen_taro, xregen_mmaudio,
xregen_hunyuan) that accept slot_id as a plain string so results are
spliced back into the correct slot regardless of which tab generated it.

New _resample_to_slot_sr() helper resamples the incoming wav to match
the slot's original SR (TARO=16kHz, MMAudio=44.1kHz, Hunyuan=48kHz)
before _splice_and_save stitches it in, so any model can replace any
segment without sample-rate mismatch.

All new Gradio components/buttons use render=False to avoid adding to
the SSR component tree and triggering the 'Too many arguments' warning.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +237 -61
app.py CHANGED
@@ -1313,6 +1313,107 @@ MODEL_CONFIGS["mmaudio"]["regen_fn"] = regen_mmaudio_segment
1313
  MODEL_CONFIGS["hunyuan"]["regen_fn"] = regen_hunyuan_segment
1314
 
1315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1316
  # ================================================================== #
1317
  # SHARED UI HELPERS #
1318
  # ================================================================== #
@@ -1868,18 +1969,16 @@ _GLOBAL_JS = """
1868
 
1869
  // Fire regen for a given slot and segment by posting directly to the
1870
  // Gradio queue API — bypasses Svelte binding entirely.
1871
- function fireRegen(slot_id, seg_idx) {
1872
- // Determine tab prefix from slot_id (e.g. "taro_0" -> "taro")
1873
- const prefix = slot_id.split('_')[0];
 
 
1874
  const slotNum = parseInt(slot_id.split('_')[1], 10);
1875
 
1876
- // Build api_name for this slot's regen handler
1877
- const apiName = 'regen_' + prefix + '_' + slotNum;
1878
- const fnIndex = getFnIndex(apiName);
1879
- if (fnIndex === undefined) {
1880
- console.warn('[fireRegen] fn_index not found for api_name:', apiName, 'cache:', _fnIndexCache);
1881
- return;
1882
- }
1883
 
1884
  // Read state_json from the waveform container data-state attribute
1885
  const container = document.getElementById('wf_container_' + slot_id);
@@ -1889,48 +1988,58 @@ _GLOBAL_JS = """
1889
  return;
1890
  }
1891
 
1892
- // Read current input values from DOM by elem_id
1893
- let data;
1894
- if (prefix === 'taro') {
1895
- const video = null; // video is a file component — pass null, server uses its own state
1896
- data = [
1897
- seg_idx,
1898
- stateJson,
1899
- video,
1900
- readComponentValue('taro_seed'),
1901
- readComponentValue('taro_cfg'),
1902
- readComponentValue('taro_steps'),
1903
- readComponentValue('taro_mode'),
1904
- readComponentValue('taro_cf_dur'),
1905
- readComponentValue('taro_cf_db')
1906
- ];
1907
- } else if (prefix === 'mma') {
1908
- data = [
1909
- seg_idx,
1910
- stateJson,
1911
- null, // video
1912
- readComponentValue('mma_prompt'),
1913
- readComponentValue('mma_neg'),
1914
- readComponentValue('mma_seed'),
1915
- readComponentValue('mma_cfg'),
1916
- readComponentValue('mma_steps'),
1917
- readComponentValue('mma_cf_dur'),
1918
- readComponentValue('mma_cf_db')
1919
- ];
1920
  } else {
1921
- data = [
1922
- seg_idx,
1923
- stateJson,
1924
- null, // video
1925
- readComponentValue('hf_prompt'),
1926
- readComponentValue('hf_neg'),
1927
- readComponentValue('hf_seed'),
1928
- readComponentValue('hf_guidance'),
1929
- readComponentValue('hf_steps'),
1930
- readComponentValue('hf_size'),
1931
- readComponentValue('hf_cf_dur'),
1932
- readComponentValue('hf_cf_db')
1933
- ];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1934
  }
1935
 
1936
  console.log('[fireRegen] calling api', apiName, 'fn_index', fnIndex, 'seg', seg_idx);
@@ -2068,19 +2177,24 @@ _GLOBAL_JS = """
2068
  _popup.style.cssText = 'display:none;position:fixed;z-index:99999;' +
2069
  'background:#2a2a2a;border:1px solid #555;border-radius:6px;' +
2070
  'padding:8px 12px;box-shadow:0 4px 16px rgba(0,0,0,.5);font-family:sans-serif;';
 
 
2071
  _popup.innerHTML =
2072
  '<div id="_wf_popup_lbl" style="color:#ccc;font-size:11px;margin-bottom:6px;white-space:nowrap;"></div>' +
2073
- '<button id="_wf_popup_btn" style="background:#1d6fa5;color:#fff;border:none;' +
2074
- 'border-radius:4px;padding:5px 14px;font-size:12px;cursor:pointer;width:100%;">&#10227; Regenerate</button>';
 
 
 
2075
  document.body.appendChild(_popup);
2076
- document.getElementById('_wf_popup_btn').onclick = function(e) {
2077
- e.stopPropagation(); // prevents the document bubble-phase listener below from firing
2078
- var slot = _pendingSlot, idx = _pendingIdx; // capture before hidePopup clears them
2079
- hidePopup();
2080
- if (slot !== null && idx !== null) {
2081
- fireRegen(slot, idx);
2082
- }
2083
- };
2084
  // Use bubble phase (false) so stopPropagation() on the button click prevents this from firing
2085
  document.addEventListener('click', function() { hidePopup(); }, false);
2086
  return _popup;
@@ -2308,5 +2422,67 @@ with gr.Blocks(title="Generate Audio for Video", css=_SLOT_CSS, js=_GLOBAL_JS) a
2308
  mma_video.change(fn=_sync, inputs=[mma_video], outputs=[taro_video, hf_video])
2309
  hf_video.change(fn=_sync, inputs=[hf_video], outputs=[taro_video, mma_video])
2310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2311
  print("[startup] app.py fully loaded — regen handlers registered, SSR disabled")
2312
  demo.queue(max_size=10).launch(ssr_mode=False, height=900, allowed_paths=["/tmp"])
 
1313
  MODEL_CONFIGS["hunyuan"]["regen_fn"] = regen_hunyuan_segment
1314
 
1315
 
1316
+ # ================================================================== #
1317
+ # CROSS-MODEL REGEN WRAPPERS #
1318
+ # ================================================================== #
1319
+ # Three shared endpoints — one per model — that can be called from #
1320
+ # *any* slot tab. slot_id is passed as plain string data so the #
1321
+ # result is applied back to the correct slot by the JS listener. #
1322
+ # The new segment is resampled to match the slot's existing SR before #
1323
+ # being handed to _splice_and_save, so TARO (16 kHz) / MMAudio #
1324
+ # (44.1 kHz) / Hunyuan (48 kHz) outputs can all be mixed freely. #
1325
+ # ================================================================== #
1326
+
1327
+ def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
1328
+ """Resample *wav* from src_sr to dst_sr using torchaudio.
1329
+ Works for mono (T,) and stereo (C, T) numpy arrays."""
1330
+ if src_sr == dst_sr:
1331
+ return wav
1332
+ stereo = wav.ndim == 2
1333
+ t = torch.from_numpy(np.ascontiguousarray(wav))
1334
+ if not stereo:
1335
+ t = t.unsqueeze(0) # (1, T)
1336
+ t = torchaudio.functional.resample(t.float(), src_sr, dst_sr)
1337
+ if not stereo:
1338
+ t = t.squeeze(0) # (T,)
1339
+ return t.numpy()
1340
+
1341
+
1342
+ def xregen_taro(seg_idx, state_json, slot_id,
1343
+ seed_val, cfg_scale, num_steps, mode,
1344
+ crossfade_s, crossfade_db):
1345
+ """Cross-model regen: run TARO inference and splice into *slot_id*."""
1346
+ meta = json.loads(state_json)
1347
+ slot_sr = int(meta["sr"])
1348
+ new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1349
+ seed_val, cfg_scale, num_steps, mode,
1350
+ crossfade_s, crossfade_db, slot_id)
1351
+ new_wav = _resample_to_slot_sr(new_wav_raw, TARO_SR, slot_sr)
1352
+ video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1353
+ new_wav, int(seg_idx), meta, slot_id
1354
+ )
1355
+ return gr.update(value=video_path), gr.update(value=waveform_html)
1356
+
1357
+
1358
+ def xregen_mmaudio(seg_idx, state_json, slot_id,
1359
+ prompt, negative_prompt, seed_val,
1360
+ cfg_strength, num_steps, crossfade_s, crossfade_db):
1361
+ """Cross-model regen: run MMAudio inference and splice into *slot_id*."""
1362
+ meta = json.loads(state_json)
1363
+ seg_idx = int(seg_idx)
1364
+ seg_start, seg_end = meta["segments"][seg_idx]
1365
+ seg_dur = seg_end - seg_start
1366
+ slot_sr = int(meta["sr"])
1367
+
1368
+ silent_video = meta["silent_video"]
1369
+ tmp_dir = tempfile.mkdtemp()
1370
+ seg_path = os.path.join(tmp_dir, "xregen_seg.mp4")
1371
+ ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1372
+ seg_path, vcodec="copy", an=None
1373
+ ).run(overwrite_output=True, quiet=True)
1374
+ _regen_mmaudio_gpu._cpu_ctx = {"seg_path": seg_path}
1375
+
1376
+ new_wav_raw, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
1377
+ prompt, negative_prompt, seed_val,
1378
+ cfg_strength, num_steps,
1379
+ crossfade_s, crossfade_db, slot_id)
1380
+ new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr)
1381
+ video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1382
+ new_wav, seg_idx, meta, slot_id
1383
+ )
1384
+ return gr.update(value=video_path), gr.update(value=waveform_html)
1385
+
1386
+
1387
+ def xregen_hunyuan(seg_idx, state_json, slot_id,
1388
+ prompt, negative_prompt, seed_val,
1389
+ guidance_scale, num_steps, model_size,
1390
+ crossfade_s, crossfade_db):
1391
+ """Cross-model regen: run HunyuanFoley inference and splice into *slot_id*."""
1392
+ meta = json.loads(state_json)
1393
+ seg_idx = int(seg_idx)
1394
+ seg_start, seg_end = meta["segments"][seg_idx]
1395
+ seg_dur = seg_end - seg_start
1396
+ slot_sr = int(meta["sr"])
1397
+
1398
+ silent_video = meta["silent_video"]
1399
+ tmp_dir = tempfile.mkdtemp()
1400
+ seg_path = os.path.join(tmp_dir, "xregen_seg.mp4")
1401
+ ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1402
+ seg_path, vcodec="copy", an=None
1403
+ ).run(overwrite_output=True, quiet=True)
1404
+ _regen_hunyuan_gpu._cpu_ctx = {"seg_path": seg_path}
1405
+
1406
+ new_wav_raw, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1407
+ prompt, negative_prompt, seed_val,
1408
+ guidance_scale, num_steps, model_size,
1409
+ crossfade_s, crossfade_db, slot_id)
1410
+ new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr)
1411
+ video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1412
+ new_wav, seg_idx, meta, slot_id
1413
+ )
1414
+ return gr.update(value=video_path), gr.update(value=waveform_html)
1415
+
1416
+
1417
  # ================================================================== #
1418
  # SHARED UI HELPERS #
1419
  # ================================================================== #
 
1969
 
1970
  // Fire regen for a given slot and segment by posting directly to the
1971
  // Gradio queue API — bypasses Svelte binding entirely.
1972
+ // targetModel: 'taro' | 'mma' | 'hf' (which model to use for inference)
1973
+ // If targetModel matches the slot's own prefix, uses the per-slot regen_* endpoint.
1974
+ // Otherwise uses the shared xregen_* cross-model endpoint.
1975
+ function fireRegen(slot_id, seg_idx, targetModel) {
1976
+ const prefix = slot_id.split('_')[0]; // owning tab: 'taro'|'mma'|'hf'
1977
  const slotNum = parseInt(slot_id.split('_')[1], 10);
1978
 
1979
+ // Decide which endpoint to call
1980
+ const crossModel = (targetModel !== prefix);
1981
+ let apiName, data;
 
 
 
 
1982
 
1983
  // Read state_json from the waveform container data-state attribute
1984
  const container = document.getElementById('wf_container_' + slot_id);
 
1988
  return;
1989
  }
1990
 
1991
+ if (!crossModel) {
1992
+ // ── Same-model regen: per-slot endpoint, video passed as null ──
1993
+ apiName = 'regen_' + prefix + '_' + slotNum;
1994
+ if (prefix === 'taro') {
1995
+ data = [seg_idx, stateJson, null,
1996
+ readComponentValue('taro_seed'), readComponentValue('taro_cfg'),
1997
+ readComponentValue('taro_steps'), readComponentValue('taro_mode'),
1998
+ readComponentValue('taro_cf_dur'), readComponentValue('taro_cf_db')];
1999
+ } else if (prefix === 'mma') {
2000
+ data = [seg_idx, stateJson, null,
2001
+ readComponentValue('mma_prompt'), readComponentValue('mma_neg'),
2002
+ readComponentValue('mma_seed'), readComponentValue('mma_cfg'),
2003
+ readComponentValue('mma_steps'),
2004
+ readComponentValue('mma_cf_dur'), readComponentValue('mma_cf_db')];
2005
+ } else {
2006
+ data = [seg_idx, stateJson, null,
2007
+ readComponentValue('hf_prompt'), readComponentValue('hf_neg'),
2008
+ readComponentValue('hf_seed'), readComponentValue('hf_guidance'),
2009
+ readComponentValue('hf_steps'), readComponentValue('hf_size'),
2010
+ readComponentValue('hf_cf_dur'), readComponentValue('hf_cf_db')];
2011
+ }
 
 
 
 
 
 
 
2012
  } else {
2013
+ // ── Cross-model regen: shared xregen_* endpoint ──
2014
+ // slot_id is passed so the server knows which slot's state to splice into.
2015
+ // UI params are read from the target model's tab inputs.
2016
+ if (targetModel === 'taro') {
2017
+ apiName = 'xregen_taro';
2018
+ data = [seg_idx, stateJson, slot_id,
2019
+ readComponentValue('taro_seed'), readComponentValue('taro_cfg'),
2020
+ readComponentValue('taro_steps'), readComponentValue('taro_mode'),
2021
+ readComponentValue('taro_cf_dur'), readComponentValue('taro_cf_db')];
2022
+ } else if (targetModel === 'mma') {
2023
+ apiName = 'xregen_mmaudio';
2024
+ data = [seg_idx, stateJson, slot_id,
2025
+ readComponentValue('mma_prompt'), readComponentValue('mma_neg'),
2026
+ readComponentValue('mma_seed'), readComponentValue('mma_cfg'),
2027
+ readComponentValue('mma_steps'),
2028
+ readComponentValue('mma_cf_dur'), readComponentValue('mma_cf_db')];
2029
+ } else {
2030
+ apiName = 'xregen_hunyuan';
2031
+ data = [seg_idx, stateJson, slot_id,
2032
+ readComponentValue('hf_prompt'), readComponentValue('hf_neg'),
2033
+ readComponentValue('hf_seed'), readComponentValue('hf_guidance'),
2034
+ readComponentValue('hf_steps'), readComponentValue('hf_size'),
2035
+ readComponentValue('hf_cf_dur'), readComponentValue('hf_cf_db')];
2036
+ }
2037
+ }
2038
+
2039
+ const fnIndex = getFnIndex(apiName);
2040
+ if (fnIndex === undefined) {
2041
+ console.warn('[fireRegen] fn_index not found for api_name:', apiName, 'cache:', _fnIndexCache);
2042
+ return;
2043
  }
2044
 
2045
  console.log('[fireRegen] calling api', apiName, 'fn_index', fnIndex, 'seg', seg_idx);
 
2177
  _popup.style.cssText = 'display:none;position:fixed;z-index:99999;' +
2178
  'background:#2a2a2a;border:1px solid #555;border-radius:6px;' +
2179
  'padding:8px 12px;box-shadow:0 4px 16px rgba(0,0,0,.5);font-family:sans-serif;';
2180
+ var btnStyle = 'color:#fff;border:none;border-radius:4px;padding:5px 10px;' +
2181
+ 'font-size:11px;cursor:pointer;flex:1;';
2182
  _popup.innerHTML =
2183
  '<div id="_wf_popup_lbl" style="color:#ccc;font-size:11px;margin-bottom:6px;white-space:nowrap;"></div>' +
2184
+ '<div style="display:flex;gap:5px;">' +
2185
+ '<button id="_wf_popup_taro" style="background:#1d6fa5;' + btnStyle + '">&#10227; TARO</button>' +
2186
+ '<button id="_wf_popup_mma" style="background:#2d7a4a;' + btnStyle + '">&#10227; MMAudio</button>' +
2187
+ '<button id="_wf_popup_hf" style="background:#7a3d8c;' + btnStyle + '">&#10227; Hunyuan</button>' +
2188
+ '</div>';
2189
  document.body.appendChild(_popup);
2190
+ ['taro','mma','hf'].forEach(function(model) {
2191
+ document.getElementById('_wf_popup_' + model).onclick = function(e) {
2192
+ e.stopPropagation();
2193
+ var slot = _pendingSlot, idx = _pendingIdx;
2194
+ hidePopup();
2195
+ if (slot !== null && idx !== null) fireRegen(slot, idx, model);
2196
+ };
2197
+ });
2198
  // Use bubble phase (false) so stopPropagation() on the button click prevents this from firing
2199
  document.addEventListener('click', function() { hidePopup(); }, false);
2200
  return _popup;
 
2422
  mma_video.change(fn=_sync, inputs=[mma_video], outputs=[taro_video, hf_video])
2423
  hf_video.change(fn=_sync, inputs=[hf_video], outputs=[taro_video, mma_video])
2424
 
2425
+ # ---- Cross-model regen endpoints ----
2426
+ # render=False inputs/outputs: no DOM elements created, no SSR validation impact.
2427
+ # JS calls these via /gradio_api/queue/join using the api_name and applies
2428
+ # the returned video+waveform directly to the target slot's DOM elements.
2429
+ _xr_seg = gr.Textbox(value="0", render=False)
2430
+ _xr_state = gr.Textbox(value="", render=False)
2431
+ _xr_slot_id = gr.Textbox(value="", render=False)
2432
+ _xr_vid_out = gr.Video(render=False)
2433
+ _xr_wave_out = gr.HTML(render=False)
2434
+
2435
+ # TARO cross-model regen inputs: seg_idx, state_json, slot_id, seed, cfg, steps, mode, cf_dur, cf_db
2436
+ _xr_taro_seed = gr.Textbox(value="-1", render=False)
2437
+ _xr_taro_cfg = gr.Textbox(value="7.5", render=False)
2438
+ _xr_taro_steps = gr.Textbox(value="25", render=False)
2439
+ _xr_taro_mode = gr.Textbox(value="sde", render=False)
2440
+ _xr_taro_cfd = gr.Textbox(value="2", render=False)
2441
+ _xr_taro_cfdb = gr.Textbox(value="3", render=False)
2442
+ gr.Button(render=False).click(
2443
+ fn=xregen_taro,
2444
+ inputs=[_xr_seg, _xr_state, _xr_slot_id,
2445
+ _xr_taro_seed, _xr_taro_cfg, _xr_taro_steps,
2446
+ _xr_taro_mode, _xr_taro_cfd, _xr_taro_cfdb],
2447
+ outputs=[_xr_vid_out, _xr_wave_out],
2448
+ api_name="xregen_taro",
2449
+ )
2450
+
2451
+ # MMAudio cross-model regen inputs: seg_idx, state_json, slot_id, prompt, neg, seed, cfg, steps, cf_dur, cf_db
2452
+ _xr_mma_prompt = gr.Textbox(value="", render=False)
2453
+ _xr_mma_neg = gr.Textbox(value="", render=False)
2454
+ _xr_mma_seed = gr.Textbox(value="-1", render=False)
2455
+ _xr_mma_cfg = gr.Textbox(value="4.5", render=False)
2456
+ _xr_mma_steps = gr.Textbox(value="25", render=False)
2457
+ _xr_mma_cfd = gr.Textbox(value="2", render=False)
2458
+ _xr_mma_cfdb = gr.Textbox(value="3", render=False)
2459
+ gr.Button(render=False).click(
2460
+ fn=xregen_mmaudio,
2461
+ inputs=[_xr_seg, _xr_state, _xr_slot_id,
2462
+ _xr_mma_prompt, _xr_mma_neg, _xr_mma_seed,
2463
+ _xr_mma_cfg, _xr_mma_steps, _xr_mma_cfd, _xr_mma_cfdb],
2464
+ outputs=[_xr_vid_out, _xr_wave_out],
2465
+ api_name="xregen_mmaudio",
2466
+ )
2467
+
2468
+ # HunyuanFoley cross-model regen inputs: seg_idx, state_json, slot_id, prompt, neg, seed, guidance, steps, size, cf_dur, cf_db
2469
+ _xr_hf_prompt = gr.Textbox(value="", render=False)
2470
+ _xr_hf_neg = gr.Textbox(value="", render=False)
2471
+ _xr_hf_seed = gr.Textbox(value="-1", render=False)
2472
+ _xr_hf_guide = gr.Textbox(value="4.5", render=False)
2473
+ _xr_hf_steps = gr.Textbox(value="50", render=False)
2474
+ _xr_hf_size = gr.Textbox(value="xxl", render=False)
2475
+ _xr_hf_cfd = gr.Textbox(value="2", render=False)
2476
+ _xr_hf_cfdb = gr.Textbox(value="3", render=False)
2477
+ gr.Button(render=False).click(
2478
+ fn=xregen_hunyuan,
2479
+ inputs=[_xr_seg, _xr_state, _xr_slot_id,
2480
+ _xr_hf_prompt, _xr_hf_neg, _xr_hf_seed,
2481
+ _xr_hf_guide, _xr_hf_steps, _xr_hf_size,
2482
+ _xr_hf_cfd, _xr_hf_cfdb],
2483
+ outputs=[_xr_vid_out, _xr_wave_out],
2484
+ api_name="xregen_hunyuan",
2485
+ )
2486
+
2487
  print("[startup] app.py fully loaded — regen handlers registered, SSR disabled")
2488
  demo.queue(max_size=10).launch(ssr_mode=False, height=900, allowed_paths=["/tmp"])