BoxOfColors Claude Opus 4.6 commited on
Commit
13cc4e6
·
1 Parent(s): a4226e1

perf: move CPU work outside @spaces.GPU to reduce ZeroGPU cost

Browse files

Split all generate and regen functions into CPU wrapper + GPU-only
inner function pattern. CPU pre/post-processing (ffmpeg, torchaudio,
numpy stitching, muxing) now runs outside @spaces.GPU boundary.
Saves ~5-12s of GPU reservation time per call.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +282 -127
app.py CHANGED
@@ -402,14 +402,14 @@ def _stitch_wavs(wavs: list, crossfade_s: float, db_boost: float,
402
 
403
 
404
  @spaces.GPU(duration=_taro_duration)
405
- def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
406
- crossfade_s, crossfade_db, num_samples):
407
- """TARO: video-conditioned diffusion, 16 kHz, 8.192 s sliding window."""
 
408
  global _TARO_INFERENCE_CACHE
409
 
410
  seed_val = int(seed_val)
411
  crossfade_s = float(crossfade_s)
412
- crossfade_db = float(crossfade_db)
413
  num_samples = int(num_samples)
414
  if seed_val < 0:
415
  seed_val = random.randint(0, 2**32 - 1)
@@ -418,8 +418,6 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
418
  device = "cuda" if torch.cuda.is_available() else "cpu"
419
  weight_dtype = torch.bfloat16
420
 
421
- # TARO modules use bare imports (e.g. `from cavp_util import ...`) that
422
- # assume the TARO directory is on sys.path. Add it before importing.
423
  _taro_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "TARO")
424
  if _taro_dir not in sys.path:
425
  sys.path.insert(0, _taro_dir)
@@ -427,21 +425,19 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
427
  from TARO.onset_util import extract_onset
428
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
429
 
 
 
 
 
 
 
 
430
  extract_cavp, onset_model = _load_taro_feature_extractors(device)
431
  model, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
432
 
433
- # -- Prepare silent video (shared across all samples) --
434
- tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
435
- silent_video = os.path.join(tmp_dir, "silent_input.mp4")
436
- strip_audio_from_video(video_file, silent_video)
437
-
438
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
439
- # Use actual video duration from ffprobe — CAVP frame count can under-count
440
- # if the extractor drops the last partial window, leading to truncated audio.
441
- total_dur_s = get_video_duration(video_file)
442
- segments = _build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s)
443
 
444
- outputs = []
445
  for sample_idx in range(num_samples):
446
  sample_seed = seed_val + sample_idx
447
  cache_key = (video_file, sample_seed, float(cfg_scale), int(num_steps), mode, crossfade_s)
@@ -450,7 +446,7 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
450
  cached = _TARO_INFERENCE_CACHE.get(cache_key)
451
  if cached is not None:
452
  print(f"[TARO] Sample {sample_idx+1}: cache hit.")
453
- wavs = cached["wavs"]
454
  else:
455
  set_global_seed(sample_seed)
456
  onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
@@ -478,7 +474,42 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
478
  _TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
479
  while len(_TARO_INFERENCE_CACHE) > _TARO_CACHE_MAXLEN:
480
  _TARO_INFERENCE_CACHE.pop(next(iter(_TARO_INFERENCE_CACHE)))
 
 
 
 
 
 
481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
  final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
483
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
484
  torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(final_wav)).unsqueeze(0), TARO_SR)
@@ -489,7 +520,8 @@ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
489
  cavp_path = os.path.join(tmp_dir, f"taro_{sample_idx}_cavp.npy")
490
  onset_path = os.path.join(tmp_dir, f"taro_{sample_idx}_onset.npy")
491
  np.save(cavp_path, cavp_feats)
492
- np.save(onset_path, onset_feats)
 
493
  seg_meta = {
494
  "segments": segments,
495
  "wav_paths": wav_paths,
@@ -539,51 +571,33 @@ def _mmaudio_duration(video_file, prompt, negative_prompt, seed_val,
539
 
540
 
541
  @spaces.GPU(duration=_mmaudio_duration)
542
- def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
543
- cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
544
- """MMAudio: flow-matching video-to-audio, 44.1 kHz, 8 s sliding window."""
 
545
  _mmaudio_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "MMAudio")
546
  if _mmaudio_dir not in sys.path:
547
  sys.path.insert(0, _mmaudio_dir)
548
 
549
- from mmaudio.eval_utils import generate, load_video, make_video
550
  from mmaudio.model.flow_matching import FlowMatching
551
 
552
  seed_val = int(seed_val)
553
  num_samples = int(num_samples)
554
  crossfade_s = float(crossfade_s)
555
- crossfade_db = float(crossfade_db)
556
 
557
  device = "cuda" if torch.cuda.is_available() else "cpu"
558
  dtype = torch.bfloat16
559
 
560
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
561
 
562
- tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
563
- outputs = []
564
-
565
- # Strip original audio so the muxed output only contains the generated track
566
- silent_video = os.path.join(tmp_dir, "silent_input.mp4")
567
- strip_audio_from_video(video_file, silent_video)
568
-
569
- # MMAudio's fixed window is 8 s. For longer videos we slide over 8 s segments
570
- # with a crossfade overlap and stitch the results into a full-length track.
571
- total_dur_s = get_video_duration(video_file)
572
- segments = _build_segments(total_dur_s, MMAUDIO_WINDOW, crossfade_s)
573
- print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤8 s")
574
 
575
  sr = seq_cfg.sampling_rate # 44100
576
 
577
- # Pre-extract all segment clips once (shared across samples, saves ffmpeg overhead)
578
- seg_clip_paths = []
579
- for seg_i, (seg_start, seg_end) in enumerate(segments):
580
- seg_dur = seg_end - seg_start
581
- seg_path = os.path.join(tmp_dir, f"mma_seg_{seg_i}.mp4")
582
- ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
583
- seg_path, vcodec="copy", an=None
584
- ).run(overwrite_output=True, quiet=True)
585
- seg_clip_paths.append(seg_path)
586
-
587
  for sample_idx in range(num_samples):
588
  rng = torch.Generator(device=device)
589
  if seed_val >= 0:
@@ -591,7 +605,7 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
591
  else:
592
  rng.seed()
593
 
594
- seg_audios = [] # list of (channels, samples) numpy arrays
595
  _t_mma_start = time.perf_counter()
596
 
597
  for seg_i, (seg_start, seg_end) in enumerate(segments):
@@ -633,8 +647,49 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
633
  print(f"[MMAudio] Inference done: {_n_segs_mma} seg(s) × {int(num_steps)} steps in "
634
  f"{_t_mma_elapsed:.1f}s wall → {_secs_per_step_mma:.3f}s/step "
635
  f"(current constant={MMAUDIO_SECS_PER_STEP})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
 
637
- # Crossfade-stitch all segments using shared equal-power helper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  full_wav = seg_audios[0]
639
  for nw in seg_audios[1:]:
640
  full_wav = _cf_join(full_wav, nw, crossfade_s, crossfade_db, sr)
@@ -642,7 +697,6 @@ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
642
 
643
  audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.wav")
644
  torchaudio.save(audio_path, torch.from_numpy(full_wav), sr)
645
-
646
  video_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.mp4")
647
  mux_video_audio(silent_video, audio_path, video_path)
648
  wav_paths = _save_seg_wavs(seg_audios, tmp_dir, f"mmaudio_{sample_idx}")
@@ -694,79 +748,52 @@ def _hunyuan_duration(video_file, prompt, negative_prompt, seed_val,
694
 
695
 
696
  @spaces.GPU(duration=_hunyuan_duration)
697
- def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
698
- guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
699
- """HunyuanVideoFoley: text-guided foley, 48 kHz, up to 15 s."""
700
- # Ensure HunyuanVideo-Foley package is importable
701
  _hf_path = str(Path("HunyuanVideo-Foley").resolve())
702
  if _hf_path not in sys.path:
703
  sys.path.insert(0, _hf_path)
704
 
705
  from hunyuanvideo_foley.utils.model_utils import denoise_process
706
  from hunyuanvideo_foley.utils.feature_utils import feature_process
707
- from hunyuanvideo_foley.utils.media_utils import merge_audio_video
708
 
709
  seed_val = int(seed_val)
710
  num_samples = int(num_samples)
711
  crossfade_s = float(crossfade_s)
712
- crossfade_db = float(crossfade_db)
713
  if seed_val >= 0:
714
  set_global_seed(seed_val)
715
 
716
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
717
- model_size = model_size.lower() # "xl" or "xxl"
718
 
719
  model_dict, cfg = _load_hunyuan_model(device, model_size)
720
 
721
- tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
722
- outputs = []
 
 
 
723
 
724
- # Strip original audio so the muxed output only contains the generated track
725
- silent_video = os.path.join(tmp_dir, "silent_input.mp4")
726
- strip_audio_from_video(video_file, silent_video)
727
-
728
- # HunyuanFoley is limited to 15 s per pass. For longer videos we slice the
729
- # input into overlapping segments, generate audio for each, then crossfade-
730
- # stitch the results into a single full-length audio track.
731
- total_dur_s = get_video_duration(silent_video)
732
- segments = _build_segments(total_dur_s, HUNYUAN_MAX_DUR, crossfade_s)
733
- print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤15 s")
734
-
735
- # Pre-extract text features once (same for every segment; stream-copy, no re-encode)
736
- _dummy_seg_path = os.path.join(tmp_dir, "_seg_dummy.mp4")
737
- ffmpeg.input(silent_video, ss=0, t=min(total_dur_s, HUNYUAN_MAX_DUR)).output(
738
- _dummy_seg_path, vcodec="copy", an=None
739
- ).run(overwrite_output=True, quiet=True)
740
  _, text_feats, _ = feature_process(
741
- _dummy_seg_path,
742
  prompt if prompt else "",
743
  model_dict,
744
  cfg,
745
  neg_prompt=negative_prompt if negative_prompt else None,
746
  )
747
 
748
- # Pre-extract all segment clips once (shared across samples, saves ffmpeg overhead)
749
- hny_seg_clip_paths = []
750
- for seg_i, (seg_start, seg_end) in enumerate(segments):
751
- seg_dur = seg_end - seg_start
752
- seg_path = os.path.join(tmp_dir, f"hny_seg_{seg_i}.mp4")
753
- ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
754
- seg_path, vcodec="copy", an=None
755
- ).run(overwrite_output=True, quiet=True)
756
- hny_seg_clip_paths.append(seg_path)
757
-
758
- # Generate audio per segment, then stitch
759
  for sample_idx in range(num_samples):
760
  seg_wavs = []
761
- sr = 48000 # HunyuanFoley always outputs 48 kHz
762
  _t_hny_start = time.perf_counter()
763
  for seg_i, (seg_start, seg_end) in enumerate(segments):
764
  seg_dur = seg_end - seg_start
765
- seg_path = hny_seg_clip_paths[seg_i]
766
 
767
- # feature_process returns (visual_feats, text_feats, audio_len).
768
- # We discard the returned text_feats (_) and use the pre-computed
769
- # text_feats from above — text encoding runs once, not per segment.
770
  visual_feats, _, seg_audio_len = feature_process(
771
  seg_path,
772
  prompt if prompt else "",
@@ -787,9 +814,7 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
787
  num_inference_steps=int(num_steps),
788
  batch_size=1,
789
  )
790
- # audio_batch shape: (1, channels, samples) — take first (and only) sample
791
- wav = audio_batch[0].float().cpu().numpy() # (channels, samples)
792
- # Trim to exact segment length in samples
793
  seg_samples = int(round(seg_dur * sr))
794
  wav = wav[:, :seg_samples]
795
  seg_wavs.append(wav)
@@ -800,12 +825,66 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
800
  print(f"[HunyuanFoley] Inference done: {_n_segs_hny} seg(s) × {int(num_steps)} steps in "
801
  f"{_t_hny_elapsed:.1f}s wall → {_secs_per_step_hny:.3f}s/step "
802
  f"(current constant={HUNYUAN_SECS_PER_STEP})")
 
803
 
804
- # Crossfade-stitch all segments using shared equal-power helper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
805
  full_wav = seg_wavs[0]
806
  for nw in seg_wavs[1:]:
807
  full_wav = _cf_join(full_wav, nw, crossfade_s, crossfade_db, sr)
808
- # Trim to exact video duration
809
  full_wav = full_wav[:, : int(round(total_dur_s * sr))]
810
 
811
  audio_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.wav")
@@ -813,7 +892,6 @@ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
813
  video_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.mp4")
814
  merge_audio_video(audio_path, silent_video, video_path)
815
  wav_paths = _save_seg_wavs(seg_wavs, tmp_dir, f"hunyuan_{sample_idx}")
816
- # Cache text features so regen can skip text encoding (~2-3s saved)
817
  text_feats_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}_text_feats.pt")
