BoxOfColors Claude Sonnet 4.6 commited on
Commit
894b188
·
1 Parent(s): 8b87263

Remove channel-layout safety nets; assert stereo (2,T) everywhere

Browse files

All seg wavs are now guaranteed (2,T) by _save_seg_wavs/_to_stereo.
Remove the silent fallbacks that were masking shape bugs:

- _load_seg_wavs: assert (2,T) instead of silently squeezing (1,T)
- _splice_and_save: assert new_wav is (2,T) on entry; remove
_normalize_channel_layout call (function deleted)
- _resample_to_slot_sr: always call _to_stereo(); drop slot_wav_ref
channel-matching logic and defensive squeezes

Any shape violation now raises AssertionError immediately with the
offending shape, instead of silently producing wrong-shaped output.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +11 -48
app.py CHANGED
@@ -339,16 +339,14 @@ def _save_seg_wavs(wavs: list[np.ndarray], tmp_dir: str, prefix: str) -> list[st
339
 
340
  def _load_seg_wavs(paths: list[str]) -> list[np.ndarray]:
341
  """Load segment wav arrays from .npy file paths.
342
-
343
- Normalises (1, T) arrays (T,) mono so that single-channel output from
344
- models like HunyuanFoley (DAC decoder emits shape (1, T)) never causes a
345
- shape mismatch in _cf_join when mixed with true stereo (2, T) arrays.
346
  """
347
  wavs = []
348
  for p in paths:
349
  w = np.load(p)
350
- if w.ndim == 2 and w.shape[0] == 1:
351
- w = w.squeeze(0) # (1, T) (T,) mono
352
  wavs.append(w)
353
  return wavs
354
 
@@ -1395,30 +1393,15 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
1395
  # ================================================================== #
1396
 
1397
 
1398
- def _normalize_channel_layout(wavs: list[np.ndarray]) -> list[np.ndarray]:
1399
- """Ensure all wavs in *wavs* share the same channel layout.
1400
-
1401
- Rule: stereo wins. If ANY segment is stereo (2, T), all mono (T,)
1402
- segments are duplicated to (2, T). This preserves MMAudio's genuine
1403
- stereo output even when the slot also contains TARO or HunyuanFoley
1404
- mono segments. (1, T) arrays should already be squeezed by
1405
- _load_seg_wavs, but we handle them defensively here too.)
1406
- """
1407
- # Squeeze any residual (1, T) to (T,) first
1408
- wavs = [w.squeeze(0) if (w.ndim == 2 and w.shape[0] == 1) else w for w in wavs]
1409
- has_stereo = any(w.ndim == 2 and w.shape[0] == 2 for w in wavs)
1410
- if not has_stereo:
1411
- return wavs
1412
- return [np.stack([w, w], axis=0) if w.ndim == 1 else w for w in wavs]
1413
-
1414
-
1415
  def _splice_and_save(new_wav, seg_idx, meta, slot_id):
1416
  """Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux.
1417
  Returns (video_path, audio_path, updated_meta, waveform_html).
 
1418
  """
 
 
1419
  wavs = _load_seg_wavs(meta["wav_paths"])
1420
  wavs[seg_idx]= new_wav
1421
- wavs = _normalize_channel_layout(wavs)
1422
  crossfade_s = float(meta["crossfade_s"])
1423
  crossfade_db = float(meta["crossfade_db"])
1424
  sr = int(meta["sr"])
@@ -1727,31 +1710,11 @@ MODEL_CONFIGS["hunyuan"]["regen_fn"] = regen_hunyuan_segment
1727
 
1728
  def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int,
1729
  slot_wav_ref: np.ndarray = None) -> np.ndarray:
1730
- """Resample *wav* from src_sr to dst_sr, then match channel layout to
1731
- *slot_wav_ref* (the first existing segment in the slot).
1732
-
1733
- Stereo wins: if either the new wav or the slot reference is stereo,
1734
- the mono side is duplicated to (2, T). This preserves MMAudio's
1735
- genuine stereo rather than averaging it down to mono.
1736
- (1, T) pseudo-stereo from HunyuanFoley's DAC is squeezed to mono first.
1737
- """
1738
  wav = _resample_to_target(wav, src_sr, dst_sr)
1739
-
1740
- # Squeeze (1, T) → (T,) before channel decision
1741
- if wav.ndim == 2 and wav.shape[0] == 1:
1742
- wav = wav.squeeze(0)
1743
-
1744
- if slot_wav_ref is not None:
1745
- # Squeeze slot ref too, defensively
1746
- ref = slot_wav_ref.squeeze(0) if (slot_wav_ref.ndim == 2 and slot_wav_ref.shape[0] == 1) else slot_wav_ref
1747
- slot_stereo = ref.ndim == 2 and ref.shape[0] == 2
1748
- wav_stereo = wav.ndim == 2 and wav.shape[0] == 2
1749
- if slot_stereo and not wav_stereo:
1750
- wav = np.stack([wav, wav], axis=0) # mono → stereo (C, T)
1751
- elif wav_stereo and not slot_stereo:
1752
- pass # keep new wav stereo; _normalize_channel_layout in
1753
- # _splice_and_save will upcast the existing mono segs
1754
- return wav
1755
 
1756
 
1757
  def _resolve_silent_video(meta: dict) -> str:
 
339
 
340
  def _load_seg_wavs(paths: list[str]) -> list[np.ndarray]:
341
  """Load segment wav arrays from .npy file paths.
342
+ All files on disk are expected to be stereo (2, T) — _save_seg_wavs
343
+ guarantees this. Raises AssertionError if any array has unexpected shape.
 
 
344
  """
345
  wavs = []
346
  for p in paths:
347
  w = np.load(p)
348
+ assert w.ndim == 2 and w.shape[0] == 2, \
349
+ f"Expected stereo (2, T) in {p}, got shape {w.shape}"
350
  wavs.append(w)
351
  return wavs
352
 
 
1393
  # ================================================================== #
1394
 
1395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1396
  def _splice_and_save(new_wav, seg_idx, meta, slot_id):
1397
  """Replace wavs[seg_idx] with new_wav, re-stitch, re-save, re-mux.
1398
  Returns (video_path, audio_path, updated_meta, waveform_html).
1399
+ All wavs (loaded from disk and new_wav) must be stereo (2, T).
1400
  """
1401
+ assert new_wav.ndim == 2 and new_wav.shape[0] == 2, \
1402
+ f"new_wav must be stereo (2, T), got shape {new_wav.shape}"
1403
  wavs = _load_seg_wavs(meta["wav_paths"])
1404
  wavs[seg_idx]= new_wav
 
1405
  crossfade_s = float(meta["crossfade_s"])
1406
  crossfade_db = float(meta["crossfade_db"])
1407
  sr = int(meta["sr"])
 
1710
 
1711
  def _resample_to_slot_sr(wav: np.ndarray, src_sr: int, dst_sr: int,
1712
  slot_wav_ref: np.ndarray = None) -> np.ndarray:
1713
+ """Resample *wav* from src_sr to dst_sr and convert to stereo (2, T).
1714
+ slot_wav_ref is unused (kept for call-site compatibility) all wavs
1715
+ are now always stereo so no per-slot channel matching is needed."""
 
 
 
 
 
1716
  wav = _resample_to_target(wav, src_sr, dst_sr)
1717
+ return _to_stereo(wav)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1718
 
1719
 
1720
  def _resolve_silent_video(meta: dict) -> str: