Alexa-NLU-Clone / app.py
qanastek's picture
Update app.py
43bb9e7
raw
history blame
No virus
3.82 kB
import gradio as gr
import os
import torch
import librosa
from glob import glob
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline, AutoModelForTokenClassification, TokenClassificationPipeline, Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
# ASR
model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
processor_asr = Wav2Vec2Processor.from_pretrained(model_name)
model_asr = Wav2Vec2ForCTC.from_pretrained(model_name)
# Classifier Intent
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
tokenizer_intent = AutoTokenizer.from_pretrained(model_name)
model_intent = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier_intent = TextClassificationPipeline(model=model_intent, tokenizer=tokenizer_intent)
# Classifier Language
model_name = 'qanastek/51-languages-classifier'
tokenizer_langs = AutoTokenizer.from_pretrained(model_name)
model_langs = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier_language = TextClassificationPipeline(model=model_langs, tokenizer=tokenizer_langs)
# NER Extractor
model_name = 'qanastek/XLMRoberta-Alexa-Intents-NER-NLU'
tokenizer_ner = AutoTokenizer.from_pretrained(model_name)
model_ner = AutoModelForTokenClassification.from_pretrained(model_name)
predict_ner = TokenClassificationPipeline(model=model_ner, tokenizer=tokenizer_ner)
EXAMPLE_DIR = './'
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.wav')))
def transcribe(audio_path):
speech_array, sampling_rate = librosa.load(audio_path, sr=16_000)
inputs = processor_asr(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model_asr(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
return processor_asr.batch_decode(predicted_ids)[0]
def getUniform(text):
idx = 0
res = {}
for t in text:
raw = t["entity"].replace("B-","").replace("I-","")
word = t["word"].replace("▁","")
if "B-" in t["entity"]:
res[f"{raw}|{idx}"] = [word]
idx += 1
else:
res[f"{raw}|{idx}"].append(word)
res = [(r.split("|")[0], res[r]) for r in res]
return res
def process(path):
text = transcribe(path)
intent_class = classifier_intent(text)[0]["label"]
language_class = classifier_language(text)[0]["label"]
named_entities = getUniform(predict_ner(text))
return {
"text": text,
"language": language_class,
"intent_class": intent_class,
"named_entities": named_entities,
}
audio_paths = [
"/users/ylabrak/Alexa_NLU/Pipeline/wavs/set-the-volume-to-low.wav",
"/users/ylabrak/Alexa_NLU/Pipeline/wavs/tell-me-a-joke.wav",
"/users/ylabrak/Alexa_NLU/Pipeline/wavs/tell me the artist of this song.wav",
"/users/ylabrak/Alexa_NLU/Pipeline/wavs/order-a-pizza.wav",
"/users/ylabrak/Alexa_NLU/Pipeline/wavs/TTS_1/tell-me-a-good-joke.wav",
"/users/ylabrak/Alexa_NLU/Pipeline/wavs/TTS_1/order me a pizza.wav",
"/users/ylabrak/Alexa_NLU/Pipeline/wavs/TTS_1/tell-me-the-artist-of-this-song.wav",
]
def predict(wav_file):
res = process(wav_file)
return res
# iface = gr.Interface(fn=predict, inputs="text", outputs="text")
iface = gr.Interface(
predict,
title='Alexa NLU Clone',
description='Upload your wav file to test the model',
inputs=[
gr.inputs.Audio(label='wav file', source='microphone', type='filepath')
],
outputs=[
gr.outputs.JSON(label='Slot Recognition + Intent Classification + Language Classification + ASR'),
],
examples=examples,
article='Made with ❀️ by Yanis Labrak thanks to πŸ€—',
)
iface.launch()