818
  torch.save(text_feats, text_feats_path)
819
  seg_meta = {
@@ -922,10 +1000,10 @@ def _taro_regen_duration(video_file, seg_idx, seg_meta_json,
922
 
923
 
924
  @spaces.GPU(duration=_taro_regen_duration)
925
- def regen_taro_segment(video_file, seg_idx, seg_meta_json,
926
- seed_val, cfg_scale, num_steps, mode,
927
- crossfade_s, crossfade_db, slot_id):
928
- """Regenerate one TARO segment with a fresh random seed."""
929
  meta = json.loads(seg_meta_json)
930
  seg_idx = int(seg_idx)
931
  seg_start_s, seg_end_s = meta["segments"][seg_idx]
@@ -940,7 +1018,6 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
940
 
941
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
942
 
943
- # Load cached CAVP + onset features if available (saves ~5-7s of GPU work)
944
  cavp_path = meta.get("cavp_path")
945
  onset_path = meta.get("onset_path")
946
  if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path):
@@ -960,13 +1037,27 @@ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
960
 
961
  set_global_seed(random.randint(0, 2**32 - 1))
962
 
963
- new_wav = _taro_infer_segment(
964
  model_net, vae, vocoder, cavp_feats, onset_feats,
965
  seg_start_s, seg_end_s, device, weight_dtype,
966
  float(cfg_scale), int(num_steps), mode, latents_scale,
967
  euler_sampler, euler_maruyama_sampler,
968
  )
969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
970
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
971
  new_wav, seg_idx, meta, slot_id
972
  )
@@ -983,10 +1074,10 @@ def _mmaudio_regen_duration(video_file, seg_idx, seg_meta_json,
983
 
984
 
985
  @spaces.GPU(duration=_mmaudio_regen_duration)
986
- def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
987
- prompt, negative_prompt, seed_val,
988
- cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id):
989
- """Regenerate one MMAudio segment with a fresh random seed."""
990
  meta = json.loads(seg_meta_json)
991
  seg_idx = int(seg_idx)
992
  seg_start, seg_end = meta["segments"][seg_idx]
@@ -1003,14 +1094,18 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
1003
  dtype = torch.bfloat16
1004
 
1005
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
 
1006
 
1007
- sr = seq_cfg.sampling_rate
1008
- silent_video = meta["silent_video"]
1009
- tmp_dir = tempfile.mkdtemp()
1010
- seg_path = os.path.join(tmp_dir, "regen_seg.mp4")
1011
- ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1012
- seg_path, vcodec="copy", an=None
1013
- ).run(overwrite_output=True, quiet=True)
 
 
 
1014
 
1015
  rng = torch.Generator(device=device)
1016
  rng.manual_seed(random.randint(0, 2**32 - 1))
@@ -1033,9 +1128,37 @@ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
1033
  new_wav = audios.float().cpu()[0].numpy()
1034
  seg_samples = int(round(seg_dur * sr))
1035
  new_wav = new_wav[:, :seg_samples]
 
1036
 
1037
- meta["sr"] = sr
1038
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1039
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1040
  new_wav, seg_idx, meta, slot_id
