Spaces:
Running on Zero
Running on Zero
Commit ·
eaedb53
1
Parent(s): 0429f8a
Add crossfade duration/dB controls, inference caching, share=True
Browse files
app.py
CHANGED
|
@@ -30,6 +30,14 @@ onset_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="onset_model.ckpt",
|
|
| 30 |
taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache_dir=CACHE_DIR)
|
| 31 |
print("Checkpoints downloaded.")
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def set_global_seed(seed):
|
| 35 |
np.random.seed(seed % (2**32))
|
|
@@ -56,22 +64,22 @@ def infer_segment(model, vae, vocoder, cavp_feats_full, onset_feats_full,
|
|
| 56 |
cfg_scale, num_steps, mode,
|
| 57 |
euler_sampler, euler_maruyama_sampler):
|
| 58 |
"""
|
| 59 |
-
Run one model inference pass for the video window
|
| 60 |
-
Returns a numpy float32 wav array
|
| 61 |
-
trimmed to the actual segment length (seg_end_s - seg_start_s) when shorter.
|
| 62 |
"""
|
| 63 |
-
#
|
| 64 |
cavp_start = int(round(seg_start_s * fps))
|
| 65 |
-
|
| 66 |
-
cavp_slice = cavp_feats_full[cavp_start:cavp_end]
|
| 67 |
-
# pad if near end of video
|
| 68 |
if cavp_slice.shape[0] < truncate_frame:
|
| 69 |
-
pad = np.zeros(
|
|
|
|
|
|
|
|
|
|
| 70 |
cavp_slice = np.concatenate([cavp_slice, pad], axis=0)
|
| 71 |
video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device).to(weight_dtype)
|
| 72 |
|
| 73 |
-
#
|
| 74 |
-
onset_fps = truncate_onset / model_dur
|
| 75 |
onset_start = int(round(seg_start_s * onset_fps))
|
| 76 |
onset_slice = onset_feats_full[onset_start : onset_start + truncate_onset]
|
| 77 |
if onset_slice.shape[0] < truncate_onset:
|
|
@@ -79,7 +87,6 @@ def infer_segment(model, vae, vocoder, cavp_feats_full, onset_feats_full,
|
|
| 79 |
onset_slice = np.pad(onset_slice, ((0, pad_len),), mode="constant", constant_values=0)
|
| 80 |
onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device).to(weight_dtype)
|
| 81 |
|
| 82 |
-
# -- Diffusion --
|
| 83 |
z = torch.randn(1, model.in_channels, 204, 16, device=device).to(weight_dtype)
|
| 84 |
sampling_kwargs = dict(
|
| 85 |
model=model,
|
|
@@ -102,171 +109,180 @@ def infer_segment(model, vae, vocoder, cavp_feats_full, onset_feats_full,
|
|
| 102 |
samples = vae.decode(samples / latents_scale).sample
|
| 103 |
wav = vocoder(samples.squeeze().float()).detach().cpu().numpy()
|
| 104 |
|
| 105 |
-
# Trim to actual segment length
|
| 106 |
seg_samples = int(round((seg_end_s - seg_start_s) * sr))
|
| 107 |
return wav[:seg_samples]
|
| 108 |
|
| 109 |
|
| 110 |
-
def crossfade_join(wav_a, wav_b, crossfade_s, sr):
|
| 111 |
"""
|
| 112 |
-
Join wav_a and wav_b with a
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
For a +3 dB bump at midpoint we use *linear* ramps instead:
|
| 123 |
-
fade_out = 1 - t, fade_in = t (t: 0->1 across window)
|
| 124 |
-
At t=0.5: both = 0.5, sum = 1.0 amplitude = +6 dB power... that is not right.
|
| 125 |
-
|
| 126 |
-
DaVinci Resolve "+3 dB" crossfade means the combined level at the midpoint
|
| 127 |
-
is +3 dB above either source, which equals the behaviour where each signal
|
| 128 |
-
is kept at full gain (1.0) across the entire overlap and the two are simply
|
| 129 |
-
summed — then the overlap region has 6 dB of headroom risk, but the *perceived*
|
| 130 |
-
loudness boost at the centre is +3 dB (sqrt(2) in amplitude).
|
| 131 |
-
|
| 132 |
-
Implementation: keep both signals at unity gain in the crossfade window and
|
| 133 |
-
sum them. Outside the window use the respective signal only.
|
| 134 |
"""
|
| 135 |
cf_samples = int(round(crossfade_s * sr))
|
| 136 |
|
| 137 |
-
#
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
overlap = tail_a + head_b # +3 dB sum at centre (unity + unity)
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
])
|
| 148 |
-
return result
|
| 149 |
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
| 160 |
|
| 161 |
-
from cavp_util import Extract_CAVP_Features
|
| 162 |
-
from onset_util import VideoOnsetNet, extract_onset
|
| 163 |
-
from models import MMDiT
|
| 164 |
-
from samplers import euler_sampler, euler_maruyama_sampler
|
| 165 |
-
from diffusers import AudioLDM2Pipeline
|
| 166 |
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
new_key = key.replace("model.fc", "fc")
|
| 178 |
-
else:
|
| 179 |
-
new_key = key
|
| 180 |
-
new_state_dict[new_key] = value
|
| 181 |
-
onset_model = VideoOnsetNet(False).to(device)
|
| 182 |
-
onset_model.load_state_dict(new_state_dict)
|
| 183 |
-
onset_model.eval()
|
| 184 |
-
|
| 185 |
-
model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
|
| 186 |
-
ckpt = torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"]
|
| 187 |
-
model.load_state_dict(ckpt)
|
| 188 |
-
model.eval()
|
| 189 |
-
model.to(weight_dtype)
|
| 190 |
-
|
| 191 |
-
model_audioldm = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
|
| 192 |
-
vae = model_audioldm.vae.to(device)
|
| 193 |
-
vae.eval()
|
| 194 |
-
vocoder = model_audioldm.vocoder.to(device)
|
| 195 |
-
|
| 196 |
-
tmp_dir = tempfile.mkdtemp()
|
| 197 |
-
silent_video = os.path.join(tmp_dir, "silent_input.mp4")
|
| 198 |
-
strip_audio_from_video(video_file, silent_video)
|
| 199 |
-
|
| 200 |
-
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
|
| 201 |
-
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
|
| 202 |
|
| 203 |
sr = 16000
|
| 204 |
truncate = 131072
|
| 205 |
fps = 4
|
| 206 |
-
truncate_frame = int(fps * truncate / sr)
|
| 207 |
-
truncate_onset = 120
|
| 208 |
-
model_dur = truncate / sr
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
#
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
# seg_end_s is the actual content end (clipped to video length), #
|
| 223 |
-
# but we always run the model for a full model_dur window. #
|
| 224 |
-
# ------------------------------------------------------------------ #
|
| 225 |
-
segments = []
|
| 226 |
-
seg_start = 0.0
|
| 227 |
-
while True:
|
| 228 |
-
seg_end = min(seg_start + model_dur, total_dur_s)
|
| 229 |
-
segments.append((seg_start, seg_end))
|
| 230 |
-
if seg_end >= total_dur_s:
|
| 231 |
-
break
|
| 232 |
-
seg_start += step_s
|
| 233 |
-
|
| 234 |
-
# ------------------------------------------------------------------ #
|
| 235 |
-
# Run inference for every segment #
|
| 236 |
-
# ------------------------------------------------------------------ #
|
| 237 |
-
wavs = []
|
| 238 |
-
for seg_start_s, seg_end_s in segments:
|
| 239 |
-
print(f"Inferring segment {seg_start_s:.2f}s – {seg_end_s:.2f}s ...")
|
| 240 |
-
wav = infer_segment(
|
| 241 |
-
model, vae, vocoder,
|
| 242 |
-
cavp_feats, onset_feats,
|
| 243 |
-
seg_start_s, seg_end_s,
|
| 244 |
-
sr, fps, truncate_frame, truncate_onset, model_dur,
|
| 245 |
-
latents_scale, device, weight_dtype,
|
| 246 |
-
cfg_scale, num_steps, mode,
|
| 247 |
-
euler_sampler, euler_maruyama_sampler,
|
| 248 |
-
)
|
| 249 |
-
wavs.append(wav)
|
| 250 |
-
|
| 251 |
-
# ------------------------------------------------------------------ #
|
| 252 |
-
# Stitch with crossfades #
|
| 253 |
-
# Single segment: no crossfade needed #
|
| 254 |
-
# ------------------------------------------------------------------ #
|
| 255 |
-
if len(wavs) == 1:
|
| 256 |
-
final_wav = wavs[0]
|
| 257 |
else:
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
-
audio_path
|
| 267 |
sf.write(audio_path, final_wav, sr)
|
| 268 |
|
| 269 |
-
# Mux original silent video (full length) with generated audio
|
| 270 |
output_video = os.path.join(tmp_dir, "output.mp4")
|
| 271 |
input_v = ffmpeg.input(silent_video)
|
| 272 |
input_a = ffmpeg.input(audio_path)
|
|
@@ -292,12 +308,14 @@ demo = gr.Interface(
|
|
| 292 |
gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5),
|
| 293 |
gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1),
|
| 294 |
gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde"),
|
|
|
|
|
|
|
| 295 |
],
|
| 296 |
outputs=[
|
| 297 |
gr.Video(label="Output Video with Audio"),
|
| 298 |
gr.Audio(label="Generated Audio"),
|
| 299 |
],
|
| 300 |
title="TARO: Video-to-Audio Synthesis (ICCV 2025)",
|
| 301 |
-
description="Upload a video and generate synchronized audio using TARO. Optimal duration is 8.2s.",
|
| 302 |
)
|
| 303 |
-
demo.queue().launch()
|
|
|
|
| 30 |
taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache_dir=CACHE_DIR)
|
| 31 |
print("Checkpoints downloaded.")
|
| 32 |
|
| 33 |
+
# ------------------------------------------------------------------ #
|
| 34 |
+
# Inference cache: keyed by (video_path, seed, cfg_scale, #
|
| 35 |
+
# num_steps, mode, crossfade_s) #
|
| 36 |
+
# Stores the raw per-segment wavs so that only the dB value can be #
|
| 37 |
+
# changed without re-running the model. #
|
| 38 |
+
# ------------------------------------------------------------------ #
|
| 39 |
+
_INFERENCE_CACHE = {} # key -> {"wavs": [...], "sr": int}
|
| 40 |
+
|
| 41 |
|
| 42 |
def set_global_seed(seed):
|
| 43 |
np.random.seed(seed % (2**32))
|
|
|
|
| 64 |
cfg_scale, num_steps, mode,
|
| 65 |
euler_sampler, euler_maruyama_sampler):
|
| 66 |
"""
|
| 67 |
+
Run one model inference pass for the video window starting at seg_start_s.
|
| 68 |
+
Returns a numpy float32 wav array trimmed to (seg_end_s - seg_start_s).
|
|
|
|
| 69 |
"""
|
| 70 |
+
# CAVP features at fps (4 fps)
|
| 71 |
cavp_start = int(round(seg_start_s * fps))
|
| 72 |
+
cavp_slice = cavp_feats_full[cavp_start : cavp_start + truncate_frame]
|
|
|
|
|
|
|
| 73 |
if cavp_slice.shape[0] < truncate_frame:
|
| 74 |
+
pad = np.zeros(
|
| 75 |
+
(truncate_frame - cavp_slice.shape[0],) + cavp_slice.shape[1:],
|
| 76 |
+
dtype=cavp_slice.dtype,
|
| 77 |
+
)
|
| 78 |
cavp_slice = np.concatenate([cavp_slice, pad], axis=0)
|
| 79 |
video_feats = torch.from_numpy(cavp_slice).unsqueeze(0).to(device).to(weight_dtype)
|
| 80 |
|
| 81 |
+
# Onset features at truncate_onset / model_dur frames per second
|
| 82 |
+
onset_fps = truncate_onset / model_dur
|
| 83 |
onset_start = int(round(seg_start_s * onset_fps))
|
| 84 |
onset_slice = onset_feats_full[onset_start : onset_start + truncate_onset]
|
| 85 |
if onset_slice.shape[0] < truncate_onset:
|
|
|
|
| 87 |
onset_slice = np.pad(onset_slice, ((0, pad_len),), mode="constant", constant_values=0)
|
| 88 |
onset_feats_t = torch.from_numpy(onset_slice).unsqueeze(0).to(device).to(weight_dtype)
|
| 89 |
|
|
|
|
| 90 |
z = torch.randn(1, model.in_channels, 204, 16, device=device).to(weight_dtype)
|
| 91 |
sampling_kwargs = dict(
|
| 92 |
model=model,
|
|
|
|
| 109 |
samples = vae.decode(samples / latents_scale).sample
|
| 110 |
wav = vocoder(samples.squeeze().float()).detach().cpu().numpy()
|
| 111 |
|
|
|
|
| 112 |
seg_samples = int(round((seg_end_s - seg_start_s) * sr))
|
| 113 |
return wav[:seg_samples]
|
| 114 |
|
| 115 |
|
| 116 |
+
def crossfade_join(wav_a, wav_b, crossfade_s, db_boost, sr):
|
| 117 |
"""
|
| 118 |
+
Join wav_a and wav_b with a crossfade_s-second crossfade.
|
| 119 |
+
|
| 120 |
+
db_boost controls the gain applied to both signals in the overlap region:
|
| 121 |
+
gain = 10 ** (db_boost / 20)
|
| 122 |
+
At +3 dB (gain ≈ 1.414), the two summed unity signals produce +3 dB at midpoint.
|
| 123 |
+
At 0 dB (gain = 1.0), each signal is kept at full amplitude — same as +3 dB sum
|
| 124 |
+
since both are 1.0. The parameter lets the user tune the blend level freely.
|
| 125 |
+
|
| 126 |
+
The crossfade window is the last crossfade_s seconds of wav_a overlapping with
|
| 127 |
+
the first crossfade_s seconds of wav_b. Both are scaled by gain and summed.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
"""
|
| 129 |
cf_samples = int(round(crossfade_s * sr))
|
| 130 |
|
| 131 |
+
# Guard: if either wav is shorter than the crossfade window, shrink the window
|
| 132 |
+
cf_samples = min(cf_samples, len(wav_a), len(wav_b))
|
| 133 |
+
if cf_samples <= 0:
|
| 134 |
+
return np.concatenate([wav_a, wav_b])
|
|
|
|
| 135 |
|
| 136 |
+
gain = 10 ** (db_boost / 20.0)
|
| 137 |
+
|
| 138 |
+
tail_a = wav_a[-cf_samples:] * gain
|
| 139 |
+
head_b = wav_b[:cf_samples] * gain
|
| 140 |
+
overlap = tail_a + head_b
|
| 141 |
+
|
| 142 |
+
return np.concatenate([
|
| 143 |
+
wav_a[:-cf_samples],
|
| 144 |
+
overlap,
|
| 145 |
+
wav_b[cf_samples:],
|
| 146 |
])
|
|
|
|
| 147 |
|
| 148 |
|
| 149 |
+
def stitch_wavs(wavs, crossfade_s, db_boost, sr, total_dur_s):
|
| 150 |
+
"""Stitch a list of wav arrays using crossfade_join, then clip to total_dur_s."""
|
| 151 |
+
if len(wavs) == 1:
|
| 152 |
+
final_wav = wavs[0]
|
| 153 |
+
else:
|
| 154 |
+
final_wav = wavs[0]
|
| 155 |
+
for next_wav in wavs[1:]:
|
| 156 |
+
final_wav = crossfade_join(final_wav, next_wav, crossfade_s, db_boost, sr)
|
| 157 |
+
|
| 158 |
+
target_samples = int(round(total_dur_s * sr))
|
| 159 |
+
return final_wav[:target_samples]
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
+
@spaces.GPU(duration=300)
|
| 163 |
+
def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode,
|
| 164 |
+
crossfade_s, crossfade_db):
|
| 165 |
+
global _INFERENCE_CACHE
|
| 166 |
|
| 167 |
+
seed_val = int(seed_val)
|
| 168 |
+
crossfade_s = float(crossfade_s)
|
| 169 |
+
crossfade_db = float(crossfade_db)
|
| 170 |
+
|
| 171 |
+
if seed_val < 0:
|
| 172 |
+
seed_val = random.randint(0, 2**32 - 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
sr = 16000
|
| 175 |
truncate = 131072
|
| 176 |
fps = 4
|
| 177 |
+
truncate_frame = int(fps * truncate / sr)
|
| 178 |
+
truncate_onset = 120
|
| 179 |
+
model_dur = truncate / sr # 8.192 s
|
| 180 |
+
step_s = model_dur - crossfade_s
|
| 181 |
+
|
| 182 |
+
# Cache key covers everything that affects segmentation and inference
|
| 183 |
+
cache_key = (video_file, seed_val, float(cfg_scale), int(num_steps), mode,
|
| 184 |
+
crossfade_s)
|
| 185 |
+
|
| 186 |
+
if cache_key in _INFERENCE_CACHE:
|
| 187 |
+
print("Cache hit — skipping inference, re-stitching with new dB value.")
|
| 188 |
+
cached = _INFERENCE_CACHE[cache_key]
|
| 189 |
+
wavs = cached["wavs"]
|
| 190 |
+
total_dur_s = cached["total_dur_s"]
|
| 191 |
+
tmp_dir = cached["tmp_dir"]
|
| 192 |
+
silent_video = cached["silent_video"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
else:
|
| 194 |
+
set_global_seed(seed_val)
|
| 195 |
+
torch.set_grad_enabled(False)
|
| 196 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 197 |
+
weight_dtype = torch.bfloat16
|
| 198 |
+
|
| 199 |
+
from cavp_util import Extract_CAVP_Features
|
| 200 |
+
from onset_util import VideoOnsetNet, extract_onset
|
| 201 |
+
from models import MMDiT
|
| 202 |
+
from samplers import euler_sampler, euler_maruyama_sampler
|
| 203 |
+
from diffusers import AudioLDM2Pipeline
|
| 204 |
+
|
| 205 |
+
extract_cavp = Extract_CAVP_Features(
|
| 206 |
+
device=device, config_path="./cavp/cavp.yaml", ckpt_path=cavp_ckpt_path
|
| 207 |
+
)
|
| 208 |
|
| 209 |
+
state_dict = torch.load(onset_ckpt_path, map_location=device, weights_only=False)["state_dict"]
|
| 210 |
+
new_state_dict = {}
|
| 211 |
+
for key, value in state_dict.items():
|
| 212 |
+
if "model.net.model" in key:
|
| 213 |
+
new_key = key.replace("model.net.model", "net.model")
|
| 214 |
+
elif "model.fc." in key:
|
| 215 |
+
new_key = key.replace("model.fc", "fc")
|
| 216 |
+
else:
|
| 217 |
+
new_key = key
|
| 218 |
+
new_state_dict[new_key] = value
|
| 219 |
+
onset_model = VideoOnsetNet(False).to(device)
|
| 220 |
+
onset_model.load_state_dict(new_state_dict)
|
| 221 |
+
onset_model.eval()
|
| 222 |
+
|
| 223 |
+
model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
|
| 224 |
+
ckpt = torch.load(taro_ckpt_path, map_location=device, weights_only=False)["ema"]
|
| 225 |
+
model.load_state_dict(ckpt)
|
| 226 |
+
model.eval()
|
| 227 |
+
model.to(weight_dtype)
|
| 228 |
+
|
| 229 |
+
model_audioldm = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
|
| 230 |
+
vae = model_audioldm.vae.to(device)
|
| 231 |
+
vae.eval()
|
| 232 |
+
vocoder = model_audioldm.vocoder.to(device)
|
| 233 |
+
|
| 234 |
+
tmp_dir = tempfile.mkdtemp()
|
| 235 |
+
silent_video = os.path.join(tmp_dir, "silent_input.mp4")
|
| 236 |
+
strip_audio_from_video(video_file, silent_video)
|
| 237 |
+
|
| 238 |
+
cavp_feats = extract_cavp(silent_video, tmp_path=tmp_dir)
|
| 239 |
+
onset_feats = extract_onset(silent_video, onset_model, tmp_path=tmp_dir, device=device)
|
| 240 |
+
|
| 241 |
+
latents_scale = torch.tensor([0.18215] * 8).view(1, 8, 1, 1).to(device)
|
| 242 |
+
|
| 243 |
+
total_frames = cavp_feats.shape[0]
|
| 244 |
+
total_dur_s = total_frames / fps
|
| 245 |
+
|
| 246 |
+
# Build segment list
|
| 247 |
+
segments = []
|
| 248 |
+
seg_start = 0.0
|
| 249 |
+
while True:
|
| 250 |
+
seg_end = min(seg_start + model_dur, total_dur_s)
|
| 251 |
+
segments.append((seg_start, seg_end))
|
| 252 |
+
if seg_end >= total_dur_s:
|
| 253 |
+
break
|
| 254 |
+
seg_start += step_s
|
| 255 |
+
|
| 256 |
+
# Run inference for every segment
|
| 257 |
+
wavs = []
|
| 258 |
+
for seg_start_s, seg_end_s in segments:
|
| 259 |
+
print(f"Inferring segment {seg_start_s:.2f}s – {seg_end_s:.2f}s ...")
|
| 260 |
+
wav = infer_segment(
|
| 261 |
+
model, vae, vocoder,
|
| 262 |
+
cavp_feats, onset_feats,
|
| 263 |
+
seg_start_s, seg_end_s,
|
| 264 |
+
sr, fps, truncate_frame, truncate_onset, model_dur,
|
| 265 |
+
latents_scale, device, weight_dtype,
|
| 266 |
+
cfg_scale, num_steps, mode,
|
| 267 |
+
euler_sampler, euler_maruyama_sampler,
|
| 268 |
+
)
|
| 269 |
+
wavs.append(wav)
|
| 270 |
+
|
| 271 |
+
# Store in cache
|
| 272 |
+
_INFERENCE_CACHE[cache_key] = {
|
| 273 |
+
"wavs": wavs,
|
| 274 |
+
"total_dur_s": total_dur_s,
|
| 275 |
+
"tmp_dir": tmp_dir,
|
| 276 |
+
"silent_video": silent_video,
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
# Stitch with current crossfade params
|
| 280 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 281 |
+
final_wav = stitch_wavs(wavs, crossfade_s, crossfade_db, sr, total_dur_s)
|
| 282 |
|
| 283 |
+
audio_path = os.path.join(tmp_dir, "output.wav")
|
| 284 |
sf.write(audio_path, final_wav, sr)
|
| 285 |
|
|
|
|
| 286 |
output_video = os.path.join(tmp_dir, "output.mp4")
|
| 287 |
input_v = ffmpeg.input(silent_video)
|
| 288 |
input_a = ffmpeg.input(audio_path)
|
|
|
|
| 308 |
gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5),
|
| 309 |
gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1),
|
| 310 |
gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde"),
|
| 311 |
+
gr.Slider(label="Crossfade Duration (s)", minimum=0, maximum=8, value=2, step=0.1),
|
| 312 |
+
gr.Textbox(label="Crossfade Boost (dB)", value="3"),
|
| 313 |
],
|
| 314 |
outputs=[
|
| 315 |
gr.Video(label="Output Video with Audio"),
|
| 316 |
gr.Audio(label="Generated Audio"),
|
| 317 |
],
|
| 318 |
title="TARO: Video-to-Audio Synthesis (ICCV 2025)",
|
| 319 |
+
description="Upload a video and generate synchronized audio using TARO. Optimal clip duration is 8.2s. Longer videos are automatically split into overlapping segments and stitched with a crossfade.",
|
| 320 |
)
|
| 321 |
+
demo.queue().launch(share=True)
|