namratac commited on
Commit
1a32e6c
1 Parent(s): 51e6f46

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+ from base64 import b64encode
3
+ from io import BytesIO
4
+
5
+ from gradio import Audio, Interface, Textbox
6
+ from gtts import gTTS
7
+ from mtranslate import translate
8
+ from speech_recognition import AudioFile, Recognizer
9
+ from transformers import (BlenderbotSmallForConditionalGeneration,
10
+ BlenderbotSmallTokenizer)
11
+
12
+ def stt(audio: object, language: str) -> str:
13
+ """Converts speech to text.
14
+ Args:
15
+ audio: record of user speech
16
+ Returns:
17
+ text (str): recognized speech of user
18
+ """
19
+
20
+ # Create a Recognizer object
21
+ r = Recognizer()
22
+ # Open the audio file
23
+ with AudioFile(audio) as source:
24
+ # Listen for the data (load audio to memory)
25
+ audio_data = r.record(source)
26
+ # Transcribe the audio using Google's speech-to-text API
27
+ text = r.recognize_google(audio_data, language=language)
28
+ return text
29
+
30
+ def to_en_translation(text: str, language: str) -> str:
31
+ """Translates text from specified language to English.
32
+ Args:
33
+ text (str): input text
34
+ language (str): desired language
35
+ Returns:
36
+ str: translated text
37
+ """
38
+ return translate(text, "en", language)
39
+
40
+
41
+ def from_en_translation(text: str, language: str) -> str:
42
+ """Translates text from english to specified language.
43
+ Args:
44
+ text (str): input text
45
+ language (str): desired language
46
+ Returns:
47
+ str: translated text
48
+ """
49
+ return translate(text, language, "en")
50
+
51
+ class TextGenerationPipeline:
52
+ """Pipeline for text generation of blenderbot model.
53
+ Returns:
54
+ str: generated text
55
+ """
56
+
57
+ # load tokenizer and the model
58
+ model_name = "facebook/blenderbot_small-90M"
59
+ tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_name)
60
+ model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_name)
61
+
62
+ def __init__(self, **kwargs):
63
+ """Specififying text generation parameters.
64
+ For example: max_length=100 which generates text shorter than
65
+ 100 tokens. Visit:
66
+ https://huggingface.co/docs/transformers/main_classes/text_generation
67
+ for more parameters
68
+ """
69
+ self.__dict__.update(kwargs)
70
+
71
+ def preprocess(self, text) -> str:
72
+ """Tokenizes input text.
73
+ Args:
74
+ text (str): user specified text
75
+ Returns:
76
+ torch.Tensor (obj): text representation as tensors
77
+ """
78
+ return self.tokenizer(text, return_tensors="pt")
79
+
80
+ def postprocess(self, outputs) -> str:
81
+ """Converts tensors into text.
82
+ Args:
83
+ outputs (torch.Tensor obj): model text generation output
84
+ Returns:
85
+ str: generated text
86
+ """
87
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
88
+
89
+ def __call__(self, text: str) -> str:
90
+ """Generates text from input text.
91
+ Args:
92
+ text (str): user specified text
93
+ Returns:
94
+ str: generated text
95
+ """
96
+ tokenized_text = self.preprocess(text)
97
+ output = self.model.generate(**tokenized_text, **self.__dict__)
98
+ return self.postprocess(output)
99
+
100
+
101
+ def tts(text: str, language: str) -> object:
102
+ """Converts text into audio object.
103
+ Args:
104
+ text (str): generated answer of bot
105
+ Returns:
106
+ object: text to speech object
107
+ """
108
+ return gTTS(text=text, lang=language, slow=False)
109
+
110
+ def tts_to_bytesio(tts_object: object) -> bytes:
111
+ """Converts tts object to bytes.
112
+ Args:
113
+ tts_object (object): audio object obtained from gtts
114
+ Returns:
115
+ bytes: audio bytes
116
+ """
117
+ bytes_object = BytesIO()
118
+ tts_object.write_to_fp(bytes_object)
119
+ bytes_object.seek(0)
120
+ return bytes_object.getvalue()
121
+
122
+
123
+ def html_audio_autoplay(bytes: bytes) -> object:
124
+ """Creates html object for autoplaying audio at gradio app.
125
+ Args:
126
+ bytes (bytes): audio bytes
127
+ Returns:
128
+ object: html object that provides audio autoplaying
129
+ """
130
+ b64 = b64encode(bytes).decode()
131
+ html = f"""
132
+ <audio controls autoplay>
133
+ <source src="data:audio/wav;base64,{b64}" type="audio/wav">
134
+ </audio>
135
+ """
136
+ return html
137
+
138
+ max_answer_length=100
139
+ desired_language = "de"
140
+ response_generator_pipe = TextGenerationPipeline(max_length=max_answer_length)
141
+
142
+
143
+ def main(audio: object):
144
+ """Calls functions for deploying gradio app.
145
+
146
+ It responds both verbally and in text
147
+ by taking voice input from user.
148
+
149
+ Args:
150
+ audio (object): recorded speech of user
151
+
152
+ Returns:
153
+ tuple containing
154
+
155
+ - user_speech_text (str) : recognized speech
156
+ - bot_response_de (str) : translated answer of bot
157
+ - bot_response_en (str) : bot's original answer
158
+ - html (object) : autoplayer for bot's speech
159
+ """
160
+ user_speech_text = stt(audio, desired_language)
161
+ tranlated_text = to_en_translation(user_speech_text, desired_language)
162
+ bot_response_en = response_generator_pipe(tranlated_text)
163
+ bot_response_de = from_en_translation(bot_response_en, desired_language)
164
+ bot_voice = tts(bot_response_de, desired_language)
165
+ bot_voice_bytes = tts_to_bytesio(bot_voice)
166
+ html = html_audio_autoplay(bot_voice_bytes)
167
+ return user_speech_text, bot_response_de, bot_response_en, html
168
+
169
+ Interface(
170
+ fn=main,
171
+ inputs=[
172
+ Audio(
173
+ source="microphone",
174
+ type="filepath",
175
+ ),
176
+ ],
177
+ outputs=[
178
+ Textbox(label="You said: "),
179
+ Textbox(label="AI said: "),
180
+ Textbox(label="AI said (English): "),
181
+ "html",
182
+ ],
183
+ live=True,
184
+ allow_flagging="never",
185
+ ).launch()