1041
  )
@@ -1053,11 +1176,11 @@ def _hunyuan_regen_duration(video_file, seg_idx, seg_meta_json,
1053
 
1054
 
1055
  @spaces.GPU(duration=_hunyuan_regen_duration)
1056
- def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1057
- prompt, negative_prompt, seed_val,
1058
- guidance_scale, num_steps, model_size,
1059
- crossfade_s, crossfade_db, slot_id):
1060
- """Regenerate one HunyuanFoley segment with a fresh random seed."""
1061
  meta = json.loads(seg_meta_json)
1062
  seg_idx = int(seg_idx)
1063
  seg_start, seg_end = meta["segments"][seg_idx]
@@ -1075,14 +1198,16 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1075
 
1076
  set_global_seed(random.randint(0, 2**32 - 1))
1077
 
1078
- silent_video = meta["silent_video"]
1079
- tmp_dir = tempfile.mkdtemp()
1080
- seg_path = os.path.join(tmp_dir, "regen_seg.mp4")
1081
- ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1082
- seg_path, vcodec="copy", an=None
1083
- ).run(overwrite_output=True, quiet=True)
 
 
 
1084
 
1085
- # Load cached text features if available (saves ~2-3s text encoding)
1086
  text_feats_path = meta.get("text_feats_path")
1087
  if text_feats_path and os.path.exists(text_feats_path):
1088
  print("[HunyuanFoley regen] Loading cached text features, extracting visual only")
@@ -1104,9 +1229,39 @@ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1104
  new_wav = audio_batch[0].float().cpu().numpy()
1105
  seg_samples = int(round(seg_dur * sr))
1106
  new_wav = new_wav[:, :seg_samples]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1107
 
1108
- meta["sr"] = sr
1109
 
 
1110
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1111
  new_wav, seg_idx, meta, slot_id
1112
  )
 
402
 
403
 
404
  @spaces.GPU(duration=_taro_duration)
405
+ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
406
+ crossfade_s, crossfade_db, num_samples):
407
+ """GPU-only TARO inference model loading + feature extraction + diffusion.
408
+ Returns list of (wavs_list, onset_feats) per sample."""
409
  global _TARO_INFERENCE_CACHE
410
 
411
  seed_val = int(seed_val)
412
  crossfade_s = float(crossfade_s)
 
413
  num_samples = int(num_samples)
414
  if seed_val < 0:
415
  seed_val = random.randint(0, 2**32 - 1)
 
418
  device = "cuda" if torch.cuda.is_available() else "cpu"
419
  weight_dtype = torch.bfloat16
420
 
 
 
421
  _taro_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "TARO")
422
  if _taro_dir not in sys.path:
423
  sys.path.insert(0, _taro_dir)
 
425
  from TARO.onset_util import extract_onset
426
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
427
 
428
+ # Use pre-computed CPU results from the wrapper
429
+ ctx = _taro_gpu_infer._cpu_ctx
430
+ tmp_dir = ctx["tmp_dir"]
431
+ silent_video = ctx["silent_video"]
432
+ segments = ctx["segments"]
433
+ total_dur_s = ctx["total_dur_s"]
434
+
435
  extract_cavp, onset_model = _load_taro_feature_extractors(device)
436
  model, vae, vocoder, latents_scale = _load_taro_models(device, weight_dtype)
437
 
 
 
 
 
 
438
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
 
 
 
 
439
 
440
+ results = [] # list of (wavs, onset_feats) per sample
441
  for sample_idx in range(num_samples):
442
  sample_seed = seed_val + sample_idx
443
  cache_key = (video_file, sample_seed, float(cfg_scale), int(num_steps), mode, crossfade_s)
 
446
  cached = _TARO_INFERENCE_CACHE.get(cache_key)
447
  if cached is not None:
448
  print(f"[TARO] Sample {sample_idx+1}: cache hit.")
449
+ results.append((cached["wavs"], cavp_feats, None))
450
  else:
451
  set_global_seed(sample_seed)
452
  onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
 
474
  _TARO_INFERENCE_CACHE[cache_key] = {"wavs": wavs}
475
  while len(_TARO_INFERENCE_CACHE) > _TARO_CACHE_MAXLEN:
476
  _TARO_INFERENCE_CACHE.pop(next(iter(_TARO_INFERENCE_CACHE)))
477
+ results.append((wavs, cavp_feats, onset_feats))
478
+
479
+ return results
480
+
481
+ # Attach a context slot for the CPU wrapper to pass pre-computed data
482
+ _taro_gpu_infer._cpu_ctx = {}
483
 
