BoxOfColors commited on
Commit
47fd0ad
·
1 Parent(s): 166fb8e

xregen: use target model's optimal window centered on segment midpoint

Browse files

Each xregen_* now:
1. Computes optimal clip window centered on segment midpoint (clamped to video)
2. Runs _build_segments on that clip with target model's window size
3. Calls full generation GPU pipeline (same path as initial generation)
4. Stitches sub-segments with _stitch_wavs + contact-edge trimming
5. Returns (wav, sr, clip_start_s) so _xregen_splice can align to original grid

Handles all cases: target window > span (single inference), target window <
span (multiple sub-segments), video shorter than target window (clamped).

Files changed (1) hide show
  1. app.py +141 -26
app.py CHANGED
@@ -1617,13 +1617,64 @@ def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int,
1617
  return wav
1618
 
1619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1620
  def _xregen_splice(new_wav_raw: np.ndarray, src_sr: int,
1621
- meta: dict, seg_idx: int, slot_id: str) -> tuple:
 
1622
  """Shared epilogue for all xregen_* functions: resample → splice → save.
1623
- Returns (video_path, waveform_html)."""
 
 
 
 
 
 
 
1624
  slot_sr = int(meta["sr"])
1625
  slot_wavs = _load_seg_wavs(meta["wav_paths"])
1626
  new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1627
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1628
  new_wav, seg_idx, meta, slot_id
1629
  )
@@ -1635,8 +1686,8 @@ def _xregen_dispatch(state_json: str, seg_idx: int, slot_id: str, infer_fn):
1635
 
1636
  Yields pending HTML immediately, then calls *infer_fn()* — a zero-argument
1637
  callable that runs model-specific CPU prep + GPU inference and returns
1638
- (wav_array, src_sr). For TARO, *infer_fn* should return the wav already
1639
- upsampled to 48 kHz; pass TARO_SR_OUT as src_sr.
1640
 
1641
  Yields:
1642
  First: (gr.update(), gr.update(value=pending_html)) — shown while GPU runs
@@ -1646,8 +1697,8 @@ def _xregen_dispatch(state_json: str, seg_idx: int, slot_id: str, infer_fn):
1646
  pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1647
  yield gr.update(), gr.update(value=pending_html)
1648
 
1649
- new_wav_raw, src_sr = infer_fn()
1650
- video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id)
1651
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1652
 
1653
 
@@ -1655,16 +1706,35 @@ def xregen_taro(seg_idx, state_json, slot_id,
1655
  seed_val, cfg_scale, num_steps, mode,
1656
  crossfade_s, crossfade_db,
1657
  request: gr.Request = None):
1658
- """Cross-model regen: run TARO inference and splice into *slot_id*."""
1659
  seg_idx = int(seg_idx)
1660
  meta = json.loads(state_json)
1661
 
1662
  def _run():
1663
- # CAVP/onset features are loaded from disk paths inside the GPU fn
1664
- wav = _regen_taro_gpu(None, seg_idx, state_json,
1665
- seed_val, cfg_scale, num_steps, mode,
1666
- crossfade_s, crossfade_db, slot_id)
1667
- return _upsample_taro(wav), TARO_SR_OUT # 16 kHz → 48 kHz (CPU)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1668
 
1669
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
1670
 
@@ -1673,16 +1743,37 @@ def xregen_mmaudio(seg_idx, state_json, slot_id,
1673
  prompt, negative_prompt, seed_val,
1674
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1675
  request: gr.Request = None):
1676
- """Cross-model regen: run MMAudio inference and splice into *slot_id*."""
1677
  seg_idx = int(seg_idx)
 
1678
 
1679
  def _run():
1680
- # Segment clip extraction happens inside _regen_mmaudio_gpu
1681
- wav, src_sr = _regen_mmaudio_gpu(None, seg_idx, state_json,
1682
- prompt, negative_prompt, seed_val,
1683
- cfg_strength, num_steps,
1684
- crossfade_s, crossfade_db, slot_id)
1685
- return wav, src_sr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1686
 
