AudioGPT / audio_to_text /inference_waveform.py
lmzjms's picture
Update audio_to_text/inference_waveform.py
f57acc9
raw history blame
No virus
3.41 kB
import sys
import os
import librosa
import numpy as np
import torch
import audio_to_text.captioning.models
import audio_to_text.captioning.models.encoder
import audio_to_text.captioning.models.decoder
import audio_to_text.captioning.utils.train_util as train_util
def load_model(config, checkpoint):
ckpt = torch.load(checkpoint, "cpu")
encoder_cfg = config["model"]["encoder"]
encoder = train_util.init_obj(
audio_to_text.captioning.models.encoder,
encoder_cfg
)
if "pretrained" in encoder_cfg:
pretrained = encoder_cfg["pretrained"]
train_util.load_pretrained_model(encoder,
pretrained,
sys.stdout.write)
decoder_cfg = config["model"]["decoder"]
if "vocab_size" not in decoder_cfg["args"]:
decoder_cfg["args"]["vocab_size"] = len(ckpt["vocabulary"])
decoder = train_util.init_obj(
audio_to_text.captioning.models.decoder,
decoder_cfg
)
if "word_embedding" in decoder_cfg:
decoder.load_word_embedding(**decoder_cfg["word_embedding"])
if "pretrained" in decoder_cfg:
pretrained = decoder_cfg["pretrained"]
train_util.load_pretrained_model(decoder,
pretrained,
sys.stdout.write)
model = train_util.init_obj(audio_to_text.captioning.models, config["model"],
encoder=encoder, decoder=decoder)
train_util.load_pretrained_model(model, ckpt)
model.eval()
return {
"model": model,
"vocabulary": ckpt["vocabulary"]
}
def decode_caption(word_ids, vocabulary):
candidate = []
for word_id in word_ids:
word = vocabulary[word_id]
if word == "<end>":
break
elif word == "<start>":
continue
candidate.append(word)
candidate = " ".join(candidate)
return candidate
class AudioCapModel(object):
def __init__(self,weight_dir,device='cpu'):
config = os.path.join(weight_dir,'config.yaml')
self.config = train_util.parse_config_or_kwargs(config)
checkpoint = os.path.join(weight_dir,'swa.pth')
resumed = load_model(self.config, checkpoint)
model = resumed["model"]
self.vocabulary = resumed["vocabulary"]
self.model = model.to(device)
self.device = device
def caption(self,audio_list):
if isinstance(audio_list,np.ndarray):
audio_list = [audio_list]
elif isinstance(audio_list,str):
audio_list = [librosa.load(audio_list,sr=32000)[0]]
captions = []
for wav in audio_list:
inputwav = torch.as_tensor(wav).float().unsqueeze(0).to(self.device)
wav_len = torch.LongTensor([len(wav)])
input_dict = {
"mode": "inference",
"wav": inputwav,
"wav_len": wav_len,
"specaug": False,
"sample_method": "beam",
}
print(input_dict)
out_dict = self.model(input_dict)
caption_batch = [decode_caption(seq, self.vocabulary) for seq in \
out_dict["seq"].cpu().numpy()]
captions.extend(caption_batch)
return captions
def __call__(self, audio_list):
return self.caption(audio_list)