484
+
485
+ def generate_taro(video_file, seed_val, cfg_scale, num_steps, mode,
486
+ crossfade_s, crossfade_db, num_samples):
487
+ """TARO: video-conditioned diffusion, 16 kHz, 8.192 s sliding window.
488
+ CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost."""
489
+ crossfade_s = float(crossfade_s)
490
+ crossfade_db = float(crossfade_db)
491
+ num_samples = int(num_samples)
492
+
493
+ # ── CPU pre-processing (no GPU needed) ──
494
+ tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
495
+ silent_video = os.path.join(tmp_dir, "silent_input.mp4")
496
+ strip_audio_from_video(video_file, silent_video)
497
+ total_dur_s = get_video_duration(video_file)
498
+ segments = _build_segments(total_dur_s, TARO_MODEL_DUR, crossfade_s)
499
+
500
+ # Pass pre-computed CPU results to the GPU function via context
501
+ _taro_gpu_infer._cpu_ctx = {
502
+ "tmp_dir": tmp_dir, "silent_video": silent_video,
503
+ "segments": segments, "total_dur_s": total_dur_s,
504
+ }
505
+
506
+ # ── GPU inference only ──
507
+ results = _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
508
+ crossfade_s, crossfade_db, num_samples)
509
+
510
+ # ── CPU post-processing (no GPU needed) ──
511
+ outputs = []
512
+ for sample_idx, (wavs, cavp_feats, onset_feats) in enumerate(results):
513
  final_wav = _stitch_wavs(wavs, crossfade_s, crossfade_db, total_dur_s, TARO_SR)
514
  audio_path = os.path.join(tmp_dir, f"taro_{sample_idx}.wav")
515
  torchaudio.save(audio_path, torch.from_numpy(np.ascontiguousarray(final_wav)).unsqueeze(0), TARO_SR)
 
520
  cavp_path = os.path.join(tmp_dir, f"taro_{sample_idx}_cavp.npy")
521
  onset_path = os.path.join(tmp_dir, f"taro_{sample_idx}_onset.npy")
522
  np.save(cavp_path, cavp_feats)
523
+ if onset_feats is not None:
524
+ np.save(onset_path, onset_feats)
525
  seg_meta = {
526
  "segments": segments,
527
  "wav_paths": wav_paths,
 
571
 
572
 
573
  @spaces.GPU(duration=_mmaudio_duration)
574
+ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
575
+ cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
576
+ """GPU-only MMAudio inference model loading + flow-matching generation.
577
+ Returns list of (seg_audios, sr) per sample."""
578
  _mmaudio_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "MMAudio")
579
  if _mmaudio_dir not in sys.path:
580
  sys.path.insert(0, _mmaudio_dir)
581
 
582
+ from mmaudio.eval_utils import generate, load_video
583
  from mmaudio.model.flow_matching import FlowMatching
584
 
585
  seed_val = int(seed_val)
586
  num_samples = int(num_samples)
587
  crossfade_s = float(crossfade_s)
 
588
 
589
  device = "cuda" if torch.cuda.is_available() else "cpu"
590
  dtype = torch.bfloat16
591
 
592
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
593
 
594
+ ctx = _mmaudio_gpu_infer._cpu_ctx
595
+ segments = ctx["segments"]
596
+ seg_clip_paths = ctx["seg_clip_paths"]
 
 
 
 
 
 
 
 
 
597
 
598
  sr = seq_cfg.sampling_rate # 44100
599
 
600
+ results = []
 
 
 
 
 
 
 
 
 
601
  for sample_idx in range(num_samples):
602
  rng = torch.Generator(device=device)
603
  if seed_val >= 0:
 
605
  else:
606
  rng.seed()
607
 
608
+ seg_audios = []
609
  _t_mma_start = time.perf_counter()
610
 
611
  for seg_i, (seg_start, seg_end) in enumerate(segments):
 
647
  print(f"[MMAudio] Inference done: {_n_segs_mma} seg(s) × {int(num_steps)} steps in "
648
  f"{_t_mma_elapsed:.1f}s wall → {_secs_per_step_mma:.3f}s/step "
649
  f"(current constant={MMAUDIO_SECS_PER_STEP})")
650
+ results.append((seg_audios, sr))
651
+
652
+ return results
653
+
654
+ _mmaudio_gpu_infer._cpu_ctx = {}
655
+
656
+
657
+ def generate_mmaudio(video_file, prompt, negative_prompt, seed_val,
658
+ cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples):
659
+ """MMAudio: flow-matching video-to-audio, 44.1 kHz, 8 s sliding window.
660
+ CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost."""
661
+ num_samples = int(num_samples)
662
+ crossfade_s = float(crossfade_s)
663
+ crossfade_db = float(crossfade_db)
664
+
665
+ # ── CPU pre-processing ──
666
+ tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
667
+ silent_video = os.path.join(tmp_dir, "silent_input.mp4")
668
+ strip_audio_from_video(video_file, silent_video)
669
+ total_dur_s = get_video_duration(video_file)
670
+ segments = _build_segments(total_dur_s, MMAUDIO_WINDOW, crossfade_s)
671
+ print(f"[MMAudio] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤8 s")
672
 
673
+ seg_clip_paths = []
674
+ for seg_i, (seg_start, seg_end) in enumerate(segments):
675
+ seg_dur = seg_end - seg_start
676
+ seg_path = os.path.join(tmp_dir, f"mma_seg_{seg_i}.mp4")
677
+ ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
678
+ seg_path, vcodec="copy", an=None
679
+ ).run(overwrite_output=True, quiet=True)
680
+ seg_clip_paths.append(seg_path)
681
+
682
+ _mmaudio_gpu_infer._cpu_ctx = {
683
+ "segments": segments, "seg_clip_paths": seg_clip_paths,
684
+ }
685
+
686
+ # ── GPU inference only ──
687
+ results = _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
688
+ cfg_strength, num_steps, crossfade_s, crossfade_db, num_samples)
689
+
690
+ # ── CPU post-processing ──
691
+ outputs = []
692
+ for sample_idx, (seg_audios, sr) in enumerate(results):
693
  full_wav = seg_audios[0]