1687
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
1688
 
@@ -1692,16 +1783,40 @@ def xregen_hunyuan(seg_idx, state_json, slot_id,
1692
  guidance_scale, num_steps, model_size,
1693
  crossfade_s, crossfade_db,
1694
  request: gr.Request = None):
1695
- """Cross-model regen: run HunyuanFoley inference and splice into *slot_id*."""
1696
  seg_idx = int(seg_idx)
 
1697
 
1698
  def _run():
1699
- # Segment clip extraction happens inside _regen_hunyuan_gpu
1700
- wav, src_sr = _regen_hunyuan_gpu(None, seg_idx, state_json,
1701
- prompt, negative_prompt, seed_val,
1702
- guidance_scale, num_steps, model_size,
1703
- crossfade_s, crossfade_db, slot_id)
1704
- return wav, src_sr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1705
 
1706
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
1707
 
 
1617
  return wav
1618
 
1619
 
1620
+ def _xregen_clip_window(meta: dict, seg_idx: int, target_window_s: float) -> tuple:
1621
+ """Compute the video clip window for a cross-model regen.
1622
+
1623
+ Centers *target_window_s* on the original segment's midpoint, clamped to
1624
+ [0, total_dur_s]. Returns (clip_start, clip_end, clip_dur).
1625
+
1626
+ If the video is shorter than *target_window_s*, the full video is used
1627
+ (suboptimal but never breaks). If the segment span exceeds
1628
+ *target_window_s*, the caller should run _build_segments on the span and
1629
+ generate multiple sub-segments — but the clip window is still returned as
1630
+ the full segment span so the caller can decide.
1631
+ """
1632
+ total_dur_s = float(meta["total_dur_s"])
1633
+ seg_start, seg_end = meta["segments"][seg_idx]
1634
+ seg_mid = (seg_start + seg_end) / 2.0
1635
+ half_win = target_window_s / 2.0
1636
+
1637
+ clip_start = max(0.0, seg_mid - half_win)
1638
+ clip_end = min(total_dur_s, seg_mid + half_win)
1639
+ # If clamped at one end, extend the other to preserve full window if possible
1640
+ if clip_start == 0.0:
1641
+ clip_end = min(total_dur_s, target_window_s)
1642
+ elif clip_end == total_dur_s:
1643
+ clip_start = max(0.0, total_dur_s - target_window_s)
1644
+ clip_dur = clip_end - clip_start
1645
+ return clip_start, clip_end, clip_dur
1646
+
1647
+
1648
  def _xregen_splice(new_wav_raw: np.ndarray, src_sr: int,
1649
+ meta: dict, seg_idx: int, slot_id: str,
1650
+ clip_start_s: float = None) -> tuple:
1651
  """Shared epilogue for all xregen_* functions: resample → splice → save.
1652
+ Returns (video_path, waveform_html).
1653
+
1654
+ *clip_start_s* is the absolute video time where new_wav_raw starts.
1655
+ When the clip was centered on the segment midpoint (not at seg_start),
1656
+ we need to shift the wav so _stitch_wavs can trim it correctly relative
1657
+ to the original segment's start. We do this by prepending silence so
1658
+ the wav's time origin aligns with the original segment's start.
1659
+ """
1660
  slot_sr = int(meta["sr"])
1661
  slot_wavs = _load_seg_wavs(meta["wav_paths"])
1662
  new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0])
1663
+
1664
+ # If the clip started before the original segment start, prepend silence
1665
+ # so that sample index 0 of new_wav corresponds to seg_start in video time.
1666
+ if clip_start_s is not None:
1667
+ seg_start = meta["segments"][seg_idx][0]
1668
+ offset_s = seg_start - clip_start_s # positive = seg starts after clip start
1669
+ if offset_s < 0:
1670
+ # clip started after seg_start — prepend silence to align
1671
+ pad_samples = int(round(abs(offset_s) * slot_sr))
1672
+ silence = np.zeros(
1673
+ (new_wav.shape[0], pad_samples) if new_wav.ndim == 2 else pad_samples,
1674
+ dtype=new_wav.dtype,
1675
+ )
1676
+ new_wav = np.concatenate([silence, new_wav], axis=1 if new_wav.ndim == 2 else 0)
1677
+
1678
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1679
  new_wav, seg_idx, meta, slot_id
1680
  )
 
1686
 
1687
  Yields pending HTML immediately, then calls *infer_fn()* — a zero-argument
1688
  callable that runs model-specific CPU prep + GPU inference and returns
1689
+ (wav_array, src_sr, clip_start_s). For TARO, *infer_fn* should return
1690
+ the wav already upsampled to 48 kHz; pass TARO_SR_OUT as src_sr.
1691
 
1692
  Yields:
1693
  First: (gr.update(), gr.update(value=pending_html)) — shown while GPU runs
 
1697
  pending_html = _build_regen_pending_html(meta["segments"], seg_idx, slot_id, "")
1698
  yield gr.update(), gr.update(value=pending_html)
1699
 
1700
+ new_wav_raw, src_sr, clip_start_s = infer_fn()
1701
+ video_path, waveform_html = _xregen_splice(new_wav_raw, src_sr, meta, seg_idx, slot_id, clip_start_s)
1702
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1703
 
1704
 
 
1706
  seed_val, cfg_scale, num_steps, mode,
1707
  crossfade_s, crossfade_db,
1708
  request: gr.Request = None):
1709
+ """Cross-model regen: run TARO on its optimal window, splice into *slot_id*."""
1710
  seg_idx = int(seg_idx)
1711
  meta = json.loads(state_json)
1712
 
1713
  def _run():
1714
+ clip_start, clip_end, clip_dur = _xregen_clip_window(meta, seg_idx, TARO_MODEL_DUR)
1715
+ tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
1716
+ clip_path = _extract_segment_clip(
1717
+ meta["silent_video"], clip_start, clip_dur,
1718
+ os.path.join(tmp_dir, "xregen_taro_clip.mp4"),
1719
+ )
1720
+ # Build a minimal fake-video meta so generate_taro can run on clip_path
1721
+ sub_segs = _build_segments(clip_dur, TARO_MODEL_DUR, float(crossfade_s))
1722
+ sub_meta_json = json.dumps({
1723
+ "segments": sub_segs, "silent_video": clip_path,
1724
+ "total_dur_s": clip_dur,
1725
+ })
1726
+ # Run full TARO generation pipeline on the clip
1727
+ _ctx_store("taro_gpu_infer", {
1728
+ "tmp_dir": tmp_dir, "silent_video": clip_path,
1729
+ "segments": sub_segs, "total_dur_s": clip_dur,
1730
+ })
1731
+ results = _taro_gpu_infer(clip_path, seed_val, cfg_scale, num_steps, mode,
1732
+ crossfade_s, crossfade_db, 1)
1733
+ wavs, _, _ = results[0]
1734
+ wavs = [_upsample_taro(w) for w in wavs]
1735
+ wav = _stitch_wavs(wavs, float(crossfade_s), float(crossfade_db),
1736
+ clip_dur, TARO_SR_OUT, sub_segs)
1737
+ return wav, TARO_SR_OUT, clip_start
1738
 
1739
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
1740
 
 
1743
  prompt, negative_prompt, seed_val,
1744
  cfg_strength, num_steps, crossfade_s, crossfade_db,
1745
  request: gr.Request = None):
1746
+ """Cross-model regen: run MMAudio on its optimal window, splice into *slot_id*."""
1747
  seg_idx = int(seg_idx)
1748
+ meta = json.loads(state_json)
1749
 
1750
  def _run():
