Added application file
Browse files- voice_to_qa.py +148 -0
voice_to_qa.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cProfile import label
|
2 |
+
from unittest import result
|
3 |
+
import gradio as gr
|
4 |
+
import urllib.request
|
5 |
+
import re
|
6 |
+
import os
|
7 |
+
import requests
|
8 |
+
import torch
|
9 |
+
|
10 |
+
#Speech to text
|
11 |
+
import whisper
|
12 |
+
|
13 |
+
#QA
|
14 |
+
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
|
15 |
+
|
16 |
+
#TTS
|
17 |
+
import tempfile
|
18 |
+
from TTS.utils.manage import ModelManager
|
19 |
+
from TTS.utils.synthesizer import Synthesizer
|
20 |
+
from typing import Optional
|
21 |
+
|
22 |
+
|
23 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
24 |
+
|
25 |
+
# Whisper: Speech-to-text
|
26 |
+
model = whisper.load_model("base", device = device)
|
27 |
+
model_med = whisper.load_model("small", device = device)
|
28 |
+
|
29 |
+
#Roberta Q&A
|
30 |
+
model_name = "deepset/tinyroberta-squad2"
|
31 |
+
nlp = pipeline('question-answering', model=model_name, tokenizer=model_name, device = 0)
|
32 |
+
|
33 |
+
#TTS
|
34 |
+
tts_manager = ModelManager()
|
35 |
+
MAX_TXT_LEN = 100
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
print(model.device)
|
40 |
+
|
41 |
+
# Whisper - speech-to-text
|
42 |
+
def whisper_stt(audio):
|
43 |
+
print("Inside Whisper TTS")
|
44 |
+
# load audio and pad/trim it to fit 30 seconds
|
45 |
+
audio = whisper.load_audio(audio)
|
46 |
+
audio = whisper.pad_or_trim(audio)
|
47 |
+
|
48 |
+
# make log-Mel spectrogram and move to the same device as the model
|
49 |
+
mel = whisper.log_mel_spectrogram(audio).to(model.device)
|
50 |
+
|
51 |
+
# detect the spoken language
|
52 |
+
_, probs = model.detect_language(mel)
|
53 |
+
lang = max(probs, key=probs.get)
|
54 |
+
print(f"Detected language: {max(probs, key=probs.get)}")
|
55 |
+
|
56 |
+
# decode the audio
|
57 |
+
options_transc = whisper.DecodingOptions(fp16 = False, language=lang, task='transcribe') #lang
|
58 |
+
options_transl = whisper.DecodingOptions(fp16 = False, language='en', task='translate') #lang
|
59 |
+
result_transc = whisper.decode(model_med, mel, options_transc)
|
60 |
+
result_transl = whisper.decode(model_med, mel, options_transl)
|
61 |
+
|
62 |
+
# print the recognized text
|
63 |
+
print(f"transcript is : {result_transc.text}")
|
64 |
+
print(f"translation is : {result_transl.text}")
|
65 |
+
|
66 |
+
return result_transc.text, result_transl.text, lang
|
67 |
+
|
68 |
+
# Coqui - Text-to-Speech
|
69 |
+
def tts(text: str, model_name: str):
|
70 |
+
if len(text) > MAX_TXT_LEN:
|
71 |
+
text = text[:MAX_TXT_LEN]
|
72 |
+
print(f"Input text was cutoff since it went over the {MAX_TXT_LEN} character limit.")
|
73 |
+
print(text, model_name)
|
74 |
+
# download model
|
75 |
+
model_path, config_path, model_item = tts_manager.download_model(f"tts_models/{model_name}")
|
76 |
+
vocoder_name: Optional[str] = model_item["default_vocoder"]
|
77 |
+
# download vocoder
|
78 |
+
vocoder_path = None
|
79 |
+
vocoder_config_path = None
|
80 |
+
if vocoder_name is not None:
|
81 |
+
vocoder_path, vocoder_config_path, _ = tts_manager.download_model(vocoder_name)
|
82 |
+
# init synthesizer
|
83 |
+
synthesizer = Synthesizer(
|
84 |
+
model_path, config_path, None, None, vocoder_path, vocoder_config_path,
|
85 |
+
)
|
86 |
+
|
87 |
+
# synthesize
|
88 |
+
if synthesizer is None:
|
89 |
+
raise NameError("model not found")
|
90 |
+
wavs = synthesizer.tts(text)
|
91 |
+
|
92 |
+
# return output
|
93 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
94 |
+
synthesizer.save_wav(wavs, fp)
|
95 |
+
return fp.name
|
96 |
+
|
97 |
+
def engine(audio, context):
|
98 |
+
# Get voice query to text
|
99 |
+
transcribe, translation, lang = whisper_stt(audio)
|
100 |
+
|
101 |
+
# Get Query answer
|
102 |
+
answer = get_query_result(translation, context)
|
103 |
+
|
104 |
+
answer_speech = tts(answer, model_name= 'en/ljspeech/tacotron2-DDC_ph')
|
105 |
+
|
106 |
+
return translation, answer, answer_speech
|
107 |
+
|
108 |
+
|
109 |
+
def get_query_result(query, context):
|
110 |
+
|
111 |
+
QA_input = {
|
112 |
+
'question': query,
|
113 |
+
'context': context
|
114 |
+
}
|
115 |
+
answer = nlp(QA_input)['answer']
|
116 |
+
|
117 |
+
return answer
|
118 |
+
|
119 |
+
|
120 |
+
demo = gr.Blocks()
|
121 |
+
|
122 |
+
with demo:
|
123 |
+
gr.Markdown("<h1><center>Voice to QA</center></h1>")
|
124 |
+
gr.Markdown(
|
125 |
+
"""<center> An app to ask voice queries about a text article.</center>
|
126 |
+
"""
|
127 |
+
)
|
128 |
+
gr.Markdown(
|
129 |
+
"""Model pipeline consisting of - <br>- [**Whisper**](https://github.com/openai/whisper)for Speech-to-text, <br>- [**Roberta Base QA**](https://huggingface.co/deepset/roberta-base-squad2) for Question Answering, and <br>- [**CoquiTTS**](https://github.com/coqui-ai/TTS) for Text-To-Speech.
|
130 |
+
<br> Just type/paste your text in the context field, and then ask voice questions.""")
|
131 |
+
with gr.Column():
|
132 |
+
with gr.Row():
|
133 |
+
with gr.Column():
|
134 |
+
in_audio = gr.Audio(source="microphone", type="filepath", label='Record your voice query here in English, Spanish or French for best results-')
|
135 |
+
in_context = gr.Textbox(label="Context")
|
136 |
+
b1 = gr.Button("Generate Answer")
|
137 |
+
|
138 |
+
with gr.Column():
|
139 |
+
out_query = gr.Textbox('Your Query (Transcribed)')
|
140 |
+
out_audio = gr.Audio(label = 'Voice response')
|
141 |
+
out_textbox = gr.Textbox(label="Answer")
|
142 |
+
|
143 |
+
b1.click(engine, inputs=[in_audio, in_context], outputs=[out_query, out_textbox, out_audio])
|
144 |
+
|
145 |
+
#with gr.Row():
|
146 |
+
# gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=ysharma_Voice-to-Youtube)")
|
147 |
+
|
148 |
+
demo.launch(enable_queue=True, debug=True)
|