BoxOfColors commited on
Commit
efe424b
·
1 Parent(s): 4d46101

feat: add FlashSR post-processing to upsample TARO 16kHz → 48kHz

Browse files

All three models now output at 48kHz (TARO via FlashSR, MMAudio at
44.1kHz natively resampled, HunyuanFoley at 48kHz natively).
FlashSR is applied after generation and after each regen/xregen on
TARO outputs. Console logs confirm each upsampling step with duration
and sample rate. Falls back to sinc resampling if FlashSR errors.

Files changed (2) hide show
  1. app.py +90 -5
  2. requirements.txt +1 -0
app.py CHANGED
@@ -498,6 +498,73 @@ def _taro_infer_segment(
498
  return wav[:seg_samples]
499
 
500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float,
502
  total_dur_s: float, sr: int) -> np.ndarray:
503
  """Crossfade-join a list of wav arrays and trim to *total_dur_s*.
@@ -672,8 +739,15 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
672
  outputs = []
673
  for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
674
  final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
 
 
 
 
 
 
 
675
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
676
- _save_wav(audio_path, final_wav, TARO_SR)
677
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
678
  mux_video_audio(silent_video, audio_path, video_path)
679
  wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
@@ -685,7 +759,7 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
685
  first_cavp_saved = True
686
  seg_meta = _build_seg_meta(
687
  segments=segments, wav_paths=wav_paths, audio_path=audio_path,
688
- video_path=video_path, silent_video=silent_video, sr=TARO_SR,
689
  model="taro", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
690
  total_dur_s=total_dur_s, cavp_path=cavp_path, onset_path=onset_path,
691
  )
@@ -1135,9 +1209,16 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
1135
  seed_val, cfg_scale, num_steps, mode,
1136
  crossfade_s, crossfade_db, slot_id)
1137
 
1138
- # CPU: splice, stitch, mux, save
 
 
 
 
 
 
 
1139
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1140
- new_wav, seg_idx, meta, slot_id
1141
  )
1142
  return video_path, audio_path, json.dumps(updated_meta), waveform_html
1143
 
@@ -1405,7 +1486,11 @@ def xregen_taro(seg_idx, state_json, slot_id,
1405
  new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1406
  seed_val, cfg_scale, num_steps, mode,
1407
  crossfade_s, crossfade_db, slot_id)
1408
- video_path, waveform_html = _xregen_splice(new_wav_raw, TARO_SR, meta, seg_idx, slot_id)
 
 
 
 
1409
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1410
 
1411
 
 
498
  return wav[:seg_samples]
499
 
500
 
501
+ # ================================================================== #
502
+ # FlashSR (16 → 48 kHz) #
503
+ # ================================================================== #
504
+ # FlashSR is used as a post-processing step on TARO outputs only.
505
+ # TARO generates at 16 kHz; FlashSR upsamples to 48 kHz so all three
506
+ # models produce output at the same sample rate.
507
+ # Model weights are downloaded once from HF Hub and cached on disk.
508
+
509
+ _FLASHSR_MODEL = None # module-level cache — loaded once per process
510
+ _FLASHSR_LOCK = threading.Lock()
511
+
512
+ FLASHSR_SR_IN = 16000
513
+ FLASHSR_SR_OUT = 48000
514
+
515
+
516
+ def _load_flashsr():
517
+ """Load FlashSR model (cached after first call). Returns FASR instance."""
518
+ global _FLASHSR_MODEL
519
+ with _FLASHSR_LOCK:
520
+ if _FLASHSR_MODEL is not None:
521
+ return _FLASHSR_MODEL
522
+ print("[FlashSR] Loading model weights from HF Hub …")
523
+ from huggingface_hub import hf_hub_download
524
+ from FastAudioSR import FASR
525
+ ckpt_path = hf_hub_download(
526
+ repo_id="YatharthS/FlashSR",
527
+ filename="upsampler.pth",
528
+ local_dir=os.path.join(os.path.dirname(os.path.abspath(__file__)), ".flashsr_cache"),
529
+ )
530
+ model = FASR(ckpt_path)
531
+ if torch.cuda.is_available():
532
+ model.model.half().cuda()
533
+ print("[FlashSR] Model loaded on GPU (fp16)")
534
+ else:
535
+ print("[FlashSR] Model loaded on CPU (fp32)")
536
+ _FLASHSR_MODEL = model
537
+ return model
538
+
539
+
540
+ def _apply_flashsr(wav_16k: np.ndarray) -> np.ndarray:
541
+ """Upsample a mono 16 kHz numpy array to 48 kHz using FlashSR.
542
+
543
+ Returns a mono float32 numpy array at 48 kHz.
544
+ Falls back to torchaudio sinc resampling if FlashSR fails.
545
+ """
546
+ try:
547
+ model = _load_flashsr()
548
+ t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0)
549
+ if torch.cuda.is_available():
550
+ t = t.half().cuda()
551
+ print(f"[FlashSR] Upsampling {len(wav_16k)/FLASHSR_SR_IN:.2f}s @ 16kHz → 48kHz …")
552
+ with torch.no_grad():
553
+ out = model.run(t)
554
+ # out is a tensor or numpy array — normalise to numpy float32 cpu
555
+ if isinstance(out, torch.Tensor):
556
+ out = out.float().cpu().squeeze().numpy()
557
+ else:
558
+ out = np.array(out, dtype=np.float32).squeeze()
559
+ print(f"[FlashSR] Done — output shape {out.shape}, sr={FLASHSR_SR_OUT}")
560
+ return out
561
+ except Exception as e:
562
+ print(f"[FlashSR] ERROR: {e} — falling back to sinc resampling")
563
+ t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0)
564
+ out = torchaudio.functional.resample(t, FLASHSR_SR_IN, FLASHSR_SR_OUT)
565
+ return out.squeeze().numpy()
566
+
567
+
568
  def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float,
569
  total_dur_s: float, sr: int) -> np.ndarray:
570
  """Crossfade-join a list of wav arrays and trim to *total_dur_s*.
 
