File size: 4,870 Bytes
fb4e25f
 
d66e935
59da3de
 
f9e3936
59da3de
 
73bf18c
 
 
 
535f2ec
73bf18c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ed6760
 
 
 
 
bcf29d2
0ed6760
 
 
 
 
bcf29d2
0ed6760
 
 
 
 
bcf29d2
0ed6760
73bf18c
cb25b1b
cef12c2
cb25b1b
73bf18c
63deeee
 
020e681
 
 
 
 
73bf18c
 
 
 
 
63deeee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bcd824
73bf18c
2bc0b29
 
def88bd
2bc0b29
 
 
1bcd824
def88bd
1bcd824
 
 
 
 
 
 
 
 
 
 
 
138be94
 
b4ecc66
51be472
9a1f9cf
138be94
73bf18c
535f2ec
138be94
 
ae1d4cd
138be94
 
cf263be
138be94
fb4e25f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr

import os
import torch
import librosa
from glob import glob
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline, AutoModelForTokenClassification, TokenClassificationPipeline, Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM

SAMPLE_RATE = 16_000

models = {}

models_paths = {
    "en-US": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
    "fr-FR": "jonatasgrosman/wav2vec2-large-xlsr-53-french",
    "nl-NL": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
    "pl-PL": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
    "it-IT": "jonatasgrosman/wav2vec2-large-xlsr-53-italian",
    "ru-RU": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
    "pt-PT": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
    "de-DE": "jonatasgrosman/wav2vec2-large-xlsr-53-german",
    "es-ES": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish",
    "ja-JP": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
    "ar-SA": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
    "fi-FI": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
    "hu-HU": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
    "zh-CN": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
    "el-GR": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
}

# Classifier Intent
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
tokenizer_intent = AutoTokenizer.from_pretrained(model_name)
model_intent = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier_intent = TextClassificationPipeline(model=model_intent, tokenizer=tokenizer_intent)

# Classifier Language
model_name = 'qanastek/51-languages-classifier'
tokenizer_langs = AutoTokenizer.from_pretrained(model_name)
model_langs = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier_language = TextClassificationPipeline(model=model_langs, tokenizer=tokenizer_langs)

# NER Extractor
model_name = 'qanastek/XLMRoberta-Alexa-Intents-NER-NLU'
tokenizer_ner = AutoTokenizer.from_pretrained(model_name)
model_ner = AutoModelForTokenClassification.from_pretrained(model_name)
predict_ner = TokenClassificationPipeline(model=model_ner, tokenizer=tokenizer_ner)

EXAMPLE_DIR = './wavs/'
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.wav')))
examples = [[e, e.split("=")[0].split("/")[-1]] for e in examples]

def transcribe(audio_path, lang_code):
        
    speech_array, sampling_rate = librosa.load(audio_path, sr=16_000)

    if lang_code not in models:
        models[lang_code] = {}
        models[lang_code]["processor"] = Wav2Vec2Processor.from_pretrained(models_paths[lang_code])
        models[lang_code]["model"] = Wav2Vec2ForCTC.from_pretrained(models_paths[lang_code])
    
    # Load model
    processor_asr = models[lang_code]["processor"]
    model_asr = models[lang_code]["model"]

    inputs = processor_asr(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True)

    with torch.no_grad():
        logits = model_asr(inputs.input_values, attention_mask=inputs.attention_mask).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    
    return processor_asr.batch_decode(predicted_ids)[0]

def getUniform(text):

    idx = 0
    res = {}

    for t in text:

        raw = t["entity"].replace("B-","").replace("I-","")
        word = t["word"].replace("▁","")

        if "B-" in t["entity"]:
            res[f"{raw}|{idx}"] = [word]
            idx += 1
        else:
            res[f"{raw}|{idx}"].append(word)

    res = [(r.split("|")[0], res[r]) for r in res]

    return res


def predict(wav_file, lang_code):

    if lang_code not in models_paths.keys():

        return {
            "The language code is unknown!"
        }
    
    text = transcribe(wav_file, lang_code).replace("apizza","a pizza") + " ."

    intent_class = classifier_intent(text)[0]["label"]
    language_class = classifier_language(text)[0]["label"]
    named_entities = getUniform(predict_ner(text))

    return {
        "text": text,
        "language": language_class,
        "intent_class": intent_class,
        "named_entities": named_entities,
    }

iface = gr.Interface(
    predict,
    title='Alexa Clone πŸ‘©β€πŸ’Ό πŸ—ͺ πŸ€– Multilingual NLU',
    description='Upload your wav file to test the models (<i>First execution take about 20s to 30s, then next run in less than 1s</i>)',
    # thumbnail="",
    inputs=[
        gr.inputs.Audio(label='wav file', source='microphone', type='filepath'),
        gr.inputs.Dropdown(choices=list(models_paths.keys())),
    ],
    outputs=[
        gr.outputs.JSON(label='ASR -> Slot Recognition + Intent Classification + Language Classification'),
    ],
    examples=examples,
    article='Made with ❀️ by <a href="https://www.linkedin.com/in/yanis-labrak-8a7412145/" target="_blank">Yanis Labrak</a> thanks to πŸ€—',
)

iface.launch()