JackIsNotInTheBox commited on
Commit
8ca132a
·
1 Parent(s): e197e0a

Add app.py with Gradio interface for TARO inference

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py CHANGED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import random
5
+ import soundfile as sf
6
+ import ffmpeg
7
+ import tempfile
8
+ import spaces
9
+ import gradio as gr
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ REPO_ID = "JackIsNotInTheBox/Taro_checkpoints"
13
+ CACHE_DIR = "/tmp/taro_ckpts"
14
+ os.makedirs(CACHE_DIR, exist_ok=True)
15
+
16
+ print("Downloading checkpoints...")
17
+ cavp_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="cavp_epoch66.ckpt", cache_dir=CACHE_DIR)
18
+ onset_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="onset_model.ckpt", cache_dir=CACHE_DIR)
19
+ taro_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="taro_ckpt.pt", cache_dir=CACHE_DIR)
20
+ print("Checkpoints downloaded.")
21
+
22
+ def set_global_seed(seed):
23
+ np.random.seed(seed % (2**32))
24
+ random.seed(seed)
25
+ torch.manual_seed(seed)
26
+ torch.cuda.manual_seed(seed)
27
+ torch.backends.cudnn.deterministic = True
28
+
29
+ @spaces.GPU(duration=300)
30
+ def generate_audio(video_file, seed_val, cfg_scale, num_steps, mode):
31
+ set_global_seed(int(seed_val))
32
+ torch.set_grad_enabled(False)
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ weight_dtype = torch.bfloat16
35
+ from cavp_util import Extract_CAVP_Features
36
+ from onset_util import VideoOnsetNet, extract_onset
37
+ from models import MMDiT
38
+ from samplers import euler_sampler, euler_maruyama_sampler
39
+ from diffusers import AudioLDM2Pipeline
40
+ extract_cavp = Extract_CAVP_Features(device=device, config_path="./cavp/cavp.yaml", ckpt_path=cavp_ckpt_path)
41
+ state_dict = torch.load(onset_ckpt_path, map_location=device)["state_dict"]
42
+ new_state_dict = {}
43
+ for key, value in state_dict.items():
44
+ if "model.net.model" in key:
45
+ new_key = key.replace("model.net.model", "net.model")
46
+ elif "model.fc." in key:
47
+ new_key = key.replace("model.fc", "fc")
48
+ else:
49
+ new_key = key
50
+ new_state_dict[new_key] = value
51
+ onset_model = VideoOnsetNet(False).to(device)
52
+ onset_model.load_state_dict(new_state_dict)
53
+ onset_model.eval()
54
+ model = MMDiT(adm_in_channels=120, z_dims=[768], encoder_depth=4).to(device)
55
+ ckpt = torch.load(taro_ckpt_path, map_location=device)["ema"]
56
+ model.load_state_dict(ckpt)
57
+ model.eval()
58
+ model.to(weight_dtype)
59
+ model_audioldm = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
60
+ vae = model_audioldm.vae.to(device)
61
+ vae.eval()
62
+ vocoder = model_audioldm.vocoder.to(device)
63
+ tmp_dir = tempfile.mkdtemp()
64
+ cavp_feats = extract_cavp(video_file, tmp_path=tmp_dir)
65
+ onset_feats = extract_onset(video_file, onset_model, tmp_path=tmp_dir, device=device)
66
+ sr = 16000
67
+ truncate = 131072
68
+ fps = 4
69
+ truncate_frame = int(fps * truncate / sr)
70
+ truncate_onset = 120
71
+ latents_scale = torch.tensor([0.18215]*8).view(1, 8, 1, 1).to(device)
72
+ video_feats = torch.from_numpy(cavp_feats[:truncate_frame]).unsqueeze(0).to(device).to(weight_dtype)
73
+ onset_feats_t = torch.from_numpy(onset_feats[:truncate_onset]).unsqueeze(0).to(device).to(weight_dtype)
74
+ z = torch.randn(len(video_feats), model.in_channels, 204, 16, device=device).to(weight_dtype)
75
+ sampling_kwargs = dict(model=model, latents=z, y=onset_feats_t, context=video_feats, num_steps=int(num_steps), heun=False, cfg_scale=float(cfg_scale), guidance_low=0.0, guidance_high=0.7, path_type="linear")
76
+ with torch.no_grad():
77
+ if mode == "sde":
78
+ samples = euler_maruyama_sampler(**sampling_kwargs)
79
+ else:
80
+ samples = euler_sampler(**sampling_kwargs)
81
+ samples = vae.decode(samples / latents_scale).sample
82
+ wav_samples = vocoder(samples.squeeze()).detach().cpu().numpy()
83
+ audio_path = os.path.join(tmp_dir, "output.wav")
84
+ sf.write(audio_path, wav_samples, sr)
85
+ duration = truncate / sr
86
+ trimmed_video = os.path.join(tmp_dir, "trimmed.mp4")
87
+ output_video = os.path.join(tmp_dir, "output.mp4")
88
+ ffmpeg.input(video_file, ss=0, t=duration).output(trimmed_video, vcodec="libx264", an=None).run(overwrite_output=True, quiet=True)
89
+ input_v = ffmpeg.input(trimmed_video)
90
+ input_a = ffmpeg.input(audio_path)
91
+ ffmpeg.output(input_v, input_a, output_video, vcodec="libx264", acodec="aac", strict="experimental").run(overwrite_output=True, quiet=True)
92
+ return output_video, audio_path
93
+
94
+ demo = gr.Interface(fn=generate_audio, inputs=[gr.Video(label="Input Video"), gr.Number(label="Seed", value=0, precision=0), gr.Slider(label="CFG Scale", minimum=1, maximum=15, value=8, step=0.5), gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=25, step=1), gr.Radio(label="Sampling Mode", choices=["sde", "ode"], value="sde")], outputs=[gr.Video(label="Output Video with Audio"), gr.Audio(label="Generated Audio")], title="TARO: Video-to-Audio Synthesis (ICCV 2025)", description="Upload a video and generate synchronized audio using TARO.")
95
+ demo.queue().launch()