739
  outputs = []
740
  for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
741
  final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
742
+
743
+ # ── FlashSR: upsample 16 kHz → 48 kHz ──
744
+ print(f"[TARO] Sample {sample_idx+1}: running FlashSR upsampler (16kHz → 48kHz) …")
745
+ final_wav = _apply_flashsr(final_wav)
746
+ out_sr = FLASHSR_SR_OUT
747
+ print(f"[TARO] Sample {sample_idx+1}: FlashSR complete — {len(final_wav)/out_sr:.2f}s @ {out_sr}Hz")
748
+
749
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
750
+ _save_wav(audio_path, final_wav, out_sr)
751
  video_path = os.path.join(tmp_dir, f"taro_{sample_idx}.mp4")
752
  mux_video_audio(silent_video, audio_path, video_path)
753
  wav_paths = _save_seg_wavs(wavs, tmp_dir, f"taro_{sample_idx}")
 
759
  first_cavp_saved = True
760
  seg_meta = _build_seg_meta(
761
  segments=segments, wav_paths=wav_paths, audio_path=audio_path,
762
+ video_path=video_path, silent_video=silent_video, sr=out_sr,
763
  model="taro", crossfade_s=crossfade_s, crossfade_db=crossfade_db,
764
  total_dur_s=total_dur_s, cavp_path=cavp_path, onset_path=onset_path,
765
  )
 
1209
  seed_val, cfg_scale, num_steps, mode,
1210
  crossfade_s, crossfade_db, slot_id)
1211
 
1212
+ # FlashSR: upsample 16 kHz → 48 kHz before splicing
1213
+ print(f"[TARO regen] Running FlashSR upsampler (16kHz → 48kHz) on seg {seg_idx} …")
1214
+ new_wav = _apply_flashsr(new_wav)
1215
+ print(f"[TARO regen] FlashSR complete — {len(new_wav)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
1216
+
1217
+ # CPU: splice, stitch, mux, save — meta["sr"] must reflect the upsampled rate
1218
+ meta_48k = dict(meta)
1219
+ meta_48k["sr"] = FLASHSR_SR_OUT
1220
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1221
+ new_wav, seg_idx, meta_48k, slot_id
1222
  )
1223
  return video_path, audio_path, json.dumps(updated_meta), waveform_html
1224
 
 
1486
  new_wav_raw = _regen_taro_gpu(None, seg_idx, state_json,
1487
  seed_val, cfg_scale, num_steps, mode,
1488
  crossfade_s, crossfade_db, slot_id)
1489
+ # FlashSR: upsample 16 kHz 48 kHz before splicing into slot
1490
+ print(f"[xregen TARO] Running FlashSR upsampler (16kHz → 48kHz) on seg {seg_idx} …")
1491
+ new_wav_raw = _apply_flashsr(new_wav_raw)
1492
+ print(f"[xregen TARO] FlashSR complete — {len(new_wav_raw)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
1493
+ video_path, waveform_html = _xregen_splice(new_wav_raw, FLASHSR_SR_OUT, meta, seg_idx, slot_id)
1494
  yield gr.update(value=video_path), gr.update(value=waveform_html)
1495
 
1496
 
requirements.txt CHANGED
@@ -21,6 +21,7 @@ loguru
21
  torchdiffeq
22
  open_clip_torch
23
  git+https://github.com/descriptinc/audiotools
 
24
  --extra-index-url https://download.pytorch.org/whl/cu124
25
  torchaudio==2.5.1+cu124
26
  --find-links https://download.openmmlab.com/mmcv/dist/cu121/torch2.4.0/index.html
 
21
  torchdiffeq
22
  open_clip_torch
23
  git+https://github.com/descriptinc/audiotools
24
+ git+https://github.com/ysharma3501/FlashSR.git
25
  --extra-index-url https://download.pytorch.org/whl/cu124
26
  torchaudio==2.5.1+cu124
27
  --find-links https://download.openmmlab.com/mmcv/dist/cu121/torch2.4.0/index.html