fffiloni commited on
Commit
62787b7
1 Parent(s): fb299f3

Update inference_from_video.py

Browse files
Files changed (1) hide show
  1. inference_from_video.py +9 -25
inference_from_video.py CHANGED
@@ -3,16 +3,10 @@ import copy
3
  import json
4
  import time
5
  import torch
6
-
7
- # Check if CUDA is available and set the device
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
- print("Using device:", device)
10
-
11
  import argparse
12
  from PIL import Image
13
  import numpy as np
14
  import soundfile as sf
15
- #import wandb
16
  from tqdm import tqdm
17
  from diffusers import DDPMScheduler
18
  from models import build_pretrained_models, AudioDiffusion
@@ -21,6 +15,10 @@ import torchaudio
21
  import tools.torch_tools as torch_tools
22
  from datasets import load_dataset
23
 
 
 
 
 
24
  class dotdict(dict):
25
  """dot.notation access to dictionary attributes"""
26
  __getattr__ = dict.get
@@ -80,7 +78,6 @@ def parse_args():
80
  )
81
 
82
  args = parser.parse_args()
83
-
84
  return args
85
 
86
  def main():
@@ -93,12 +90,13 @@ def main():
93
  # Load Models #
94
  name = train_args.vae_model
95
  vae, stft = build_pretrained_models(name)
96
- vae, stft = vae.cuda(), stft.cuda()
 
97
  model_class = AudioDiffusion
98
  if train_args.ib:
99
  print("*****USING MODEL IMAGEBIND*****")
100
  from models_imagebind import AudioDiffusion_IB
101
- model_class = AudioDiffusion if not train_args.ib else AudioDiffusion_IB
102
  elif train_args.lb:
103
  print("*****USING MODEL LANGUAGEBIND*****")
104
  from models_languagebind import AudioDiffusion_LB
@@ -125,9 +123,8 @@ def main():
125
  model.eval()
126
 
127
  # Load Trained Weight #
128
-
129
  if args.model.endswith(".pt") or args.model.endswith(".bin"):
130
- model.load_state_dict(torch.load(args.model), strict=False)
131
  else:
132
  from safetensors.torch import load_model
133
  load_model(model, args.model, strict=False)
@@ -136,8 +133,6 @@ def main():
136
 
137
  scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
138
  sample_rate = args.sample_rate
139
- #evaluator = EvaluationHelper(16000, "cuda:0")
140
-
141
 
142
  def audio_text_matching(waveforms, text, sample_freq=24000, max_len_in_seconds=10):
143
  new_freq = 48000
@@ -163,7 +158,6 @@ def main():
163
  else:
164
  prefix = ""
165
 
166
- # data_path = "data/video_test/"
167
  data_path = args.data_path
168
  wavname = [f"{name.split('.')[0]}.wav" for name in os.listdir(data_path)]
169
  video_features = []
@@ -171,25 +165,15 @@ def main():
171
  video_path = os.path.join(data_path, video_file)
172
  video_feature = torch_tools.load_video(video_path, frame_rate=2, size=224)
173
  print(video_feature.shape)
174
- video_features.append(video_feature)
175
 
176
  # Generate #
177
  num_steps, guidance, batch_size, num_samples = args.num_steps, args.guidance, args.batch_size, args.num_samples
178
  all_outputs = []
179
 
180
  for k in tqdm(range(0, len(wavname), batch_size)):
181
-
182
  with torch.no_grad():
183
- # if train_args.task == 'image2audio':
184
- # prompt = text_prompts[k: k+batch_size]
185
- # imgs = []
186
- # for img_path in prompt:
187
- # img = Image.open(img_path)
188
- # imgs.append(np.array(img))
189
- # prompt = imgs
190
- # elif train_args.task == 'video2audio':
191
  prompt = video_features[k: k+batch_size]
192
-
193
  latents = model.inference(scheduler, None, prompt, None, num_steps, guidance, num_samples, disable_progress=True, device=device)
194
  mel = vae.decode_first_stage(latents)
195
  wave = vae.decode_to_waveform(mel)
 
3
  import json
4
  import time
5
  import torch
 
 
 
 
 
6
  import argparse
7
  from PIL import Image
8
  import numpy as np
9
  import soundfile as sf
 
10
  from tqdm import tqdm
11
  from diffusers import DDPMScheduler
12
  from models import build_pretrained_models, AudioDiffusion
 
15
  import tools.torch_tools as torch_tools
16
  from datasets import load_dataset
17
 
18
+ # Check if CUDA is available and set the device
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ print("Using device:", device)
21
+
22
  class dotdict(dict):
23
  """dot.notation access to dictionary attributes"""
24
  __getattr__ = dict.get
 
78
  )
79
 
80
  args = parser.parse_args()
 
81
  return args
82
 
83
  def main():
 
90
  # Load Models #
91
  name = train_args.vae_model
92
  vae, stft = build_pretrained_models(name)
93
+ vae, stft = vae.to(device), stft.to(device) # Ensure models are on the correct device
94
+
95
  model_class = AudioDiffusion
96
  if train_args.ib:
97
  print("*****USING MODEL IMAGEBIND*****")
98
  from models_imagebind import AudioDiffusion_IB
99
+ model_class = AudioDiffusion_IB
100
  elif train_args.lb:
101
  print("*****USING MODEL LANGUAGEBIND*****")
102
  from models_languagebind import AudioDiffusion_LB
 
123
  model.eval()
124
 
125
  # Load Trained Weight #
 
126
  if args.model.endswith(".pt") or args.model.endswith(".bin"):
127
+ model.load_state_dict(torch.load(args.model, map_location=device), strict=False)
128
  else:
129
  from safetensors.torch import load_model
130
  load_model(model, args.model, strict=False)
 
133
 
134
  scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
135
  sample_rate = args.sample_rate
 
 
136
 
137
  def audio_text_matching(waveforms, text, sample_freq=24000, max_len_in_seconds=10):
138
  new_freq = 48000
 
158
  else:
159
  prefix = ""
160
 
 
161
  data_path = args.data_path
162
  wavname = [f"{name.split('.')[0]}.wav" for name in os.listdir(data_path)]
163
  video_features = []
 
165
  video_path = os.path.join(data_path, video_file)
166
  video_feature = torch_tools.load_video(video_path, frame_rate=2, size=224)
167
  print(video_feature.shape)
168
+ video_features.append(video_feature.to(device)) # Move to device
169
 
170
  # Generate #
171
  num_steps, guidance, batch_size, num_samples = args.num_steps, args.guidance, args.batch_size, args.num_samples
172
  all_outputs = []
173
 
174
  for k in tqdm(range(0, len(wavname), batch_size)):
 
175
  with torch.no_grad():
 
 
 
 
 
 
 
 
176
  prompt = video_features[k: k+batch_size]
 
177
  latents = model.inference(scheduler, None, prompt, None, num_steps, guidance, num_samples, disable_progress=True, device=device)
178
  mel = vae.decode_first_stage(latents)
179
  wave = vae.decode_to_waveform(mel)