jayesh95 commited on
Commit
ddd3649
·
1 Parent(s): c3eee0f

Added application file

Browse files
Files changed (1) hide show
  1. voice_to_qa.py +148 -0
voice_to_qa.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cProfile import label
2
+ from unittest import result
3
+ import gradio as gr
4
+ import urllib.request
5
+ import re
6
+ import os
7
+ import requests
8
+ import torch
9
+
10
+ #Speech to text
11
+ import whisper
12
+
13
+ #QA
14
+ from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
15
+
16
+ #TTS
17
+ import tempfile
18
+ from TTS.utils.manage import ModelManager
19
+ from TTS.utils.synthesizer import Synthesizer
20
+ from typing import Optional
21
+
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ # Whisper: Speech-to-text
26
+ model = whisper.load_model("base", device = device)
27
+ model_med = whisper.load_model("small", device = device)
28
+
29
+ #Roberta Q&A
30
+ model_name = "deepset/tinyroberta-squad2"
31
+ nlp = pipeline('question-answering', model=model_name, tokenizer=model_name, device = 0)
32
+
33
+ #TTS
34
+ tts_manager = ModelManager()
35
+ MAX_TXT_LEN = 100
36
+
37
+
38
+
39
+ print(model.device)
40
+
41
+ # Whisper - speech-to-text
42
+ def whisper_stt(audio):
43
+ print("Inside Whisper TTS")
44
+ # load audio and pad/trim it to fit 30 seconds
45
+ audio = whisper.load_audio(audio)
46
+ audio = whisper.pad_or_trim(audio)
47
+
48
+ # make log-Mel spectrogram and move to the same device as the model
49
+ mel = whisper.log_mel_spectrogram(audio).to(model.device)
50
+
51
+ # detect the spoken language
52
+ _, probs = model.detect_language(mel)
53
+ lang = max(probs, key=probs.get)
54
+ print(f"Detected language: {max(probs, key=probs.get)}")
55
+
56
+ # decode the audio
57
+ options_transc = whisper.DecodingOptions(fp16 = False, language=lang, task='transcribe') #lang
58
+ options_transl = whisper.DecodingOptions(fp16 = False, language='en', task='translate') #lang
59
+ result_transc = whisper.decode(model_med, mel, options_transc)
60
+ result_transl = whisper.decode(model_med, mel, options_transl)
61
+
62
+ # print the recognized text
63
+ print(f"transcript is : {result_transc.text}")
64
+ print(f"translation is : {result_transl.text}")
65
+
66
+ return result_transc.text, result_transl.text, lang
67
+
68
+ # Coqui - Text-to-Speech
69
+ def tts(text: str, model_name: str):
70
+ if len(text) > MAX_TXT_LEN:
71
+ text = text[:MAX_TXT_LEN]
72
+ print(f"Input text was cutoff since it went over the {MAX_TXT_LEN} character limit.")
73
+ print(text, model_name)
74
+ # download model
75
+ model_path, config_path, model_item = tts_manager.download_model(f"tts_models/{model_name}")
76
+ vocoder_name: Optional[str] = model_item["default_vocoder"]
77
+ # download vocoder
78
+ vocoder_path = None
79
+ vocoder_config_path = None
80
+ if vocoder_name is not None:
81
+ vocoder_path, vocoder_config_path, _ = tts_manager.download_model(vocoder_name)
82
+ # init synthesizer
83
+ synthesizer = Synthesizer(
84
+ model_path, config_path, None, None, vocoder_path, vocoder_config_path,
85
+ )
86
+
87
+ # synthesize
88
+ if synthesizer is None:
89
+ raise NameError("model not found")
90
+ wavs = synthesizer.tts(text)
91
+
92
+ # return output
93
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
94
+ synthesizer.save_wav(wavs, fp)
95
+ return fp.name
96
+
97
+ def engine(audio, context):
98
+ # Get voice query to text
99
+ transcribe, translation, lang = whisper_stt(audio)
100
+
101
+ # Get Query answer
102
+ answer = get_query_result(translation, context)
103
+
104
+ answer_speech = tts(answer, model_name= 'en/ljspeech/tacotron2-DDC_ph')
105
+
106
+ return translation, answer, answer_speech
107
+
108
+
109
+ def get_query_result(query, context):
110
+
111
+ QA_input = {
112
+ 'question': query,
113
+ 'context': context
114
+ }
115
+ answer = nlp(QA_input)['answer']
116
+
117
+ return answer
118
+
119
+
120
+ demo = gr.Blocks()
121
+
122
+ with demo:
123
+ gr.Markdown("<h1><center>Voice to QA</center></h1>")
124
+ gr.Markdown(
125
+ """<center> An app to ask voice queries about a text article.</center>
126
+ """
127
+ )
128
+ gr.Markdown(
129
+ """Model pipeline consisting of - <br>- [**Whisper**](https://github.com/openai/whisper)for Speech-to-text, <br>- [**Roberta Base QA**](https://huggingface.co/deepset/roberta-base-squad2) for Question Answering, and <br>- [**CoquiTTS**](https://github.com/coqui-ai/TTS) for Text-To-Speech.
130
+ <br> Just type/paste your text in the context field, and then ask voice questions.""")
131
+ with gr.Column():
132
+ with gr.Row():
133
+ with gr.Column():
134
+ in_audio = gr.Audio(source="microphone", type="filepath", label='Record your voice query here in English, Spanish or French for best results-')
135
+ in_context = gr.Textbox(label="Context")
136
+ b1 = gr.Button("Generate Answer")
137
+
138
+ with gr.Column():
139
+ out_query = gr.Textbox('Your Query (Transcribed)')
140
+ out_audio = gr.Audio(label = 'Voice response')
141
+ out_textbox = gr.Textbox(label="Answer")
142
+
143
+ b1.click(engine, inputs=[in_audio, in_context], outputs=[out_query, out_textbox, out_audio])
144
+
145
+ #with gr.Row():
146
+ # gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=ysharma_Voice-to-Youtube)")
147
+
148
+ demo.launch(enable_queue=True, debug=True)