import time | |
import librosa | |
import torch | |
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor | |
from transformers import set_seed | |
def identify_language(fp:str) -> str: | |
''' | |
For given audio file, identify what language it uses. | |
Parameters | |
---------- | |
fp: str | |
The file path to the audio file. | |
Returns | |
---------- | |
detected_lang:str | |
The iso3 code of the detected language. | |
''' | |
# Ensure replicability | |
set_seed(555) | |
start_time = time.time() | |
# Load language ID model | |
model_id = "facebook/mms-lid-256" # Need to find the appropriate model for the language -- 256 languages is the first that contains MOS | |
processor = AutoFeatureExtractor.from_pretrained(model_id) | |
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id) | |
# Process the audio | |
signal, sampling_rate = librosa.load(fp, sr=16000) | |
inputs = processor(signal, sampling_rate=16_000, return_tensors="pt") | |
# Inference | |
with torch.no_grad(): | |
outputs = model(**inputs).logits | |
lang_id = torch.argmax(outputs, dim=-1)[0].item() | |
detected_lang = model.config.id2label[lang_id] | |
print("Time elapsed: ", int(time.time() - start_time), " seconds") | |
return detected_lang | |