Spaces:
Runtime error
Runtime error
Update inference_from_video.py
Browse files- inference_from_video.py +9 -25
inference_from_video.py
CHANGED
@@ -3,16 +3,10 @@ import copy
|
|
3 |
import json
|
4 |
import time
|
5 |
import torch
|
6 |
-
|
7 |
-
# Check if CUDA is available and set the device
|
8 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
-
print("Using device:", device)
|
10 |
-
|
11 |
import argparse
|
12 |
from PIL import Image
|
13 |
import numpy as np
|
14 |
import soundfile as sf
|
15 |
-
#import wandb
|
16 |
from tqdm import tqdm
|
17 |
from diffusers import DDPMScheduler
|
18 |
from models import build_pretrained_models, AudioDiffusion
|
@@ -21,6 +15,10 @@ import torchaudio
|
|
21 |
import tools.torch_tools as torch_tools
|
22 |
from datasets import load_dataset
|
23 |
|
|
|
|
|
|
|
|
|
24 |
class dotdict(dict):
|
25 |
"""dot.notation access to dictionary attributes"""
|
26 |
__getattr__ = dict.get
|
@@ -80,7 +78,6 @@ def parse_args():
|
|
80 |
)
|
81 |
|
82 |
args = parser.parse_args()
|
83 |
-
|
84 |
return args
|
85 |
|
86 |
def main():
|
@@ -93,12 +90,13 @@ def main():
|
|
93 |
# Load Models #
|
94 |
name = train_args.vae_model
|
95 |
vae, stft = build_pretrained_models(name)
|
96 |
-
vae, stft = vae.
|
|
|
97 |
model_class = AudioDiffusion
|
98 |
if train_args.ib:
|
99 |
print("*****USING MODEL IMAGEBIND*****")
|
100 |
from models_imagebind import AudioDiffusion_IB
|
101 |
-
model_class =
|
102 |
elif train_args.lb:
|
103 |
print("*****USING MODEL LANGUAGEBIND*****")
|
104 |
from models_languagebind import AudioDiffusion_LB
|
@@ -125,9 +123,8 @@ def main():
|
|
125 |
model.eval()
|
126 |
|
127 |
# Load Trained Weight #
|
128 |
-
|
129 |
if args.model.endswith(".pt") or args.model.endswith(".bin"):
|
130 |
-
model.load_state_dict(torch.load(args.model), strict=False)
|
131 |
else:
|
132 |
from safetensors.torch import load_model
|
133 |
load_model(model, args.model, strict=False)
|
@@ -136,8 +133,6 @@ def main():
|
|
136 |
|
137 |
scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
|
138 |
sample_rate = args.sample_rate
|
139 |
-
#evaluator = EvaluationHelper(16000, "cuda:0")
|
140 |
-
|
141 |
|
142 |
def audio_text_matching(waveforms, text, sample_freq=24000, max_len_in_seconds=10):
|
143 |
new_freq = 48000
|
@@ -163,7 +158,6 @@ def main():
|
|
163 |
else:
|
164 |
prefix = ""
|
165 |
|
166 |
-
# data_path = "data/video_test/"
|
167 |
data_path = args.data_path
|
168 |
wavname = [f"{name.split('.')[0]}.wav" for name in os.listdir(data_path)]
|
169 |
video_features = []
|
@@ -171,25 +165,15 @@ def main():
|
|
171 |
video_path = os.path.join(data_path, video_file)
|
172 |
video_feature = torch_tools.load_video(video_path, frame_rate=2, size=224)
|
173 |
print(video_feature.shape)
|
174 |
-
video_features.append(video_feature)
|
175 |
|
176 |
# Generate #
|
177 |
num_steps, guidance, batch_size, num_samples = args.num_steps, args.guidance, args.batch_size, args.num_samples
|
178 |
all_outputs = []
|
179 |
|
180 |
for k in tqdm(range(0, len(wavname), batch_size)):
|
181 |
-
|
182 |
with torch.no_grad():
|
183 |
-
# if train_args.task == 'image2audio':
|
184 |
-
# prompt = text_prompts[k: k+batch_size]
|
185 |
-
# imgs = []
|
186 |
-
# for img_path in prompt:
|
187 |
-
# img = Image.open(img_path)
|
188 |
-
# imgs.append(np.array(img))
|
189 |
-
# prompt = imgs
|
190 |
-
# elif train_args.task == 'video2audio':
|
191 |
prompt = video_features[k: k+batch_size]
|
192 |
-
|
193 |
latents = model.inference(scheduler, None, prompt, None, num_steps, guidance, num_samples, disable_progress=True, device=device)
|
194 |
mel = vae.decode_first_stage(latents)
|
195 |
wave = vae.decode_to_waveform(mel)
|
|
|
3 |
import json
|
4 |
import time
|
5 |
import torch
|
|
|
|
|
|
|
|
|
|
|
6 |
import argparse
|
7 |
from PIL import Image
|
8 |
import numpy as np
|
9 |
import soundfile as sf
|
|
|
10 |
from tqdm import tqdm
|
11 |
from diffusers import DDPMScheduler
|
12 |
from models import build_pretrained_models, AudioDiffusion
|
|
|
15 |
import tools.torch_tools as torch_tools
|
16 |
from datasets import load_dataset
|
17 |
|
18 |
+
# Check if CUDA is available and set the device
|
19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
print("Using device:", device)
|
21 |
+
|
22 |
class dotdict(dict):
|
23 |
"""dot.notation access to dictionary attributes"""
|
24 |
__getattr__ = dict.get
|
|
|
78 |
)
|
79 |
|
80 |
args = parser.parse_args()
|
|
|
81 |
return args
|
82 |
|
83 |
def main():
|
|
|
90 |
# Load Models #
|
91 |
name = train_args.vae_model
|
92 |
vae, stft = build_pretrained_models(name)
|
93 |
+
vae, stft = vae.to(device), stft.to(device) # Ensure models are on the correct device
|
94 |
+
|
95 |
model_class = AudioDiffusion
|
96 |
if train_args.ib:
|
97 |
print("*****USING MODEL IMAGEBIND*****")
|
98 |
from models_imagebind import AudioDiffusion_IB
|
99 |
+
model_class = AudioDiffusion_IB
|
100 |
elif train_args.lb:
|
101 |
print("*****USING MODEL LANGUAGEBIND*****")
|
102 |
from models_languagebind import AudioDiffusion_LB
|
|
|
123 |
model.eval()
|
124 |
|
125 |
# Load Trained Weight #
|
|
|
126 |
if args.model.endswith(".pt") or args.model.endswith(".bin"):
|
127 |
+
model.load_state_dict(torch.load(args.model, map_location=device), strict=False)
|
128 |
else:
|
129 |
from safetensors.torch import load_model
|
130 |
load_model(model, args.model, strict=False)
|
|
|
133 |
|
134 |
scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
|
135 |
sample_rate = args.sample_rate
|
|
|
|
|
136 |
|
137 |
def audio_text_matching(waveforms, text, sample_freq=24000, max_len_in_seconds=10):
|
138 |
new_freq = 48000
|
|
|
158 |
else:
|
159 |
prefix = ""
|
160 |
|
|
|
161 |
data_path = args.data_path
|
162 |
wavname = [f"{name.split('.')[0]}.wav" for name in os.listdir(data_path)]
|
163 |
video_features = []
|
|
|
165 |
video_path = os.path.join(data_path, video_file)
|
166 |
video_feature = torch_tools.load_video(video_path, frame_rate=2, size=224)
|
167 |
print(video_feature.shape)
|
168 |
+
video_features.append(video_feature.to(device)) # Move to device
|
169 |
|
170 |
# Generate #
|
171 |
num_steps, guidance, batch_size, num_samples = args.num_steps, args.guidance, args.batch_size, args.num_samples
|
172 |
all_outputs = []
|
173 |
|
174 |
for k in tqdm(range(0, len(wavname), batch_size)):
|
|
|
175 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
prompt = video_features[k: k+batch_size]
|
|
|
177 |
latents = model.inference(scheduler, None, prompt, None, num_steps, guidance, num_samples, disable_progress=True, device=device)
|
178 |
mel = vae.decode_first_stage(latents)
|
179 |
wave = vae.decode_to_waveform(mel)
|