BoxOfColors commited on
Commit
d9fa683
·
1 Parent(s): 47fd0ad

Refactor: reduce technical debt across app.py

Browse files

- Remove unused ; move to top-level
- Add _resolve_seed() helper to unify inconsistent seed handling across
TARO/MMAudio/HunyuanFoley (was: 3 different patterns; now: one call)
- Remove redundant torch.device() wrapping in _hunyuan_gpu_infer and
_regen_hunyuan_gpu (device is already a string from _get_device_and_dtype)
- Fix all unregistered mkdtemp() calls in regen GPU functions so temp dirs
are tracked and cleaned up (prevents /tmp accumulation on long-running Spaces)
- Fix misleading comments in _xregen_splice offset alignment logic

Files changed (1) hide show
  1. app.py +20 -22
app.py CHANGED
@@ -8,7 +8,7 @@ Supported models
8
  HunyuanFoley – text-guided foley via SigLIP2 + Synchformer + CLAP (48 kHz, up to 15 s)
9
  """
10
 
11
- import html as _html
12
  import os
13
  import sys
14
  import json
@@ -210,6 +210,12 @@ def set_global_seed(seed: int) -> None:
210
  def get_random_seed() -> int:
211
  return random.randint(0, 2**32 - 1)
212
 
 
 
 
 
 
 
213
  def get_video_duration(video_path: str) -> float:
214
  """Return video duration in seconds (CPU only)."""
215
  probe = ffmpeg.probe(video_path)
@@ -432,7 +438,6 @@ def _build_segments(total_dur_s: float, window_s: float, crossfade_s: float) ->
432
  crossfade_s = min(crossfade_s, window_s * 0.5)
433
  if total_dur_s <= window_s:
434
  return [(0.0, total_dur_s)]
435
- import math
436
  step_min = window_s - crossfade_s # minimum step to honour crossfade
437
  n = math.ceil((total_dur_s - crossfade_s) / step_min)
438
  n = max(n, 2)
@@ -849,11 +854,9 @@ def _taro_gpu_infer(video_file, seed_val, cfg_scale, num_steps, mode,
849
  crossfade_s, crossfade_db, num_samples):
850
  """GPU-only TARO inference — model loading + feature extraction + diffusion.
