|
from typing import List |
|
import torch |
|
import argparse |
|
import shutil |
|
import tempfile |
|
from speechbrain.pretrained import EncoderDecoderASR |
|
|
|
|
|
def asr_model_inference(model: EncoderDecoderASR, audios: List[str]) -> List[str]: |
|
""" |
|
convert input audio to words and return the result |
|
""" |
|
tmp_dir = tempfile.mkdtemp() |
|
results = [process_audio(model, audio, tmp_dir) for audio in audios] |
|
shutil.rmtree(tmp_dir) |
|
return results |
|
|
|
def process_audio(model: EncoderDecoderASR, audio: str, savedir:str) -> str: |
|
""" |
|
convert input audio to words and return the result |
|
""" |
|
waveform = model.load_audio(audio, savedir=savedir) |
|
|
|
batch = waveform.unsqueeze(0) |
|
rel_length = torch.tensor([1.0]) |
|
predicted_words, predicted_tokens = model.transcribe_batch( |
|
batch, rel_length |
|
) |
|
return predicted_words[0] |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("-I", dest="audio_file", required=True) |
|
|
|
args = parser.parse_args() |
|
|
|
asr_model = EncoderDecoderASR.from_hparams( |
|
source="./inference", hparams_file="hyperparams.yaml", savedir="inference", run_opts={"device": "cpu"}) |
|
|
|
print(asr_model_inference(asr_model, [args.audio_file])) |