File size: 1,294 Bytes
a84c313
e5e9b34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a84c313
 
e5e9b34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a84c313
 
e5e9b34
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import time
import librosa
import torch
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
from transformers import set_seed

def identify_language(fp:str) -> str:
    ''' 
    For given audio file, identify what language it uses.
    
    Parameters
    ---------- 
    fp: str
        The file path to the audio file.

    Returns
    ---------- 
    detected_lang:str
        The iso3 code of the detected language. 

    '''
    # Ensure replicability
    set_seed(555) 

    start_time = time.time()
    
    # Load language ID model
    model_id = "facebook/mms-lid-256" # Need to find the appropriate model for the language -- 256 languages is the first that contains MOS
    processor = AutoFeatureExtractor.from_pretrained(model_id)
    model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)

    # Process the audio
    signal, sampling_rate =  librosa.load(fp, sr=16000)
    inputs = processor(signal, sampling_rate=16_000, return_tensors="pt")
    
    # Inference
    with torch.no_grad():
        outputs = model(**inputs).logits
    
    lang_id = torch.argmax(outputs, dim=-1)[0].item()
    detected_lang = model.config.id2label[lang_id]

    print("Time elapsed: ", int(time.time() - start_time), " seconds")
    return detected_lang