reecursion commited on
Commit
df90a53
·
verified ·
1 Parent(s): 84d30ed

Create speech_conversation_app.py

Browse files
Files changed (1) hide show
  1. speech_conversation_app.py +325 -0
speech_conversation_app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+ import gradio as gr
6
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoModelForSpeechSeq2Seq
7
+ from datasets import load_dataset
8
+ import soundfile as sf
9
+
10
+ # Global variables to track latency
11
+ latency_ASR = 0.0
12
+ latency_LLM = 0.0
13
+ latency_TTS = 0.0
14
+
15
+ # Global variables to store conversation state
16
+ conversation_history = []
17
+ audio_output = None
18
+
19
+ # ASR Models
20
+ ASR_OPTIONS = {
21
+ "Whisper Small": "openai/whisper-small",
22
+ "Wav2Vec2": "facebook/wav2vec2-base-960h"
23
+ }
24
+
25
+ # LLM Models
26
+ LLM_OPTIONS = {
27
+ "Llama-2 7B Chat": "meta-llama/Llama-2-7b-chat-hf",
28
+ "Flan-T5 Small": "google/flan-t5-small"
29
+ }
30
+
31
+ # TTS Models
32
+ TTS_OPTIONS = {
33
+ "VITS": "espnet/kan-bayashi_ljspeech_vits",
34
+ "FastSpeech2": "espnet/kan-bayashi_ljspeech_fastspeech2"
35
+ }
36
+
37
+ # Load models
38
+ asr_models = {}
39
+ llm_models = {}
40
+ tts_models = {}
41
+
42
+ def load_asr_model(model_name):
43
+ """Load ASR model from Hugging Face"""
44
+ global asr_models
45
+
46
+ if model_name not in asr_models:
47
+ print(f"Loading ASR model: {model_name}")
48
+ model_id = ASR_OPTIONS[model_name]
49
+
50
+ if "whisper" in model_id:
51
+ asr_models[model_name] = pipeline("automatic-speech-recognition", model=model_id)
52
+ else: # wav2vec2
53
+ processor = AutoProcessor.from_pretrained(model_id)
54
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
55
+ asr_models[model_name] = {"processor": processor, "model": model}
56
+
57
+ return asr_models[model_name]
58
+
59
+ def load_llm_model(model_name):
60
+ """Load LLM model from Hugging Face"""
61
+ global llm_models
62
+
63
+ if model_name not in llm_models:
64
+ print(f"Loading LLM model: {model_name}")
65
+ model_id = LLM_OPTIONS[model_name]
66
+
67
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
68
+ model = AutoModelForCausalLM.from_pretrained(
69
+ model_id,
70
+ torch_dtype=torch.float16,
71
+ device_map="auto"
72
+ )
73
+
74
+ llm_models[model_name] = {
75
+ "model": model,
76
+ "tokenizer": tokenizer
77
+ }
78
+
79
+ return llm_models[model_name]
80
+
81
+ def load_tts_model(model_name):
82
+ """Load TTS model using ESPnet"""
83
+ global tts_models
84
+
85
+ if model_name not in tts_models:
86
+ print(f"Loading TTS model: {model_name}")
87
+ try:
88
+ # Import ESPnet TTS modules
89
+ from espnet2.bin.tts_inference import Text2Speech
90
+
91
+ model_id = TTS_OPTIONS[model_name]
92
+ tts = Text2Speech.from_pretrained(model_id)
93
+ tts_models[model_name] = tts
94
+
95
+ except ImportError:
96
+ print("ESPnet not installed. Using mock TTS for demonstration.")
97
+ tts_models[model_name] = "mock_tts"
98
+
99
+ return tts_models[model_name]
100
+
101
+ def transcribe_audio(audio_data, sr, asr_model_name):
102
+ """Transcribe audio using selected ASR model"""
103
+ global latency_ASR
104
+
105
+ start_time = time.time()
106
+
107
+ model = load_asr_model(asr_model_name)
108
+
109
+ if "whisper" in ASR_OPTIONS[asr_model_name]:
110
+ result = model({"array": audio_data, "sampling_rate": sr})
111
+ transcript = result["text"]
112
+ else: # wav2vec2
113
+ inputs = model["processor"](audio_data, sampling_rate=sr, return_tensors="pt")
114
+ with torch.no_grad():
115
+ outputs = model["model"].generate(**inputs)
116
+ transcript = model["processor"].batch_decode(outputs, skip_special_tokens=True)[0]
117
+
118
+ latency_ASR = time.time() - start_time
119
+ return transcript
120
+
121
+ def generate_response(transcript, llm_model_name, system_prompt):
122
+ """Generate response using selected LLM model"""
123
+ global latency_LLM, conversation_history
124
+
125
+ start_time = time.time()
126
+
127
+ model_info = load_llm_model(llm_model_name)
128
+ model = model_info["model"]
129
+ tokenizer = model_info["tokenizer"]
130
+
131
+ # Format the prompt based on the model
132
+ if "llama" in LLM_OPTIONS[llm_model_name].lower():
133
+ # Format for Llama models
134
+ if not conversation_history:
135
+ conversation_history.append({"role": "system", "content": system_prompt})
136
+
137
+ conversation_history.append({"role": "user", "content": transcript})
138
+
139
+ prompt = tokenizer.apply_chat_template(
140
+ conversation_history,
141
+ tokenize=False,
142
+ add_generation_prompt=True
143
+ )
144
+ else:
145
+ # Format for T5 models
146
+ prompt = f"{system_prompt}\nUser: {transcript}\nAssistant:"
147
+
148
+ # Generate text
149
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
150
+
151
+ with torch.no_grad():
152
+ outputs = model.generate(
153
+ input_ids,
154
+ max_new_tokens=100,
155
+ temperature=0.7,
156
+ top_p=0.9,
157
+ )
158
+
159
+ # Decode the response
160
+ if "llama" in LLM_OPTIONS[llm_model_name].lower():
161
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
162
+ # Extract just the assistant's response
163
+ response = response.split("Assistant: ")[-1].strip()
164
+ else:
165
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
166
+
167
+ # Add to conversation history
168
+ conversation_history.append({"role": "assistant", "content": response})
169
+
170
+ latency_LLM = time.time() - start_time
171
+ return response
172
+
173
+ def synthesize_speech(text, tts_model_name):
174
+ """Synthesize speech using selected TTS model"""
175
+ global latency_TTS
176
+
177
+ start_time = time.time()
178
+
179
+ tts = load_tts_model(tts_model_name)
180
+
181
+ if tts == "mock_tts":
182
+ # Mock TTS response for demonstration
183
+ # In a real implementation, this would use the ESPnet model
184
+ # Load a sample audio file for demonstration
185
+ try:
186
+ sample_rate = 16000
187
+ # Generate a simple sine wave as demo audio
188
+ duration = 2 # seconds
189
+ t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
190
+ audio_data = 0.5 * np.sin(2 * np.pi * 220 * t) # 220 Hz sine wave
191
+ except Exception as e:
192
+ print(f"Error generating mock audio: {e}")
193
+ audio_data = np.zeros(16000) # 1 second of silence
194
+ sample_rate = 16000
195
+ else:
196
+ # Use actual ESPnet TTS model
197
+ with torch.no_grad():
198
+ wav = tts(text)["wav"]
199
+ audio_data = wav.numpy()
200
+ sample_rate = tts.fs
201
+
202
+ latency_TTS = time.time() - start_time
203
+ return (sample_rate, audio_data)
204
+
205
+ def process_speech(
206
+ audio_input,
207
+ asr_option,
208
+ llm_option,
209
+ tts_option,
210
+ system_prompt
211
+ ):
212
+ """Process speech: ASR -> LLM -> TTS pipeline"""
213
+ global audio_output
214
+
215
+ # Check if audio input is available
216
+ if audio_input is None:
217
+ return None, "", "", None
218
+
219
+ # Get audio data
220
+ sr, audio_data = audio_input
221
+
222
+ # ASR: Speech to text
223
+ transcript = transcribe_audio(audio_data, sr, asr_option)
224
+
225
+ # LLM: Generate response
226
+ response = generate_response(transcript, llm_option, system_prompt)
227
+
228
+ # TTS: Text to speech
229
+ audio_output = synthesize_speech(response, tts_option)
230
+
231
+ # Return results
232
+ return audio_input, transcript, response, audio_output
233
+
234
+ def display_latency():
235
+ """Display latency information"""
236
+ return f"""
237
+ ASR Latency: {latency_ASR:.2f} seconds
238
+ LLM Latency: {latency_LLM:.2f} seconds
239
+ TTS Latency: {latency_TTS:.2f} seconds
240
+ Total Latency: {latency_ASR + latency_LLM + latency_TTS:.2f} seconds
241
+ """
242
+
243
+ def reset_conversation():
244
+ """Reset the conversation history"""
245
+ global conversation_history, audio_output
246
+ conversation_history = []
247
+ audio_output = None
248
+ return None, "", "", None, ""
249
+
250
+ # Create Gradio interface
251
+ with gr.Blocks(title="Conversational Speech System") as demo:
252
+ gr.Markdown(
253
+ """
254
+ # Conversational Speech System with ASR, LLM, and TTS
255
+
256
+ This demo showcases a complete speech-to-speech conversation system using:
257
+ - **ASR** (Automatic Speech Recognition) to convert your speech to text
258
+ - **LLM** (Large Language Model) to generate responses
259
+ - **TTS** (Text-to-Speech) to convert the responses to speech
260
+
261
+ Speak into your microphone and the system will respond with synthesized speech.
262
+ """
263
+ )
264
+
265
+ with gr.Row():
266
+ with gr.Column(scale=1):
267
+ # Input components
268
+ audio_input = gr.Audio(
269
+ sources=["microphone"],
270
+ type="numpy",
271
+ label="Speak here",
272
+ )
273
+
274
+ system_prompt = gr.Textbox(
275
+ label="System Prompt (instructions for the LLM)",
276
+ value="You are a helpful and friendly AI assistant. Keep your responses concise and under 3 sentences."
277
+ )
278
+
279
+ asr_dropdown = gr.Dropdown(
280
+ choices=list(ASR_OPTIONS.keys()),
281
+ value=list(ASR_OPTIONS.keys())[0],
282
+ label="Select ASR Model"
283
+ )
284
+
285
+ llm_dropdown = gr.Dropdown(
286
+ choices=list(LLM_OPTIONS.keys()),
287
+ value=list(LLM_OPTIONS.keys())[0],
288
+ label="Select LLM Model"
289
+ )
290
+
291
+ tts_dropdown = gr.Dropdown(
292
+ choices=list(TTS_OPTIONS.keys()),
293
+ value=list(TTS_OPTIONS.keys())[0],
294
+ label="Select TTS Model"
295
+ )
296
+
297
+ reset_btn = gr.Button("Reset Conversation")
298
+
299
+ with gr.Column(scale=1):
300
+ # Output components
301
+ user_transcript = gr.Textbox(label="Your Speech (ASR Output)")
302
+ system_response = gr.Textbox(label="AI Response (LLM Output)")
303
+ audio_output_component = gr.Audio(label="AI Voice Response", autoplay=True)
304
+ latency_info = gr.Textbox(label="Performance Metrics")
305
+
306
+ # Set up event handlers
307
+ audio_input.change(
308
+ process_speech,
309
+ inputs=[audio_input, asr_dropdown, llm_dropdown, tts_dropdown, system_prompt],
310
+ outputs=[audio_input, user_transcript, system_response, audio_output_component]
311
+ ).then(
312
+ display_latency,
313
+ inputs=[],
314
+ outputs=[latency_info]
315
+ )
316
+
317
+ reset_btn.click(
318
+ reset_conversation,
319
+ inputs=[],
320
+ outputs=[audio_input, user_transcript, system_response, audio_output_component, latency_info]
321
+ )
322
+
323
+ # Launch the app
324
+ if __name__ == "__main__":
325
+ demo.launch()