694
  for nw in seg_audios[1:]:
695
  full_wav = _cf_join(full_wav, nw, crossfade_s, crossfade_db, sr)
 
697
 
698
  audio_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.wav")
699
  torchaudio.save(audio_path, torch.from_numpy(full_wav), sr)
 
700
  video_path = os.path.join(tmp_dir, f"mmaudio_{sample_idx}.mp4")
701
  mux_video_audio(silent_video, audio_path, video_path)
702
  wav_paths = _save_seg_wavs(seg_audios, tmp_dir, f"mmaudio_{sample_idx}")
 
748
 
749
 
750
  @spaces.GPU(duration=_hunyuan_duration)
751
+ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
752
+ guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
753
+ """GPU-only HunyuanFoley inference model loading + feature extraction + denoising.
754
+ Returns list of (seg_wavs, sr, text_feats) per sample."""
755
  _hf_path = str(Path("HunyuanVideo-Foley").resolve())
756
  if _hf_path not in sys.path:
757
  sys.path.insert(0, _hf_path)
758
 
759
  from hunyuanvideo_foley.utils.model_utils import denoise_process
760
  from hunyuanvideo_foley.utils.feature_utils import feature_process
 
761
 
762
  seed_val = int(seed_val)
763
  num_samples = int(num_samples)
764
  crossfade_s = float(crossfade_s)
 
765
  if seed_val >= 0:
766
  set_global_seed(seed_val)
767
 
768
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
769
+ model_size = model_size.lower()
770
 
771
  model_dict, cfg = _load_hunyuan_model(device, model_size)
772
 
773
+ ctx = _hunyuan_gpu_infer._cpu_ctx
774
+ segments = ctx["segments"]
775
+ total_dur_s = ctx["total_dur_s"]
776
+ dummy_seg_path = ctx["dummy_seg_path"]
777
+ seg_clip_paths = ctx["seg_clip_paths"]
778
 
779
+ # Text feature extraction (GPU runs once for all segments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780
  _, text_feats, _ = feature_process(
781
+ dummy_seg_path,
782
  prompt if prompt else "",
783
  model_dict,
784
  cfg,
785
  neg_prompt=negative_prompt if negative_prompt else None,
786
  )
787
 
788
+ results = []
 
 
 
 
 
 
 
 
 
 
789
  for sample_idx in range(num_samples):
790
  seg_wavs = []
791
+ sr = 48000
792
  _t_hny_start = time.perf_counter()
793
  for seg_i, (seg_start, seg_end) in enumerate(segments):
794
  seg_dur = seg_end - seg_start
795
+ seg_path = seg_clip_paths[seg_i]
796
 
 
 
 
797
  visual_feats, _, seg_audio_len = feature_process(
798
  seg_path,
799
  prompt if prompt else "",
 
814
  num_inference_steps=int(num_steps),
815
  batch_size=1,
816
  )
817
+ wav = audio_batch[0].float().cpu().numpy()
 
 
818
  seg_samples = int(round(seg_dur * sr))
819
  wav = wav[:, :seg_samples]
820
  seg_wavs.append(wav)
 
825
  print(f"[HunyuanFoley] Inference done: {_n_segs_hny} seg(s) × {int(num_steps)} steps in "
826
  f"{_t_hny_elapsed:.1f}s wall → {_secs_per_step_hny:.3f}s/step "
827
  f"(current constant={HUNYUAN_SECS_PER_STEP})")
828
+ results.append((seg_wavs, sr, text_feats))
829
 
830
+ return results
831
+
832
+ _hunyuan_gpu_infer._cpu_ctx = {}
833
+
834
+
835
+ def generate_hunyuan(video_file, prompt, negative_prompt, seed_val,
836
+ guidance_scale, num_steps, model_size, crossfade_s, crossfade_db, num_samples):
837
+ """HunyuanVideoFoley: text-guided foley, 48 kHz, up to 15 s.
838
+ CPU pre/post-processing wraps the GPU-only inference to minimize ZeroGPU cost."""
839
+ num_samples = int(num_samples)
840
+ crossfade_s = float(crossfade_s)
841
+ crossfade_db = float(crossfade_db)
842
+
843
+ # ── CPU pre-processing (no GPU needed) ──
844
+ tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
845
+ silent_video = os.path.join(tmp_dir, "silent_input.mp4")
846
+ strip_audio_from_video(video_file, silent_video)
847
+ total_dur_s = get_video_duration(silent_video)
848
+ segments = _build_segments(total_dur_s, HUNYUAN_MAX_DUR, crossfade_s)
849
+ print(f"[HunyuanFoley] Video={total_dur_s:.2f}s | {len(segments)} segment(s) × ≤15 s")
850
+
851
+ # Pre-extract dummy segment for text feature extraction (ffmpeg, CPU)
852
+ dummy_seg_path = os.path.join(tmp_dir, "_seg_dummy.mp4")
853
+ ffmpeg.input(silent_video, ss=0, t=min(total_dur_s, HUNYUAN_MAX_DUR)).output(
854
+ dummy_seg_path, vcodec="copy", an=None
855
+ ).run(overwrite_output=True, quiet=True)
856
+
857
+ # Pre-extract all segment clips (ffmpeg, CPU)
858
+ seg_clip_paths = []
859
+ for seg_i, (seg_start, seg_end) in enumerate(segments):
860
+ seg_dur = seg_end - seg_start
861
+ seg_path = os.path.join(tmp_dir, f"hny_seg_{seg_i}.mp4")
862
+ ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
863
+ seg_path, vcodec="copy", an=None
864
+ ).run(overwrite_output=True, quiet=True)
865
+ seg_clip_paths.append(seg_path)
866
+
867
+ _hunyuan_gpu_infer._cpu_ctx = {
868
+ "segments": segments, "total_dur_s": total_dur_s,
869
+ "dummy_seg_path": dummy_seg_path, "seg_clip_paths": seg_clip_paths,
870
+ }
871
+
872
+ # ── GPU inference only ──
873
+ results = _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
874
+ guidance_scale, num_steps, model_size,
875
+ crossfade_s, crossfade_db, num_samples)
876
+
877
+ # ── CPU post-processing (no GPU needed) ──
878
+ _hf_path = str(Path("HunyuanVideo-Foley").resolve())
879
+ if _hf_path not in sys.path:
880
+ sys.path.insert(0, _hf_path)
881
+ from hunyuanvideo_foley.utils.media_utils import merge_audio_video
882
+
883
+ outputs = []
884
+ for sample_idx, (seg_wavs, sr, text_feats) in enumerate(results):
885
  full_wav = seg_wavs[0]