851
  Returns list of (wavs_list, onset_feats) per sample."""
852
- seed_val = int(seed_val)
853
  crossfade_s = float(crossfade_s)
854
  num_samples = int(num_samples)
855
- if seed_val < 0:
856
- seed_val = random.randint(0, 2**32 - 1)
857
 
858
  torch.set_grad_enabled(False)
859
  device, weight_dtype = _get_device_and_dtype()
@@ -1005,7 +1008,7 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1005
  from mmaudio.eval_utils import generate, load_video
1006
  from mmaudio.model.flow_matching import FlowMatching
1007
 
1008
- seed_val = int(seed_val)
1009
  num_samples = int(num_samples)
1010
  crossfade_s = float(crossfade_s)
1011
 
@@ -1022,10 +1025,7 @@ def _mmaudio_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1022
  results = []
1023
  for sample_idx in range(num_samples):
1024
  rng = torch.Generator(device=device)
1025
- if seed_val >= 0:
1026
- rng.manual_seed(seed_val + sample_idx)
1027
- else:
1028
- rng.seed()
1029
 
1030
  seg_audios = []
1031
  _t_mma_start = time.perf_counter()
@@ -1149,14 +1149,12 @@ def _hunyuan_gpu_infer(video_file, prompt, negative_prompt, seed_val,
1149
  from hunyuanvideo_foley.utils.model_utils import denoise_process
1150
  from hunyuanvideo_foley.utils.feature_utils import feature_process
1151
 
1152
- seed_val = int(seed_val)
1153
  num_samples = int(num_samples)
1154
  crossfade_s = float(crossfade_s)
1155
- if seed_val >= 0:
1156
- set_global_seed(seed_val)
1157
 
1158
  device, _ = _get_device_and_dtype()
1159
- device = torch.device(device)
1160
  model_size = model_size.lower()
1161
 
1162
  model_dict, cfg = _load_hunyuan_model(device, model_size)
@@ -1375,7 +1373,7 @@ def _regen_taro_gpu(video_file, seg_idx, seg_meta_json,
1375
  from TARO.onset_util import extract_onset
1376
  extract_cavp, onset_model = _load_taro_feature_extractors(device)
1377
  silent_video = meta["silent_video"]
1378
- tmp_dir = tempfile.mkdtemp()
1379
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
1380
  onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
1381
  del extract_cavp, onset_model
@@ -1446,7 +1444,7 @@ def _regen_mmaudio_gpu(video_file, seg_idx, seg_meta_json,
1446
  # This avoids any cross-process context passing that fails under ZeroGPU isolation.
1447
  seg_path = _extract_segment_clip(
1448
  meta["silent_video"], seg_start, seg_dur,
1449
- os.path.join(tempfile.mkdtemp(), "regen_seg.mp4"),
1450
  )
1451
 
1452
  rng = torch.Generator(device=device)
@@ -1521,7 +1519,6 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1521
  from hunyuanvideo_foley.utils.feature_utils import feature_process
1522
 
1523
  device, _ = _get_device_and_dtype()
1524
- device = torch.device(device)
1525
  model_dict, cfg = _load_hunyuan_model(device, model_size)
1526
 
1527
  set_global_seed(random.randint(0, 2**32 - 1))
@@ -1529,7 +1526,7 @@ def _regen_hunyuan_gpu(video_file, seg_idx, seg_meta_json,
1529
  # Extract segment clip inside the GPU function — ffmpeg is CPU-only and safe here.
1530
  seg_path = _extract_segment_clip(
1531
  meta["silent_video"], seg_start, seg_dur,
1532
- os.path.join(tempfile.mkdtemp(), "regen_seg.mp4"),
1533
  )
1534
 
1535
  text_feats_path = meta.get("text_feats_path", "")
@@ -1661,13 +1658,14 @@ def _xregen_splice(new_wav_raw: np.ndarray, src_sr: int,
1661
  slot_wavs = _load_seg_wavs(meta["wav_paths"])
1662
  new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0])
1663
 
1664
- # If the clip started before the original segment start, prepend silence
1665
- # so that sample index 0 of new_wav corresponds to seg_start in video time.
 
 
1666
  if clip_start_s is not None:
1667
  seg_start = meta["segments"][seg_idx][0]
1668
- offset_s = seg_start - clip_start_s # positive = seg starts after clip start
1669
  if offset_s < 0:
1670
- # clip started after seg_start — prepend silence to align
1671
  pad_samples = int(round(abs(offset_s) * slot_sr))
1672
  silence = np.zeros(
1673
  (new_wav.shape[0], pad_samples) if new_wav.ndim == 2 else pad_samples,
 
8
  HunyuanFoley – text-guided foley via SigLIP2 + Synchformer + CLAP (48 kHz, up to 15 s)
9
  """
10
 
11
+ import math
12
  import os
13
  import sys
14
  import json
 
210
  def get_random_seed() -> int:
211
  return random.randint(0, 2**32 - 1)
212
 
213
+ def _resolve_seed(seed_val) -> int:
214
+ """Normalise seed_val to a non-negative int.
215
+ Negative values (UI default 'random') produce a fresh random seed."""
216
+ seed_val = int(seed_val)
217
+ return seed_val if seed_val >= 0 else get_random_seed()
218
+
219
  def get_video_duration(video_path: str) -> float:
220
  """Return video duration in seconds (CPU only)."""
221
  probe = ffmpeg.probe(video_path)
 
438
  crossfade_s = min(crossfade_s, window_s * 0.5)
439
  if total_dur_s <= window_s:
440
  return [(0.0, total_dur_s)]
 
441
  step_min = window_s - crossfade_s # minimum step to honour crossfade
442
  n = math.ceil((total_dur_s - crossfade_s) / step_min)
443
  n = max(n, 2)
 
854
  crossfade_s, crossfade_db, num_samples):
