BoxOfColors commited on
Commit
585a112
·
1 Parent(s): 72afd74

Unify mono and stereo crossfade into a single _cf_join function

Browse files

Replace separate _crossfade_join (mono/TARO) and _cf_join_stereo (stereo/
MMAudio+HunyuanFoley) with one _cf_join that handles both shapes via a.ndim
check: stereo (C,T) uses axis=1 slicing, mono (T,) uses 1D indexing.
Update _stitch_wavs to accept an sr parameter and call _cf_join, and update
its call site in generate_taro to pass TARO_SR explicitly.

Files changed (1) hide show
  1. app.py +24 -31
app.py CHANGED
@@ -127,19 +127,26 @@ def _build_segments(total_dur_s: float, window_s: float, crossfade_s: float) ->
127
  return segments
128
 
129
 
130
- def _cf_join_stereo(a: np.ndarray, b: np.ndarray,
131
- crossfade_s: float, db_boost: float, sr: int) -> np.ndarray:
132
- """Equal-power crossfade join for stereo (C, T) numpy arrays."""
133
- cf = int(round(crossfade_s * sr))
134
- cf = min(cf, a.shape[1], b.shape[1])
 
 
 
135
  if cf <= 0:
136
- return np.concatenate([a, b], axis=1)
137
- gain = 10 ** (db_boost / 20.0)
138
- t = np.linspace(0.0, 1.0, cf, dtype=np.float32)
139
  fade_out = np.cos(t * np.pi / 2) # 1 → 0
140
  fade_in = np.sin(t * np.pi / 2) # 0 → 1
141
- overlap = a[:, -cf:] * fade_out * gain + b[:, :cf] * fade_in * gain
142
- return np.concatenate([a[:, :-cf], overlap, b[:, cf:]], axis=1)
 
 
 
 
143
 
144
 
145
  # ================================================================== #
@@ -237,26 +244,12 @@ def _taro_infer_segment(
237
  return wav[:seg_samples]
238
 
239
 
240
- def _crossfade_join(wav_a: np.ndarray, wav_b: np.ndarray,
241
- crossfade_s: float, db_boost: float) -> np.ndarray:
242
- cf = int(round(crossfade_s * TARO_SR))
243
- cf = min(cf, len(wav_a), len(wav_b))
244
- if cf <= 0:
245
- return np.concatenate([wav_a, wav_b])
246
- gain = 10 ** (db_boost / 20.0)
247
- # Equal-power fade: fade-out a, fade-in b over the overlap region
248
- t = np.linspace(0.0, 1.0, cf, dtype=np.float32)
249
- fade_out = np.cos(t * np.pi / 2) # 1 → 0
250
- fade_in = np.sin(t * np.pi / 2) # 0 → 1
251
- overlap = wav_a[-cf:] * fade_out * gain + wav_b[:cf] * fade_in * gain
252
- return np.concatenate([wav_a[:-cf], overlap, wav_b[cf:]])
253
-
254
-
255
- def _stitch_wavs(wavs: list, crossfade_s: float, db_boost: float, total_dur_s: float) -> np.ndarray:
256
  out = wavs[0]
257
  for nw in wavs[1:]:
258
- out = _crossfade_join(out, nw, crossfade_s, db_boost)
259
- return out[:int(round(total_dur_s * TARO_SR))]
260
 
261
 
262
  @spaces.GPU(duration=600)
@@ -360,7 +353,7 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
360
  wavs.append(wav)
361
  _TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
362
 
363
- final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s)
364
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
365
  sf.write(audio_path, final_wav, TARO_SR)
366
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
@@ -495,7 +488,7 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
495
  # Crossfade-stitch all segments using shared equal-power helper
496
  full_wav = seg_audios[0]
497
  for nw in seg_audios[1:]:
498
- full_wav = _cf_join_stereo(full_wav, nw, MMA_CF_S, MMA_CF_DB, sr)
499
  full_wav = full_wav[:, : int(round(total_dur_s * sr))]
500
 
501
  audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.flac")
@@ -631,7 +624,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
631
  # Crossfade-stitch all segments using shared equal-power helper
632
  full_wav = seg_wavs[0]
633
  for nw in seg_wavs[1:]:
634
- full_wav = _cf_join_stereo(full_wav, nw, CF_S, CF_DB, sr)
635
  # Trim to exact video duration
636
  full_wav = full_wav[:, : int(round(total_dur_s * sr))]
637
 
 
127
  return segments
128
 
129
 
130
+ def _cf_join(a: np.ndarray, b: np.ndarray,
131
+ crossfade_s: float, db_boost: float, sr: int) -> np.ndarray:
132
+ """Equal-power crossfade join. Works for both mono (T,) and stereo (C, T) arrays.
133
+ Stereo arrays are expected in (channels, samples) layout."""
134
+ stereo = a.ndim == 2
135
+ n_a = a.shape[1] if stereo else len(a)
136
+ n_b = b.shape[1] if stereo else len(b)
137
+ cf = min(int(round(crossfade_s * sr)), n_a, n_b)
138
  if cf <= 0:
139
+ return np.concatenate([a, b], axis=1 if stereo else 0)
140
+ gain = 10 ** (db_boost / 20.0)
141
+ t = np.linspace(0.0, 1.0, cf, dtype=np.float32)
142
  fade_out = np.cos(t * np.pi / 2) # 1 → 0
143
  fade_in = np.sin(t * np.pi / 2) # 0 → 1
144
+ if stereo:
145
+ overlap = a[:, -cf:] * fade_out * gain + b[:, :cf] * fade_in * gain
146
+ return np.concatenate([a[:, :-cf], overlap, b[:, cf:]], axis=1)
147
+ else:
148
+ overlap = a[-cf:] * fade_out * gain + b[:cf] * fade_in * gain
149
+ return np.concatenate([a[:-cf], overlap, b[cf:]])
150
 
151
 
152
  # ================================================================== #
 
244
  return wav[:seg_samples]
245
 
246
 
247
+ def _stitch_wavs(wavs: list, crossfade_s: float, db_boost: float,
248
+ total_dur_s: float, sr: int) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  out = wavs[0]
250
  for nw in wavs[1:]:
251
+ out = _cf_join(out, nw, crossfade_s, db_boost, sr)
252
+ return out[:int(round(total_dur_s * sr))]
253
 
254
 
255
  @spaces.GPU(duration=600)
 
353
  wavs.append(wav)
354
  _TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
355
 
356
+ final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
357
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
358
  sf.write(audio_path, final_wav, TARO_SR)
359
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
 
488
  # Crossfade-stitch all segments using shared equal-power helper
489
  full_wav = seg_audios[0]
490
  for nw in seg_audios[1:]:
491
+ full_wav = _cf_join(full_wav, nw, MMA_CF_S, MMA_CF_DB, sr)
492
  full_wav = full_wav[:, : int(round(total_dur_s * sr))]
493
 
494
  audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.flac")
 
624
  # Crossfade-stitch all segments using shared equal-power helper
625
  full_wav = seg_wavs[0]
626
  for nw in seg_wavs[1:]:
627
+ full_wav = _cf_join(full_wav, nw, CF_S, CF_DB, sr)
628
  # Trim to exact video duration
629
  full_wav = full_wav[:, : int(round(total_dur_s * sr))]
630