1751
+ clip_start, clip_end, clip_dur = _xregen_clip_window(meta, seg_idx, MMAUDIO_WINDOW)
1752
+ tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
1753
+ clip_path = _extract_segment_clip(
1754
+ meta["silent_video"], clip_start, clip_dur,
1755
+ os.path.join(tmp_dir, "xregen_mmaudio_clip.mp4"),
1756
+ )
1757
+ sub_segs = _build_segments(clip_dur, MMAUDIO_WINDOW, float(crossfade_s))
1758
+ seg_clip_paths = [
1759
+ _extract_segment_clip(
1760
+ clip_path, s, e - s,
1761
+ os.path.join(tmp_dir, f"xregen_mma_sub_{i}.mp4"),
1762
+ )
1763
+ for i, (s, e) in enumerate(sub_segs)
1764
+ ]
1765
+ _ctx_store("mmaudio_gpu_infer", {
1766
+ "segments": sub_segs, "seg_clip_paths": seg_clip_paths,
1767
+ })
1768
+ results = _mmaudio_gpu_infer(clip_path, prompt, negative_prompt, seed_val,
1769
+ cfg_strength, num_steps, crossfade_s, crossfade_db, 1)
1770
+ seg_wavs, sr = results[0]
1771
+ wav = _stitch_wavs(seg_wavs, float(crossfade_s), float(crossfade_db),
1772
+ clip_dur, sr, sub_segs)
1773
+ if sr != TARGET_SR:
1774
+ wav = _resample_to_target(wav, sr)
1775
+ sr = TARGET_SR
1776
+ return wav, sr, clip_start
1777
 
1778
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
1779
 
 
1783
  guidance_scale, num_steps, model_size,
1784
  crossfade_s, crossfade_db,
1785
  request: gr.Request = None):
1786
+ """Cross-model regen: run HunyuanFoley on its optimal window, splice into *slot_id*."""
1787
  seg_idx = int(seg_idx)
1788
+ meta = json.loads(state_json)
1789
 
1790
  def _run():
1791
+ clip_start, clip_end, clip_dur = _xregen_clip_window(meta, seg_idx, HUNYUAN_MAX_DUR)
1792
+ tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
1793
+ clip_path = _extract_segment_clip(
1794
+ meta["silent_video"], clip_start, clip_dur,
1795
+ os.path.join(tmp_dir, "xregen_hunyuan_clip.mp4"),
1796
+ )
1797
+ sub_segs = _build_segments(clip_dur, HUNYUAN_MAX_DUR, float(crossfade_s))
1798
+ seg_clip_paths = [
1799
+ _extract_segment_clip(
1800
+ clip_path, s, e - s,
1801
+ os.path.join(tmp_dir, f"xregen_hny_sub_{i}.mp4"),
1802
+ )
1803
+ for i, (s, e) in enumerate(sub_segs)
1804
+ ]
1805
+ dummy_seg_path = _extract_segment_clip(
1806
+ clip_path, 0, min(clip_dur, HUNYUAN_MAX_DUR),
1807
+ os.path.join(tmp_dir, "xregen_hny_dummy.mp4"),
1808
+ )
1809
+ _ctx_store("hunyuan_gpu_infer", {
1810
+ "segments": sub_segs, "total_dur_s": clip_dur,
1811
+ "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
1812
+ })
1813
+ results = _hunyuan_gpu_infer(clip_path, prompt, negative_prompt, seed_val,
1814
+ guidance_scale, num_steps, model_size,
1815
+ crossfade_s, crossfade_db, 1)
1816
+ seg_wavs, sr, _ = results[0]
1817
+ wav = _stitch_wavs(seg_wavs, float(crossfade_s), float(crossfade_db),
1818
+ clip_dur, sr, sub_segs)
1819
+ return wav, sr, clip_start
1820
 
1821
  yield from _xregen_dispatch(state_json, seg_idx, slot_id, _run)
1822