855
  """GPU-only TARO inference — model loading + feature extraction + diffusion.
856
  Returns list of (wavs_list, onset_feats) per sample."""
857
+ seed_val = _resolve_seed(seed_val)
858
  crossfade_s = float(crossfade_s)
859
  num_samples = int(num_samples)
 
 
860
 
861
  torch.set_grad_enabled(False)
862
  device, weight_dtype = _get_device_and_dtype()
 
1008
  from mmaudio.eval_utils import generate, load_video
1009
  from mmaudio.model.flow_matching import FlowMatching
1010
 
1011
+ seed_val = _resolve_seed(seed_val)
1012
  num_samples = int(num_samples)
1013
  crossfade_s = float(crossfade_s)
1014
 
 
1025
  results = []
1026
  for sample_idx in range(num_samples):
1027
  rng = torch.Generator(device=device)
1028
+ rng.manual_seed(seed_val + sample_idx)
 
 
 
1029
 
1030
  seg_audios = []
1031
  _t_mma_start = time.perf_counter()
 
1149
  from hunyuanvideo_foley.utils.model_utils import denoise_process
1150
  from hunyuanvideo_foley.utils.feature_utils import feature_process
1151
 
1152
+ seed_val = _resolve_seed(seed_val)
1153
  num_samples = int(num_samples)
1154
  crossfade_s = float(crossfade_s)
1155
+ set_global_seed(seed_val)
 
1156
 
1157
  device, _ = _get_device_and_dtype()
 
1158
  model_size = model_size.lower()
1159
 
1160
  model_dict, cfg = _load_hunyuan_model(device, model_size)
 
1373
  from TARO.onset_util import extract_onset
1374
  extract_cavp, onset_model = _load_taro_feature_extractors(device)
1375
  silent_video = meta["silent_video"]
1376
+ tmp_dir = _register_tmp_dir(tempfile.mkdtemp())
1377
  cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
1378
  onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
1379
  del extract_cavp, onset_model
 
1444
  # This avoids any cross-process context passing that fails under ZeroGPU isolation.
1445
  seg_path = _extract_segment_clip(
1446
  meta["silent_video"], seg_start, seg_dur,
1447
+ os.path.join(_register_tmp_dir(tempfile.mkdtemp()), "regen_seg.mp4"),
1448
  )
1449
 
1450
  rng = torch.Generator(device=device)
 
1519
  from hunyuanvideo_foley.utils.feature_utils import feature_process
1520
 
1521
  device, _ = _get_device_and_dtype()
 
1522
  model_dict, cfg = _load_hunyuan_model(device, model_size)
1523
 
1524
  set_global_seed(random.randint(0, 2**32 - 1))
 
1526
  # Extract segment clip inside the GPU function — ffmpeg is CPU-only and safe here.
1527
  seg_path = _extract_segment_clip(
1528
  meta["silent_video"], seg_start, seg_dur,
1529
+ os.path.join(_register_tmp_dir(tempfile.mkdtemp()), "regen_seg.mp4"),
1530
  )
1531
 
1532
  text_feats_path = meta.get("text_feats_path", "")
 
1658
  slot_wavs = _load_seg_wavs(meta["wav_paths"])
1659
  new_wav = _resample_to_slot_sr(new_wav_raw, src_sr, slot_sr, slot_wavs[0])
1660
 
1661
+ # Align new_wav so sample index 0 corresponds to seg_start in video time.
1662
+ # _stitch_wavs trims using seg_start as the time origin, so if the clip
1663
+ # started AFTER seg_start (clip_start_s > seg_start), we prepend silence
1664
+ # equal to (clip_start_s - seg_start) to shift the audio back to seg_start.
1665
  if clip_start_s is not None:
1666
  seg_start = meta["segments"][seg_idx][0]
1667
+ offset_s = seg_start - clip_start_s # negative when clip starts after seg_start
1668
  if offset_s < 0:
 
1669
  pad_samples = int(round(abs(offset_s) * slot_sr))
1670
  silence = np.zeros(
1671
  (new_wav.shape[0], pad_samples) if new_wav.ndim == 2 else pad_samples,