File size: 2,185 Bytes
7bcf8d7
 
 
d697dab
7bcf8d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d697dab
 
 
 
a043894
71996b7
e628248
71996b7
e628248
d697dab
 
 
 
7bcf8d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d697dab
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
import torch
import librosa
import numpy as np

model_id = "facebook/mms-lid-1024"

processor = AutoFeatureExtractor.from_pretrained(model_id)
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)


LID_SAMPLING_RATE = 16_000
LID_TOPK = 10
LID_THRESHOLD = 0.33

LID_LANGUAGES = {}
with open(f"data/lid/all_langs.tsv") as f:
    for line in f:
        iso, name = line.split(" ", 1)
        LID_LANGUAGES[iso] = name


def identify(audio_data):
    if isinstance(audio_data, tuple):
        # microphone
        sr, audio_samples = audio_data
        audio_samples = (audio_samples / 32768.0).astype(np.float32)
        if sr != LID_SAMPLING_RATE:
            audio_samples = librosa.resample(
                audio_samples, orig_sr=sr, target_sr=LID_SAMPLING_RATE
            )
    else:
        # file upload
        isinstance(audio_data, str)
        audio_samples = librosa.load(audio_data, sr=LID_SAMPLING_RATE, mono=True)[0]

    inputs = processor(
        audio_samples, sampling_rate=LID_SAMPLING_RATE, return_tensors="pt"
    )

    # set device
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif (
        hasattr(torch.backends, "mps")
        and torch.backends.mps.is_available()
        and torch.backends.mps.is_built()
    ):
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    model.to(device)
    inputs = inputs.to(device)

    with torch.no_grad():
        logit = model(**inputs).logits

    logit_lsm = torch.log_softmax(logit.squeeze(), dim=-1)
    scores, indices = torch.topk(logit_lsm, 5, dim=-1)
    scores, indices = torch.exp(scores).to("cpu").tolist(), indices.to("cpu").tolist()
    iso2score = {model.config.id2label[int(i)]: s for s, i in zip(scores, indices)}
    if max(iso2score.values()) < LID_THRESHOLD:
        return "Low confidence in the language identification predictions. Output is not shown!"
    return {LID_LANGUAGES[iso]: score for iso, score in iso2score.items()}


LID_EXAMPLES = [
    ["./assets/english.mp3"],
    ["./assets/tamil.mp3"],
    ["./assets/burmese.mp3"],
]