import librosa | |
import torch | |
from transformers import Wav2Vec2ForCTC, AutoProcessor | |
from transformers import set_seed | |
import time | |
def transcribe(fp:str, target_lang:str) -> str: | |
''' | |
For given audio file, transcribe it. | |
Parameters | |
---------- | |
fp: str | |
The file path to the audio file. | |
target_lang:str | |
The ISO-3 code of the target language. | |
Returns | |
---------- | |
transcript:str | |
The transcribed text. | |
''' | |
# Ensure replicability | |
set_seed(555) | |
start_time = time.time() | |
# Load transcription model | |
model_id = "facebook/mms-1b-all" | |
processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang) | |
model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang=target_lang, ignore_mismatched_sizes=True) | |
# Process the audio | |
signal, sampling_rate = librosa.load(fp, sr=16000) | |
inputs = processor(signal, sampling_rate=16_000, return_tensors="pt") | |
# Inference | |
with torch.no_grad(): | |
outputs = model(**inputs).logits | |
ids = torch.argmax(outputs, dim=-1)[0] | |
transcript = processor.decode(ids) | |
print("Time elapsed: ", int(time.time() - start_time), " seconds") | |
return transcript |