886
  for nw in seg_wavs[1:]:
887
  full_wav = _cf_join(full_wav, nw, crossfade_s, crossfade_db, sr)
 
888
  full_wav = full_wav[:, : int(round(total_dur_s * sr))]
889
 
890
  audio_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.wav")
 
892
  video_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}.mp4")
893
  merge_audio_video(audio_path, silent_video, video_path)
894
  wav_paths = _save_seg_wavs(seg_wavs, tmp_dir, f"hunyuan_{sample_idx}")
 
895
  text_feats_path = os.path.join(tmp_dir, f"hunyuan_{sample_idx}_text_feats.pt")
896
  torch.save(text_feats, text_feats_path)
897
  seg_meta = {
 
1000
 
1001
 
1002
  @spaces.GPU(duration=_taro_regen_duration)
1003
+ def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1004
+ seed_val, cfg_scale, num_steps, mode,
1005
+ crossfade_s, crossfade_db, slot_id=None):
1006
+ """GPU-only TARO regen returns new_wav for a single segment."""
1007
  meta = json.loads(seg_meta_json)
1008
  seg_idx = int(seg_idx)
1009
  seg_start_s, seg_end_s = meta["segments"][seg_idx]
 
1018
 
1019
  from TARO.samplers import euler_sampler, euler_maruyama_sampler
1020
 
 
1021
  cavp_path = meta.get("cavp_path")
1022
  onset_path = meta.get("onset_path")
1023
  if cavp_path and os.path.exists(cavp_path) and onset_path and os.path.exists(onset_path):
 
1037
 
1038
  set_global_seed(random.randint(0, 2**32 - 1))
1039
 
1040
+ return _taro_infer_segment(
1041
  model_net, vae, vocoder, cavp_feats, onset_feats,
1042
  seg_start_s, seg_end_s, device, weight_dtype,
1043
  float(cfg_scale), int(num_steps), mode, latents_scale,
1044
  euler_sampler, euler_maruyama_sampler,
1045
  )
1046
 
1047
+
1048
+ def regen_taro_segment(video_file, seg_idx, seg_meta_json,
1049
+ seed_val, cfg_scale, num_steps, mode,
1050
+ crossfade_s, crossfade_db, slot_id):
1051
+ """Regenerate one TARO segment. GPU inference + CPU splice/save."""
1052
+ meta = json.loads(seg_meta_json)
1053
+ seg_idx = int(seg_idx)
1054
+
1055
+ # GPU: inference only
1056
+ new_wav = _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1057
+ seed_val, cfg_scale, num_steps, mode,
1058
+ crossfade_s, crossfade_db, slot_id)
1059
+
1060
+ # CPU: splice, stitch, mux, save
1061
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1062
  new_wav, seg_idx, meta, slot_id
1063
  )
 
1074
 
1075
 
1076
  @spaces.GPU(duration=_mmaudio_regen_duration)
1077
+ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1078
+ prompt, negative_prompt, seed_val,
1079
+ cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id=None):
1080
+ """GPU-only MMAudio regen returns (new_wav, sr) for a single segment."""
1081
  meta = json.loads(seg_meta_json)
1082
  seg_idx = int(seg_idx)
1083
  seg_start, seg_end = meta["segments"][seg_idx]
 
1094
  dtype = torch.bfloat16
1095
 
1096
  net, feature_utils, model_cfg, seq_cfg = _load_mmaudio_models(device, dtype)
1097
+ sr = seq_cfg.sampling_rate
1098
 
1099
+ # Use pre-extracted segment clip from the wrapper
1100
+ seg_path = _regen_mmaudio_gpu._cpu_ctx.get("seg_path")
1101
+ if not seg_path:
1102
+ # Fallback: extract inside GPU (shouldn't happen)
1103
+ silent_video = meta["silent_video"]
1104
+ tmp_dir = tempfile.mkdtemp()
1105
+ seg_path = os.path.join(tmp_dir, "regen_seg.mp4")
1106
+ ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1107
+ seg_path, vcodec="copy", an=None
1108
+ ).run(overwrite_output=True, quiet=True)
1109
 
1110
  rng = torch.Generator(device=device)
1111
  rng.manual_seed(random.randint(0, 2**32 - 1))
 
1128
  new_wav = audios.float().cpu()[0].numpy()
1129
  seg_samples = int(round(seg_dur * sr))
1130
  new_wav = new_wav[:, :seg_samples]
1131
+ return new_wav, sr
1132
 
1133
+ _regen_mmaudio_gpu._cpu_ctx = {}
1134
 
1135
+
1136
+ def regen_mmaudio_segment(video_file, seg_idx, seg_meta_json,
1137
+ prompt, negative_prompt, seed_val,
1138
+ cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id):
1139
+ """Regenerate one MMAudio segment. GPU inference + CPU splice/save."""
1140
+ meta = json.loads(seg_meta_json)
1141
+ seg_idx = int(seg_idx)
1142
+ seg_start, seg_end = meta["segments"][seg_idx]
1143
+ seg_dur = seg_end - seg_start
1144
+
1145
+ # CPU: pre-extract segment clip
1146
+ silent_video = meta["silent_video"]
1147
+ tmp_dir = tempfile.mkdtemp()
1148
+ seg_path = os.path.join(tmp_dir, "regen_seg.mp4")
1149
+ ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1150
+ seg_path, vcodec="copy", an=None
1151
+ ).run(overwrite_output=True, quiet=True)
1152
+ _regen_mmaudio_gpu._cpu_ctx = {"seg_path": seg_path}
1153
+
1154
+ # GPU: inference only
1155
+ new_wav, sr = _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1156
+ prompt, negative_prompt, seed_val,
1157
+ cfg_strength, num_steps, crossfade_s, crossfade_db, slot_id)
1158
+
1159
+ meta["sr"] = sr
1160
+
1161
+ # CPU: splice, stitch, mux, save
1162
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1163
  new_wav, seg_idx, meta, slot_id
1164
  )
 
1176
 
1177
 
1178
  @spaces.GPU(duration=_hunyuan_regen_duration)
1179
+ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1180
+ prompt, negative_prompt, seed_val,
1181
+ guidance_scale, num_steps, model_size,
1182
+ crossfade_s, crossfade_db, slot_id=None):
1183
+ """GPU-only HunyuanFoley regen returns (new_wav, sr) for a single segment."""
1184
  meta = json.loads(seg_meta_json)
1185
  seg_idx = int(seg_idx)
1186
  seg_start, seg_end = meta["segments"][seg_idx]
 
1198
 
1199
  set_global_seed(random.randint(0, 2**32 - 1))
1200
 
1201
+ # Use pre-extracted segment clip from wrapper
1202
+ seg_path = _regen_hunyuan_gpu._cpu_ctx.get("seg_path")
1203
+ if not seg_path:
1204
+ silent_video = meta["silent_video"]
1205
+ tmp_dir = tempfile.mkdtemp()
1206
+ seg_path = os.path.join(tmp_dir, "regen_seg.mp4")
1207
+ ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1208
+ seg_path, vcodec="copy", an=None
1209
+ ).run(overwrite_output=True, quiet=True)
1210
 
 
1211
  text_feats_path = meta.get("text_feats_path")
1212
  if text_feats_path and os.path.exists(text_feats_path):
1213
  print("[HunyuanFoley regen] Loading cached text features, extracting visual only")
 
1229
  new_wav = audio_batch[0].float().cpu().numpy()
1230
  seg_samples = int(round(seg_dur * sr))
1231
  new_wav = new_wav[:, :seg_samples]
1232
+ return new_wav, sr
1233
+
1234
+ _regen_hunyuan_gpu._cpu_ctx = {}
1235
+
1236
+
1237
+ def regen_hunyuan_segment(video_file, seg_idx, seg_meta_json,
1238
+ prompt, negative_prompt, seed_val,
1239
+ guidance_scale, num_steps, model_size,
1240
+ crossfade_s, crossfade_db, slot_id):
1241
+ """Regenerate one HunyuanFoley segment. GPU inference + CPU splice/save."""
1242
+ meta = json.loads(seg_meta_json)
1243
+ seg_idx = int(seg_idx)
1244
+ seg_start, seg_end = meta["segments"][seg_idx]
1245
+ seg_dur = seg_end - seg_start
1246
+
1247
+ # CPU: pre-extract segment clip
1248
+ silent_video = meta["silent_video"]
1249
+ tmp_dir = tempfile.mkdtemp()
1250
+ seg_path = os.path.join(tmp_dir, "regen_seg.mp4")
1251
+ ffmpeg.input(silent_video, ss=seg_start, t=seg_dur).output(
1252
+ seg_path, vcodec="copy", an=None
1253
+ ).run(overwrite_output=True, quiet=True)
1254
+ _regen_hunyuan_gpu._cpu_ctx = {"seg_path": seg_path}
1255
+
1256
+ # GPU: inference only
1257
+ new_wav, sr = _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1258
+ prompt, negative_prompt, seed_val,
1259
+ guidance_scale, num_steps, model_size,
1260
+ crossfade_s, crossfade_db, slot_id)
1261
 
1262
+ meta["sr"] = sr
1263
 
1264
+ # CPU: splice, stitch, mux, save
1265
  video_path, audio_path, updated_meta, waveform_html = _splice_and_save(
1266
  new_wav, seg_idx, meta, slot_id
1267
  )