barreloflube commited on
Commit
70eeaf7
1 Parent(s): 099c588

Refactor code to update UI buttons in audio_tab()

Browse files
playground/refs/audio.m4a ADDED
Binary file (268 kB). View file
 
playground/refs/audio.npy ADDED
Binary file (102 kB). View file
 
playground/refs/test.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
playground/refs/test.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fastapi
2
+ import numpy as np
3
+ import torch
4
+ import torchaudio
5
+ from silero_vad import get_speech_timestamps, load_silero_vad
6
+ import whisperx
7
+ import edge_tts
8
+ import gc
9
+ import logging
10
+ import time
11
+ from openai import OpenAI
12
+ import threading
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
+
17
+ # Configure FastAPI
18
+ app = fastapi.FastAPI()
19
+
20
+ # Load Silero VAD model
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+ logging.info(f'Using device: {device}')
23
+ vad_model = load_silero_vad().to(device) # Ensure the model is on the correct device
24
+ logging.info('Loaded Silero VAD model')
25
+
26
+ # Load WhisperX model
27
+ whisper_model = whisperx.load_model("tiny", device, compute_type="float16")
28
+ logging.info('Loaded WhisperX model')
29
+
30
+ OPENAI_API_KEY = "sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C" # os.getenv("OPENAI_API_KEY")
31
+ if not OPENAI_API_KEY:
32
+ logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
33
+ raise ValueError("OpenAI API key not found.")
34
+
35
+ # Initialize OpenAI client
36
+ openai_client = OpenAI(api_key=OPENAI_API_KEY)
37
+ logging.info('Initialized OpenAI client')
38
+
39
+ # TTS Voice
40
+ TTS_VOICE = "en-GB-SoniaNeural"
41
+
42
+ # Function to check voice activity using Silero VAD
43
+ def check_vad(audio_data, sample_rate):
44
+ logging.info('Checking voice activity')
45
+ # Resample to 16000 Hz if necessary
46
+ target_sample_rate = 16000
47
+ if sample_rate != target_sample_rate:
48
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
49
+ audio_tensor = resampler(torch.from_numpy(audio_data))
50
+ else:
51
+ audio_tensor = torch.from_numpy(audio_data)
52
+ audio_tensor = audio_tensor.to(device)
53
+
54
+ # Log audio data details
55
+ logging.info(f'Audio tensor shape: {audio_tensor.shape}, dtype: {audio_tensor.dtype}, device: {audio_tensor.device}')
56
+
57
+ # Get speech timestamps
58
+ speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)
59
+ logging.info(f'Found {len(speech_timestamps)} speech timestamps')
60
+ return len(speech_timestamps) > 0
61
+
62
+ # Function to transcribe audio using WhisperX
63
+ def transcript(audio_data, sample_rate):
64
+ logging.info('Transcribing audio')
65
+ # Resample to 16000 Hz if necessary
66
+ target_sample_rate = 16000
67
+ if sample_rate != target_sample_rate:
68
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
69
+ audio_data = resampler(torch.from_numpy(audio_data)).numpy()
70
+ else:
71
+ audio_data = audio_data
72
+
73
+ # Transcribe
74
+ batch_size = 16 # Adjust as needed
75
+ result = whisper_model.transcribe(audio_data, batch_size=batch_size)
76
+ text = result["segments"][0]["text"] if len(result["segments"]) > 0 else ""
77
+ logging.info(f'Transcription result: {text}')
78
+ # Clear GPU memory
79
+ del result
80
+ gc.collect()
81
+ if device == 'cuda':
82
+ torch.cuda.empty_cache()
83
+ return text
84
+
85
+ # Function to get streaming response from OpenAI API
86
+ def llm(text):
87
+ logging.info('Getting response from OpenAI API')
88
+ response = openai_client.chat.completions.create(
89
+ model="gpt-4o", # Updated to a more recent model
90
+ messages=[
91
+ {"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."},
92
+ {"role": "user", "content": text}
93
+ ],
94
+ stream=True,
95
+ temperature=0.7, # Optional: Adjust as needed
96
+ top_p=0.9, # Optional: Adjust as needed
97
+ )
98
+ for chunk in response:
99
+ yield chunk.choices[0].delta.content
100
+
101
+ # Function to perform TTS per sentence using Edge-TTS
102
+ def tts_streaming(text_stream):
103
+ logging.info('Performing TTS')
104
+ buffer = ""
105
+ punctuation = {'.', '!', '?'}
106
+ for text_chunk in text_stream:
107
+ if text_chunk is not None:
108
+ buffer += text_chunk
109
+ # Check for sentence completion
110
+ sentences = []
111
+ start = 0
112
+ for i, char in enumerate(buffer):
113
+ if (char in punctuation):
114
+ sentences.append(buffer[start:i+1].strip())
115
+ start = i+1
116
+ buffer = buffer[start:]
117
+
118
+ for sentence in sentences:
119
+ if sentence:
120
+ communicate = edge_tts.Communicate(sentence, TTS_VOICE)
121
+ for chunk in communicate.stream_sync():
122
+ if chunk["type"] == "audio":
123
+ yield chunk["data"]
124
+ # Process any remaining text
125
+ if buffer.strip():
126
+ communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)
127
+ for chunk in communicate.stream_sync():
128
+ if chunk["type"] == "audio":
129
+ yield chunk["data"]
130
+
131
+ # Function to handle LLM and TTS
132
+ def llm_and_tts(transcribed_text, state):
133
+ logging.info('Handling LLM and TTS')
134
+ # Get streaming response from LLM
135
+ for text_chunk in llm(transcribed_text):
136
+ if state.get('stop_signal'):
137
+ logging.info('LLM and TTS task stopped')
138
+ break
139
+ # Get audio data from TTS
140
+ for audio_chunk in tts_streaming([text_chunk]):
141
+ if state.get('stop_signal'):
142
+ logging.info('LLM and TTS task stopped during TTS')
143
+ break
144
+ yield np.frombuffer(audio_chunk, dtype=np.int16)
145
+
146
+ state = {
147
+ 'mode': 'idle',
148
+ 'chunk_queue': [],
149
+ 'transcription': '',
150
+ 'in_transcription': False,
151
+ 'previous_no_vad_audio': [],
152
+ 'llm_task': None,
153
+ 'instream': None,
154
+ 'stop_signal': False,
155
+ 'args': {
156
+ 'sample_rate': 16000,
157
+ 'chunk_size': 0.5, # seconds
158
+ 'transcript_chunk_size': 2, # seconds
159
+ }
160
+ }
161
+
162
+ def transcript_loop():
163
+ while True:
164
+ if len(state['chunk_queue']) > 0:
165
+ accumulated_audio = np.concatenate(state['chunk_queue'])
166
+ total_samples = sum(len(chunk) for chunk in state['chunk_queue'])
167
+ total_duration = total_samples / state['sample_rate']
168
+
169
+ # Run transcription on the first 2 seconds if len > 3 seconds
170
+ if total_duration > 3.0 and state['in_transcription'] == True:
171
+ first_two_seconds_samples = int(2.0 * state['sample_rate'])
172
+ first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]
173
+ transcribed_text = transcript(first_two_seconds_audio, state['sample_rate'])
174
+ state['transcription'] += transcribed_text
175
+ remaining_audio = accumulated_audio[first_two_seconds_samples:]
176
+ state['chunk_queue'] = [remaining_audio]
177
+ else: # Run transcription on the accumulated audio
178
+ transcribed_text = transcript(accumulated_audio, state['sample_rate'])
179
+ state['transcription'] += transcribed_text
180
+ state['chunk_queue'] = []
181
+ state['in_transcription'] = False
182
+ else:
183
+ time.sleep(0.1)
184
+
185
+ if len(state['chunk_queue']) == 0 and state['mode'] == any(['idle', 'processing']):
186
+ state['in_transcription'] = False
187
+ break
188
+
189
+ def process_audio(audio_chunk):
190
+ # returns output audio
191
+
192
+ sample_rate, audio_data = audio_chunk
193
+ audio_data = np.array(audio_data, dtype=np.float32)
194
+
195
+ # convert to mono if necessary
196
+ if audio_data.ndim > 1:
197
+ audio_data = np.mean(audio_data, axis=1)
198
+
199
+ mode = state['mode']
200
+ chunk_queue = state['chunk_queue']
201
+ transcription = state['transcription']
202
+ in_transcription = state['in_transcription']
203
+ previous_no_vad_audio = state['previous_no_vad_audio']
204
+ llm_task = state['llm_task']
205
+ instream = state['instream']
206
+ stop_signal = state['stop_signal']
207
+ args = state['args']
208
+
209
+ args['sample_rate'] = sample_rate
210
+
211
+ # check for voice activity
212
+ vad = check_vad(audio_data, sample_rate)
213
+
214
+ if vad:
215
+ logging.info(f'Voice activity detected in mode: {mode}')
216
+ if mode == 'idle':
217
+ mode = 'listening'
218
+ elif mode == 'speaking':
219
+ # Stop llm and tts tasks
220
+ if llm_task and llm_task.is_alive():
221
+ # Implement task cancellation logic if possible
222
+ logging.info('Stopping LLM and TTS tasks')
223
+ # Since we cannot kill threads directly, we need to handle this in the tasks
224
+ stop_signal = True
225
+ llm_task.join()
226
+ mode = 'listening'
227
+
228
+ if mode == 'listening':
229
+ if previous_no_vad_audio is not None:
230
+ chunk_queue.append(previous_no_vad_audio)
231
+ previous_no_vad_audio = None
232
+ # Accumulate audio chunks
233
+ chunk_queue.append(audio_data)
234
+
235
+ # Start transcription thread if not already running
236
+ if not in_transcription:
237
+ in_transcription = True
238
+ transcription_task = threading.Thread(target=transcript_loop, args=(chunk_queue, sample_rate))
239
+ transcription_task.start()
240
+
241
+ elif mode == 'speaking':
242
+ # Continue accumulating audio chunks
243
+ chunk_queue.append(audio_data)
244
+ else:
245
+ logging.info(f'No voice activity detected in mode: {mode}')
246
+ if mode == 'listening':
247
+ # Add the last chunk to queue
248
+ chunk_queue.append(audio_data)
249
+
250
+ # Change mode to processing
251
+ mode = 'processing'
252
+
253
+ # Wait for transcription to complete
254
+ while in_transcription:
255
+ time.sleep(0.1)
256
+
257
+ # Check if transcription is complete
258
+ if len(chunk_queue) == 0:
259
+ # Start LLM and TTS tasks
260
+ if not llm_task or not llm_task.is_alive():
261
+ stop_signal = False
262
+ llm_task = threading.Thread(target=llm_and_tts, args=(transcription, state))
263
+ llm_task.start()
264
+
265
+ if mode == 'processing':
266
+ # Wait for LLM and TTS tasks to start yielding audio
267
+ if llm_task and llm_task.is_alive():
268
+ mode = 'responding'
269
+
270
+ if mode == 'responding':
271
+ for audio_chunk in llm_task:
272
+ if instream is None:
273
+ instream = audio_chunk
274
+ else:
275
+ instream = np.concatenate((instream, audio_chunk))
276
+
277
+ # Send audio to output stream
278
+ yield instream
279
+
280
+ # Cleanup
281
+ llm_task = None
282
+ transcription = ''
283
+ mode = 'idle'
284
+
285
+ # Updaate state
286
+ state['mode'] = mode
287
+ state['chunk_queue'] = chunk_queue
288
+ state['transcription'] = transcription
289
+ state['in_transcription'] = in_transcription
290
+ state['previous_no_vad_audio'] = previous_no_vad_audio
291
+ state['llm_task'] = llm_task
292
+ state['instream'] = instream
293
+ state['stop_signal'] = stop_signal
294
+ state['args'] = args
295
+
296
+ # Store previous audio chunk with no voice activity
297
+ previous_no_vad_audio = audio_data
298
+
299
+ # Update state
300
+ state['mode'] = mode
301
+ state['chunk_queue'] = chunk_queue
302
+ state['transcription'] = transcription
303
+ state['in_transcription'] = in_transcription
304
+ state['previous_no_vad_audio'] = previous_no_vad_audio
305
+ state['llm_task'] = llm_task
306
+ state['instream'] = instream
307
+ state['stop_signal'] = stop_signal
308
+ state['args'] = args
309
+
310
+
311
+ @app.websocket('/ws')
312
+ def websocket_endpoint(websocket: fastapi.WebSocket):
313
+ logging.info('WebSocket connection established')
314
+ try:
315
+ while True:
316
+ time.sleep(state['args']['chunk_size'])
317
+ audio_chunk = websocket.receive_bytes()
318
+ if audio_chunk is None:
319
+ break
320
+ for audio_data in process_audio(audio_chunk):
321
+ websocket.send_bytes(audio_data.tobytes())
322
+ except Exception as e:
323
+ logging.error(f'WebSocket error: {e}')
324
+ finally:
325
+ logging.info('WebSocket connection closed')
326
+ websocket.close()
327
+
328
+ @app.get('/')
329
+ def index():
330
+ return fastapi.FileResponse('index.html')
playground/testapp/audio.mp3 ADDED
Binary file (386 kB). View file
 
playground/testapp/index.html ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Voice Assistant</title>
7
+ <style>
8
+ body {
9
+ font-family: Arial, sans-serif;
10
+ margin: 20px;
11
+ }
12
+ #transcription {
13
+ margin-top: 20px;
14
+ padding: 10px;
15
+ border: 1px solid #ccc;
16
+ height: 150px;
17
+ overflow-y: auto;
18
+ }
19
+ #audio-player {
20
+ margin-top: 20px;
21
+ }
22
+ </style>
23
+ </head>
24
+ <body>
25
+ <h1>Voice Assistant</h1>
26
+ <button id="start-btn">Start Recording</button>
27
+ <button id="stop-btn" disabled>Stop Recording</button>
28
+ <div id="transcription"></div>
29
+ <audio id="audio-player" controls></audio>
30
+
31
+ <script>
32
+ const startBtn = document.getElementById('start-btn');
33
+ const stopBtn = document.getElementById('stop-btn');
34
+ const transcriptionDiv = document.getElementById('transcription');
35
+ const audioPlayer = document.getElementById('audio-player');
36
+ let websocket;
37
+ let mediaRecorder;
38
+ let audioChunks = [];
39
+
40
+ startBtn.addEventListener('click', async () => {
41
+ startBtn.disabled = true;
42
+ stopBtn.disabled = false;
43
+
44
+ websocket = new WebSocket('ws://localhost:8000/ws');
45
+ websocket.binaryType = 'arraybuffer';
46
+
47
+ websocket.onmessage = (event) => {
48
+ if (event.data instanceof ArrayBuffer) {
49
+ const audioBlob = new Blob([event.data], { type: 'audio/wav' });
50
+ audioPlayer.src = URL.createObjectURL(audioBlob);
51
+ audioPlayer.play();
52
+ } else {
53
+ transcriptionDiv.innerText += event.data + '\n';
54
+ }
55
+ };
56
+
57
+ const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
58
+ mediaRecorder = new MediaRecorder(stream);
59
+
60
+ mediaRecorder.ondataavailable = (event) => {
61
+ if (event.data.size > 0) {
62
+ audioChunks.push(event.data);
63
+ websocket.send(event.data);
64
+ }
65
+ };
66
+
67
+ mediaRecorder.start(1000); // Send audio data every second
68
+ });
69
+
70
+ stopBtn.addEventListener('click', () => {
71
+ startBtn.disabled = false;
72
+ stopBtn.disabled = true;
73
+
74
+ mediaRecorder.stop();
75
+ websocket.close();
76
+ });
77
+ </script>
78
+ </body>
79
+ </html>
playground/testapp/main.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fastapi
2
+ import numpy as np
3
+ import torch
4
+ import torchaudio
5
+ from silero_vad import get_speech_timestamps, load_silero_vad
6
+ import whisperx
7
+ import edge_tts
8
+ import gc
9
+ import logging
10
+ import time
11
+ import os
12
+ from openai import AsyncOpenAI
13
+ import asyncio
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
17
+
18
+ # Configure FastAPI
19
+ app = fastapi.FastAPI()
20
+
21
+ # Load Silero VAD model
22
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
+ logging.info(f'Using device: {device}')
24
+ vad_model = load_silero_vad().to(device)
25
+ logging.info('Loaded Silero VAD model')
26
+
27
+ # Load WhisperX model
28
+ whisper_model = whisperx.load_model("tiny", device, compute_type="float16")
29
+ logging.info('Loaded WhisperX model')
30
+
31
+ OPENAI_API_KEY = "sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C"
32
+ if not OPENAI_API_KEY:
33
+ logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
34
+ raise ValueError("OpenAI API key not found.")
35
+ logging.info('Initialized OpenAI client')
36
+ aclient = AsyncOpenAI(api_key=OPENAI_API_KEY) # Corrected import
37
+
38
+ # TTS Voice
39
+ TTS_VOICE = "en-GB-SoniaNeural"
40
+
41
+ # Function to check voice activity using Silero VAD
42
+ def check_vad(audio_data, sample_rate):
43
+ logging.info('Checking voice activity')
44
+ target_sample_rate = 16000
45
+ if sample_rate != target_sample_rate:
46
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
47
+ audio_tensor = resampler(torch.from_numpy(audio_data))
48
+ else:
49
+ audio_tensor = torch.from_numpy(audio_data)
50
+ audio_tensor = audio_tensor.to(device)
51
+
52
+ speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)
53
+ logging.info(f'Found {len(speech_timestamps)} speech timestamps')
54
+ return len(speech_timestamps) > 0
55
+
56
+ # Async function to transcribe audio using WhisperX
57
+ def transcript_sync(audio_data, sample_rate):
58
+ logging.info('Transcribing audio')
59
+ target_sample_rate = 16000
60
+ if sample_rate != target_sample_rate:
61
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
62
+ audio_data = resampler(torch.from_numpy(audio_data)).numpy()
63
+ else:
64
+ audio_data = audio_data
65
+
66
+ batch_size = 16 # Adjust as needed
67
+ result = whisper_model.transcribe(audio_data, batch_size=batch_size)
68
+ text = result["segments"][0]["text"] if len(result["segments"]) > 0 else ""
69
+ logging.info(f'Transcription result: {text}')
70
+ del result
71
+ gc.collect()
72
+ if device == 'cuda':
73
+ torch.cuda.empty_cache()
74
+ return text
75
+
76
+ async def transcript(audio_data, sample_rate):
77
+ loop = asyncio.get_running_loop()
78
+ text = await loop.run_in_executor(None, transcript_sync, audio_data, sample_rate)
79
+ return text
80
+
81
+ # Async function to get streaming response from OpenAI API
82
+ async def llm(text):
83
+ logging.info('Getting response from OpenAI API')
84
+ response = await aclient.chat.completions.create(model="gpt-4", # Updated to a more recent model
85
+ messages=[
86
+ {"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."},
87
+ {"role": "user", "content": text}
88
+ ],
89
+ stream=True,
90
+ temperature=0.7,
91
+ top_p=0.9)
92
+ async for chunk in response:
93
+ yield chunk.choices[0].delta.content
94
+
95
+ # Async function to perform TTS using Edge-TTS
96
+ async def tts_streaming(text_stream):
97
+ logging.info('Performing TTS')
98
+ buffer = ""
99
+ punctuation = {'.', '!', '?'}
100
+ for text_chunk in text_stream:
101
+ if text_chunk is not None:
102
+ buffer += text_chunk
103
+ # Check for sentence completion
104
+ sentences = []
105
+ start = 0
106
+ for i, char in enumerate(buffer):
107
+ if char in punctuation:
108
+ sentences.append(buffer[start:i+1].strip())
109
+ start = i+1
110
+ buffer = buffer[start:]
111
+
112
+ for sentence in sentences:
113
+ if sentence:
114
+ communicate = edge_tts.Communicate(sentence, TTS_VOICE)
115
+ async for chunk in communicate.stream():
116
+ if chunk["type"] == "audio":
117
+ yield chunk["data"]
118
+ # Process any remaining text
119
+ if buffer.strip():
120
+ communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)
121
+ async for chunk in communicate.stream():
122
+ if chunk["type"] == "audio":
123
+ yield chunk["data"]
124
+
125
+ class Conversation:
126
+ def __init__(self):
127
+ self.mode = 'idle'
128
+ self.chunk_queue = []
129
+ self.transcription = ''
130
+ self.in_transcription = False
131
+ self.previous_no_vad_audio = None
132
+ self.llm_task = None
133
+ self.transcription_task = None
134
+ self.stop_signal = False
135
+ self.sample_rate = 16000 # default sample rate
136
+ self.instream = None
137
+
138
+ async def process_audio(self, audio_chunk):
139
+ sample_rate, audio_data = audio_chunk
140
+ self.sample_rate = sample_rate
141
+ audio_data = np.array(audio_data, dtype=np.float32)
142
+
143
+ # convert to mono if necessary
144
+ if audio_data.ndim > 1:
145
+ audio_data = np.mean(audio_data, axis=1)
146
+
147
+ # check for voice activity
148
+ vad = check_vad(audio_data, sample_rate)
149
+
150
+ if vad:
151
+ logging.info(f'Voice activity detected in mode: {self.mode}')
152
+ if self.mode == 'idle':
153
+ self.mode = 'listening'
154
+ elif self.mode == 'speaking':
155
+ # Stop llm and tts tasks
156
+ if self.llm_task and not self.llm_task.done():
157
+ logging.info('Stopping LLM and TTS tasks')
158
+ self.stop_signal = True
159
+ await self.llm_task
160
+ self.mode = 'listening'
161
+
162
+ if self.mode == 'listening':
163
+ if self.previous_no_vad_audio is not None:
164
+ self.chunk_queue.append(self.previous_no_vad_audio)
165
+ self.previous_no_vad_audio = None
166
+ # Accumulate audio chunks
167
+ self.chunk_queue.append(audio_data)
168
+
169
+ # Start transcription task if not already running
170
+ if not self.in_transcription:
171
+ self.in_transcription = True
172
+ self.transcription_task = asyncio.create_task(self.transcript_loop())
173
+
174
+ else:
175
+ logging.info(f'No voice activity detected in mode: {self.mode}')
176
+ if self.mode == 'listening':
177
+ # Add the last chunk to queue
178
+ self.chunk_queue.append(audio_data)
179
+
180
+ # Change mode to processing
181
+ self.mode = 'processing'
182
+
183
+ # Wait for transcription to complete
184
+ while self.in_transcription:
185
+ await asyncio.sleep(0.1)
186
+
187
+ # Check if transcription is complete
188
+ if len(self.chunk_queue) == 0:
189
+ # Start LLM and TTS tasks
190
+ if not self.llm_task or self.llm_task.done():
191
+ self.stop_signal = False
192
+ self.llm_task = self.llm_and_tts()
193
+ self.mode = 'responding'
194
+
195
+ if self.mode == 'responding':
196
+ async for audio_chunk in self.llm_task:
197
+ if self.instream is None:
198
+ self.instream = audio_chunk
199
+ else:
200
+ self.instream = np.concatenate((self.instream, audio_chunk))
201
+ # Send audio to output stream
202
+ yield self.instream
203
+
204
+ # Cleanup
205
+ self.llm_task = None
206
+ self.transcription = ''
207
+ self.mode = 'idle'
208
+ self.instream = None
209
+
210
+ # Store previous audio chunk with no voice activity
211
+ self.previous_no_vad_audio = audio_data
212
+
213
+ async def transcript_loop(self):
214
+ while True:
215
+ if len(self.chunk_queue) > 0:
216
+ accumulated_audio = np.concatenate(self.chunk_queue)
217
+ total_samples = len(accumulated_audio)
218
+ total_duration = total_samples / self.sample_rate
219
+
220
+ if total_duration > 3.0 and self.in_transcription == True:
221
+ first_two_seconds_samples = int(2.0 * self.sample_rate)
222
+ first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]
223
+ transcribed_text = await transcript(first_two_seconds_audio, self.sample_rate)
224
+ self.transcription += transcribed_text
225
+ remaining_audio = accumulated_audio[first_two_seconds_samples:]
226
+ self.chunk_queue = [remaining_audio]
227
+ else:
228
+ transcribed_text = await transcript(accumulated_audio, self.sample_rate)
229
+ self.transcription += transcribed_text
230
+ self.chunk_queue = []
231
+ self.in_transcription = False
232
+ else:
233
+ await asyncio.sleep(0.1)
234
+
235
+ if len(self.chunk_queue) == 0 and self.mode in ['idle', 'processing']:
236
+ self.in_transcription = False
237
+ break
238
+
239
+ async def llm_and_tts(self):
240
+ logging.info('Handling LLM and TTS')
241
+ async for text_chunk in llm(self.transcription):
242
+ if self.stop_signal:
243
+ logging.info('LLM and TTS task stopped')
244
+ break
245
+ async for audio_chunk in tts_streaming([text_chunk]):
246
+ if self.stop_signal:
247
+ logging.info('LLM and TTS task stopped during TTS')
248
+ break
249
+ yield np.frombuffer(audio_chunk, dtype=np.int16)
250
+
251
+ @app.websocket('/ws')
252
+ async def websocket_endpoint(websocket: fastapi.WebSocket):
253
+ await websocket.accept()
254
+ logging.info('WebSocket connection established')
255
+ conversation = Conversation()
256
+ audio_buffer = []
257
+ buffer_duration = 0.5 # 500ms
258
+ try:
259
+ while True:
260
+ audio_chunk_bytes = await websocket.receive_bytes()
261
+ if audio_chunk_bytes is None:
262
+ break
263
+
264
+ audio_chunk = (conversation.sample_rate, np.frombuffer(audio_chunk_bytes, dtype=np.int16))
265
+ audio_buffer.append(audio_chunk[1])
266
+
267
+ # Calculate the duration of the buffered audio
268
+ total_samples = sum(len(chunk) for chunk in audio_buffer)
269
+ total_duration = total_samples / conversation.sample_rate
270
+
271
+ if total_duration >= buffer_duration:
272
+ # Concatenate buffered audio chunks
273
+ buffered_audio = np.concatenate(audio_buffer)
274
+ audio_buffer = [] # Reset buffer
275
+
276
+ # Process the buffered audio
277
+ async for audio_data in conversation.process_audio((conversation.sample_rate, buffered_audio)):
278
+ if audio_data is not None:
279
+ await websocket.send_bytes(audio_data.tobytes())
280
+ except Exception as e:
281
+ logging.error(f'WebSocket error: {e}')
282
+ finally:
283
+ logging.info('WebSocket connection closed')
284
+ await websocket.close()
285
+
286
+ @app.get('/')
287
+ def index():
288
+ return fastapi.responses.FileResponse('index.html')
289
+
290
+ if __name__ == '__main__':
291
+ import uvicorn
292
+ uvicorn.run(app, host='0.0.0.0', port=8000)
playground/testapp/test.ipynb ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import fastapi\n",
10
+ "import numpy as np\n",
11
+ "import torch\n",
12
+ "import torchaudio\n",
13
+ "from silero_vad import get_speech_timestamps, load_silero_vad\n",
14
+ "import whisperx\n",
15
+ "import edge_tts\n",
16
+ "import gc\n",
17
+ "import logging\n",
18
+ "import time\n",
19
+ "from openai import OpenAI\n",
20
+ "import threading\n",
21
+ "import asyncio\n",
22
+ "\n",
23
+ "# Configure logging\n",
24
+ "logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
25
+ "\n",
26
+ "# Configure FastAPI\n",
27
+ "app = fastapi.FastAPI()\n",
28
+ "\n",
29
+ "# Load Silero VAD model\n",
30
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
31
+ "logging.info(f'Using device: {device}')\n",
32
+ "vad_model = load_silero_vad().to(device) # Ensure the model is on the correct device\n",
33
+ "logging.info('Loaded Silero VAD model')\n",
34
+ "\n",
35
+ "# Load WhisperX model\n",
36
+ "whisper_model = whisperx.load_model(\"tiny\", device, compute_type=\"float16\")\n",
37
+ "logging.info('Loaded WhisperX model')\n",
38
+ "\n",
39
+ "# OpenAI API Key from environment variable for security\n",
40
+ "OPENAI_API_KEY = \"sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C\" # os.getenv(\"OPENAI_API_KEY\")\n",
41
+ "if not OPENAI_API_KEY:\n",
42
+ " logging.error(\"OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.\")\n",
43
+ " raise ValueError(\"OpenAI API key not found.\")\n",
44
+ "\n",
45
+ "# Initialize OpenAI client\n",
46
+ "openai_client = OpenAI(api_key=OPENAI_API_KEY)\n",
47
+ "logging.info('Initialized OpenAI client')\n",
48
+ "\n",
49
+ "# TTS Voice\n",
50
+ "TTS_VOICE = \"en-GB-SoniaNeural\"\n",
51
+ "\n",
52
+ "# Function to check voice activity using Silero VAD\n",
53
+ "def check_vad(audio_data, sample_rate):\n",
54
+ " logging.info('Checking voice activity')\n",
55
+ " # Resample to 16000 Hz if necessary\n",
56
+ " target_sample_rate = 16000\n",
57
+ " if sample_rate != target_sample_rate:\n",
58
+ " resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)\n",
59
+ " audio_tensor = resampler(torch.from_numpy(audio_data))\n",
60
+ " else:\n",
61
+ " audio_tensor = torch.from_numpy(audio_data)\n",
62
+ " audio_tensor = audio_tensor.to(device)\n",
63
+ "\n",
64
+ " # Log audio data details\n",
65
+ " logging.info(f'Audio tensor shape: {audio_tensor.shape}, dtype: {audio_tensor.dtype}, device: {audio_tensor.device}')\n",
66
+ "\n",
67
+ " # Get speech timestamps\n",
68
+ " speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)\n",
69
+ " logging.info(f'Found {len(speech_timestamps)} speech timestamps')\n",
70
+ " return len(speech_timestamps) > 0\n",
71
+ "\n",
72
+ "# Function to transcribe audio using WhisperX\n",
73
+ "def transcript(audio_data, sample_rate):\n",
74
+ " logging.info('Transcribing audio')\n",
75
+ " # Resample to 16000 Hz if necessary\n",
76
+ " target_sample_rate = 16000\n",
77
+ " if sample_rate != target_sample_rate:\n",
78
+ " resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)\n",
79
+ " audio_data = resampler(torch.from_numpy(audio_data)).numpy()\n",
80
+ " else:\n",
81
+ " audio_data = audio_data\n",
82
+ "\n",
83
+ " # Transcribe\n",
84
+ " batch_size = 16 # Adjust as needed\n",
85
+ " result = whisper_model.transcribe(audio_data, batch_size=batch_size)\n",
86
+ " text = result[\"segments\"][0][\"text\"] if len(result[\"segments\"]) > 0 else \"\"\n",
87
+ " logging.info(f'Transcription result: {text}')\n",
88
+ " # Clear GPU memory\n",
89
+ " del result\n",
90
+ " gc.collect()\n",
91
+ " if device == 'cuda':\n",
92
+ " torch.cuda.empty_cache()\n",
93
+ " return text\n",
94
+ "\n",
95
+ "# Function to get streaming response from OpenAI API\n",
96
+ "def llm(text):\n",
97
+ " logging.info('Getting response from OpenAI API')\n",
98
+ " response = openai_client.chat.completions.create(\n",
99
+ " model=\"gpt-4o\", # Updated to a more recent model\n",
100
+ " messages=[\n",
101
+ " {\"role\": \"system\", \"content\": \"You respond to the following transcript from the conversation that you are having with the user.\"},\n",
102
+ " {\"role\": \"user\", \"content\": text} \n",
103
+ " ],\n",
104
+ " stream=True,\n",
105
+ " temperature=0.7, # Optional: Adjust as needed\n",
106
+ " top_p=0.9, # Optional: Adjust as needed\n",
107
+ " )\n",
108
+ " for chunk in response:\n",
109
+ " yield chunk.choices[0].delta.content\n",
110
+ "\n",
111
+ "# Function to perform TTS per sentence using Edge-TTS\n",
112
+ "def tts_streaming(text_stream):\n",
113
+ " logging.info('Performing TTS')\n",
114
+ " buffer = \"\"\n",
115
+ " punctuation = {'.', '!', '?'}\n",
116
+ " for text_chunk in text_stream:\n",
117
+ " if text_chunk is not None:\n",
118
+ " buffer += text_chunk\n",
119
+ " # Check for sentence completion\n",
120
+ " sentences = []\n",
121
+ " start = 0\n",
122
+ " for i, char in enumerate(buffer):\n",
123
+ " if (char in punctuation):\n",
124
+ " sentences.append(buffer[start:i+1].strip())\n",
125
+ " start = i+1\n",
126
+ " buffer = buffer[start:]\n",
127
+ "\n",
128
+ " for sentence in sentences:\n",
129
+ " if sentence:\n",
130
+ " communicate = edge_tts.Communicate(sentence, TTS_VOICE)\n",
131
+ " for chunk in communicate.stream_sync():\n",
132
+ " if chunk[\"type\"] == \"audio\":\n",
133
+ " yield chunk[\"data\"]\n",
134
+ " # Process any remaining text\n",
135
+ " if buffer.strip():\n",
136
+ " communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)\n",
137
+ " for chunk in communicate.stream_sync():\n",
138
+ " if chunk[\"type\"] == \"audio\":\n",
139
+ " yield chunk[\"data\"]\n",
140
+ "\n",
141
+ "# Function to handle LLM and TTS\n",
142
+ "def llm_and_tts(transcribed_text):\n",
143
+ " logging.info('Handling LLM and TTS')\n",
144
+ " # Get streaming response from LLM\n",
145
+ " for text_chunk in llm(transcribed_text):\n",
146
+ " if state.get('stop_signal'):\n",
147
+ " logging.info('LLM and TTS task stopped')\n",
148
+ " break\n",
149
+ " # Get audio data from TTS\n",
150
+ " for audio_chunk in tts_streaming([text_chunk]):\n",
151
+ " if state.get('stop_signal'):\n",
152
+ " logging.info('LLM and TTS task stopped during TTS')\n",
153
+ " break\n",
154
+ " yield np.frombuffer(audio_chunk, dtype=np.int16)\n",
155
+ "\n",
156
+ "state = {\n",
157
+ " 'mode': 'idle',\n",
158
+ " 'chunk_queue': [],\n",
159
+ " 'transcription': '',\n",
160
+ " 'in_transcription': False,\n",
161
+ " 'previous_no_vad_audio': [],\n",
162
+ " 'llm_task': None,\n",
163
+ " 'instream': None,\n",
164
+ " 'stop_signal': False,\n",
165
+ " 'args': {\n",
166
+ " 'sample_rate': 16000,\n",
167
+ " 'chunk_size': 0.5, # seconds\n",
168
+ " 'transcript_chunk_size': 2, # seconds\n",
169
+ " }\n",
170
+ "}\n",
171
+ "\n",
172
+ "def transcript_loop():\n",
173
+ " while True:\n",
174
+ " if len(state['chunk_queue']) > 0:\n",
175
+ " accumulated_audio = np.concatenate(state['chunk_queue'])\n",
176
+ " total_samples = sum(len(chunk) for chunk in state['chunk_queue'])\n",
177
+ " total_duration = total_samples / state['args']['sample_rate']\n",
178
+ " \n",
179
+ " # Run transcription on the first 2 seconds if len > 3 seconds\n",
180
+ " if total_duration > 3.0 and state['in_transcription'] == True:\n",
181
+ " first_two_seconds_samples = int(2.0 * state['args']['sample_rate'])\n",
182
+ " first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]\n",
183
+ " transcribed_text = transcript(first_two_seconds_audio, state['args']['sample_rate'])\n",
184
+ " state['transcription'] += transcribed_text\n",
185
+ " remaining_audio = accumulated_audio[first_two_seconds_samples:]\n",
186
+ " state['chunk_queue'] = [remaining_audio]\n",
187
+ " else: # Run transcription on the accumulated audio\n",
188
+ " transcribed_text = transcript(accumulated_audio, state['args']['sample_rate'])\n",
189
+ " state['transcription'] += transcribed_text\n",
190
+ " state['chunk_queue'] = []\n",
191
+ " state['in_transcription'] = False\n",
192
+ " else:\n",
193
+ " time.sleep(0.1)\n",
194
+ "\n",
195
+ " if len(state['chunk_queue']) == 0 and state['mode'] == any(['idle', 'processing']):\n",
196
+ " state['in_transcription'] = False\n",
197
+ " break\n",
198
+ "\n",
199
+ "def process_audio(audio_chunk):\n",
200
+ " # returns output audio\n",
201
+ " \n",
202
+ " sample_rate, audio_data = audio_chunk\n",
203
+ " audio_data = np.array(audio_data, dtype=np.float32)\n",
204
+ " \n",
205
+ " # convert to mono if necessary\n",
206
+ " if audio_data.ndim > 1:\n",
207
+ " audio_data = np.mean(audio_data, axis=1)\n",
208
+ "\n",
209
+ " mode = state['mode']\n",
210
+ " chunk_queue = state['chunk_queue']\n",
211
+ " transcription = state['transcription']\n",
212
+ " in_transcription = state['in_transcription']\n",
213
+ " previous_no_vad_audio = state['previous_no_vad_audio']\n",
214
+ " llm_task = state['llm_task']\n",
215
+ " instream = state['instream']\n",
216
+ " stop_signal = state['stop_signal']\n",
217
+ " args = state['args']\n",
218
+ " \n",
219
+ " args['sample_rate'] = sample_rate\n",
220
+ " \n",
221
+ " # check for voice activity\n",
222
+ " vad = check_vad(audio_data, sample_rate)\n",
223
+ " \n",
224
+ " if vad:\n",
225
+ " logging.info(f'Voice activity detected in mode: {mode}')\n",
226
+ " if mode == 'idle':\n",
227
+ " mode = 'listening'\n",
228
+ " elif mode == 'speaking':\n",
229
+ " # Stop llm and tts tasks\n",
230
+ " if llm_task and llm_task.is_alive():\n",
231
+ " # Implement task cancellation logic if possible\n",
232
+ " logging.info('Stopping LLM and TTS tasks')\n",
233
+ " # Since we cannot kill threads directly, we need to handle this in the tasks\n",
234
+ " stop_signal = True\n",
235
+ " llm_task.join()\n",
236
+ " mode = 'listening'\n",
237
+ "\n",
238
+ " if mode == 'listening':\n",
239
+ " if previous_no_vad_audio is not None:\n",
240
+ " chunk_queue.append(previous_no_vad_audio)\n",
241
+ " previous_no_vad_audio = None\n",
242
+ " # Accumulate audio chunks\n",
243
+ " chunk_queue.append(audio_data)\n",
244
+ " \n",
245
+ " # Start transcription thread if not already running\n",
246
+ " if not in_transcription:\n",
247
+ " in_transcription = True\n",
248
+ " transcription_task = threading.Thread(target=transcript_loop)\n",
249
+ " transcription_task.start()\n",
250
+ " \n",
251
+ " elif mode == 'speaking':\n",
252
+ " # Continue accumulating audio chunks\n",
253
+ " chunk_queue.append(audio_data)\n",
254
+ " else:\n",
255
+ " logging.info(f'No voice activity detected in mode: {mode}')\n",
256
+ " if mode == 'listening':\n",
257
+ " # Add the last chunk to queue\n",
258
+ " chunk_queue.append(audio_data)\n",
259
+ " \n",
260
+ " # Change mode to processing\n",
261
+ " mode = 'processing'\n",
262
+ " \n",
263
+ " # Wait for transcription to complete\n",
264
+ " while in_transcription:\n",
265
+ " time.sleep(0.1)\n",
266
+ " \n",
267
+ " # Check if transcription is complete\n",
268
+ " if len(chunk_queue) == 0:\n",
269
+ " # Start LLM and TTS tasks\n",
270
+ " if not llm_task or not llm_task.is_alive():\n",
271
+ " stop_signal = False\n",
272
+ " llm_task = threading.Thread(target=llm_and_tts, args=(transcription))\n",
273
+ " llm_task.start()\n",
274
+ " \n",
275
+ " if mode == 'processing':\n",
276
+ " # Wait for LLM and TTS tasks to start yielding audio\n",
277
+ " if llm_task and llm_task.is_alive():\n",
278
+ " mode = 'responding'\n",
279
+ " \n",
280
+ " if mode == 'responding':\n",
281
+ " for audio_chunk in llm_task:\n",
282
+ " if instream is None:\n",
283
+ " instream = audio_chunk\n",
284
+ " else:\n",
285
+ " instream = np.concatenate((instream, audio_chunk))\n",
286
+ " \n",
287
+ " # Send audio to output stream\n",
288
+ " yield instream\n",
289
+ " \n",
290
+ " # Cleanup\n",
291
+ " llm_task = None\n",
292
+ " transcription = ''\n",
293
+ " mode = 'idle'\n",
294
+ " \n",
295
+ " # Updaate state\n",
296
+ " state['mode'] = mode\n",
297
+ " state['chunk_queue'] = chunk_queue\n",
298
+ " state['transcription'] = transcription\n",
299
+ " state['in_transcription'] = in_transcription\n",
300
+ " state['previous_no_vad_audio'] = previous_no_vad_audio\n",
301
+ " state['llm_task'] = llm_task\n",
302
+ " state['instream'] = instream\n",
303
+ " state['stop_signal'] = stop_signal\n",
304
+ " state['args'] = args\n",
305
+ " \n",
306
+ " # Store previous audio chunk with no voice activity\n",
307
+ " previous_no_vad_audio = audio_data\n",
308
+ " \n",
309
+ " # Update state\n",
310
+ " state['mode'] = mode\n",
311
+ " state['chunk_queue'] = chunk_queue\n",
312
+ " state['transcription'] = transcription\n",
313
+ " state['in_transcription'] = in_transcription\n",
314
+ " state['previous_no_vad_audio'] = previous_no_vad_audio\n",
315
+ " state['llm_task'] = llm_task\n",
316
+ " state['instream'] = instream\n",
317
+ " state['stop_signal'] = stop_signal\n",
318
+ " state['args'] = args"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": null,
324
+ "metadata": {},
325
+ "outputs": [],
326
+ "source": [
327
+ "# 1. Load audio.mp3\n",
328
+ "# 2. Split audio into chunks\n",
329
+ "# 3. Process each chunk inside a loop\n",
330
+ "\n",
331
+ "# Split audio into chunks of 500 ms or less\n",
332
+ "from pydub import AudioSegment\n",
333
+ "audio_segment = AudioSegment.from_file('audio.mp3')\n",
334
+ "chunks = [chunk for chunk in audio_segment[::500]]\n",
335
+ "chunks[0]\n",
336
+ "chunks = [(chunk.frame_rate, np.array(chunk.get_array_of_samples(), dtype=np.int16)) for chunk in chunks]\n",
337
+ "\n",
338
+ "output_audio = []\n",
339
+ "# Process each chunk\n",
340
+ "for chunk in chunks:\n",
341
+ " for audio_chunk in process_audio(chunk):\n",
342
+ " output_audio.append(audio_chunk)"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "execution_count": null,
348
+ "metadata": {},
349
+ "outputs": [],
350
+ "source": [
351
+ "output_audio"
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "code",
356
+ "execution_count": null,
357
+ "metadata": {},
358
+ "outputs": [],
359
+ "source": [
360
+ "import asyncio\n",
361
+ "import websockets\n",
362
+ "from pydub import AudioSegment\n",
363
+ "import numpy as np\n",
364
+ "import simpleaudio as sa\n",
365
+ "\n",
366
+ "# Constants\n",
367
+ "AUDIO_FILE = 'audio.mp3' # Input audio file\n",
368
+ "CHUNK_DURATION_MS = 250 # Duration of each chunk in milliseconds\n",
369
+ "WEBSOCKET_URI = 'ws://localhost:8000/ws' # WebSocket endpoint\n",
370
+ "\n",
371
+ "async def send_audio_chunks(uri):\n",
372
+ " # Load audio file using pydub\n",
373
+ " audio = AudioSegment.from_file(AUDIO_FILE)\n",
374
+ "\n",
375
+ " # Ensure audio is mono and 16kHz\n",
376
+ " if audio.channels > 1:\n",
377
+ " audio = audio.set_channels(1)\n",
378
+ " if audio.frame_rate != 16000:\n",
379
+ " audio = audio.set_frame_rate(16000)\n",
380
+ " if audio.sample_width != 2: # 2 bytes for int16\n",
381
+ " audio = audio.set_sample_width(2)\n",
382
+ "\n",
383
+ " # Split audio into chunks\n",
384
+ " chunks = [audio[i:i+CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)]\n",
385
+ "\n",
386
+ " # Store received audio data\n",
387
+ " received_audio_data = b''\n",
388
+ "\n",
389
+ " async with websockets.connect(uri) as websocket:\n",
390
+ " print(\"Connected to server.\")\n",
391
+ " for idx, chunk in enumerate(chunks):\n",
392
+ " # Get raw audio data\n",
393
+ " raw_data = chunk.raw_data\n",
394
+ "\n",
395
+ " # Send audio chunk to server\n",
396
+ " await websocket.send(raw_data)\n",
397
+ " print(f\"Sent chunk {idx+1}/{len(chunks)}\")\n",
398
+ "\n",
399
+ " # Receive response (non-blocking)\n",
400
+ " try:\n",
401
+ " response = await asyncio.wait_for(websocket.recv(), timeout=0.1)\n",
402
+ " if isinstance(response, bytes):\n",
403
+ " received_audio_data += response\n",
404
+ " print(f\"Received audio data of length {len(response)} bytes\")\n",
405
+ " except asyncio.TimeoutError:\n",
406
+ " pass # No response received yet\n",
407
+ "\n",
408
+ " # Simulate real-time by waiting for chunk duration\n",
409
+ " await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)\n",
410
+ "\n",
411
+ " # Send a final empty message to indicate end of transmission\n",
412
+ " await websocket.send(b'')\n",
413
+ " print(\"Finished sending audio. Waiting for responses...\")\n",
414
+ "\n",
415
+ " # Receive any remaining responses\n",
416
+ " while True:\n",
417
+ " try:\n",
418
+ " response = await asyncio.wait_for(websocket.recv(), timeout=1)\n",
419
+ " if isinstance(response, bytes):\n",
420
+ " received_audio_data += response\n",
421
+ " print(f\"Received audio data of length {len(response)} bytes\")\n",
422
+ " except asyncio.TimeoutError:\n",
423
+ " print(\"No more responses. Closing connection.\")\n",
424
+ " break\n",
425
+ "\n",
426
+ " print(\"Connection closed.\")\n",
427
+ "\n",
428
+ " # Save received audio data to a file or play it\n",
429
+ " if received_audio_data:\n",
430
+ " # Convert bytes to numpy array\n",
431
+ " audio_array = np.frombuffer(received_audio_data, dtype=np.int16)\n",
432
+ "\n",
433
+ " # Play audio using simpleaudio\n",
434
+ " play_obj = sa.play_buffer(audio_array, 1, 2, 16000)\n",
435
+ " play_obj.wait_done()\n",
436
+ "\n",
437
+ " # Optionally, save to a WAV file\n",
438
+ " output_audio = AudioSegment(\n",
439
+ " data=received_audio_data,\n",
440
+ " sample_width=2, # 2 bytes for int16\n",
441
+ " frame_rate=16000,\n",
442
+ " channels=1\n",
443
+ " )\n",
444
+ " output_audio.export(\"output_response.wav\", format=\"wav\")\n",
445
+ " print(\"Saved response audio to 'output_response.wav'\")\n",
446
+ " else:\n",
447
+ " print(\"No audio data received.\")\n",
448
+ "\n",
449
+ "def main():\n",
450
+ " asyncio.run(send_audio_chunks(WEBSOCKET_URI))\n",
451
+ "\n",
452
+ "if __name__ == '__main__':\n",
453
+ " main()"
454
+ ]
455
+ }
456
+ ],
457
+ "metadata": {
458
+ "kernelspec": {
459
+ "display_name": ".venv",
460
+ "language": "python",
461
+ "name": "python3"
462
+ },
463
+ "language_info": {
464
+ "codemirror_mode": {
465
+ "name": "ipython",
466
+ "version": 3
467
+ },
468
+ "file_extension": ".py",
469
+ "mimetype": "text/x-python",
470
+ "name": "python",
471
+ "nbconvert_exporter": "python",
472
+ "pygments_lexer": "ipython3",
473
+ "version": "3.10.12"
474
+ }
475
+ },
476
+ "nbformat": 4,
477
+ "nbformat_minor": 2
478
+ }
playground/testapp/test.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fastapi
2
+ import numpy as np
3
+ import torch
4
+ import torchaudio
5
+ from silero_vad import get_speech_timestamps, load_silero_vad
6
+ import whisperx
7
+ import edge_tts
8
+ import gc
9
+ import logging
10
+ import time
11
+ import os
12
+ from openai import OpenAI
13
+ import asyncio
14
+ from pydub import AudioSegment
15
+ from io import BytesIO
16
+ import threading
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
+
21
+ # Configure FastAPI
22
+ app = fastapi.FastAPI()
23
+
24
+ # Load Silero VAD model
25
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
26
+ logging.info(f'Using device: {device}')
27
+ vad_model = load_silero_vad().to(device)
28
+ logging.info('Loaded Silero VAD model')
29
+
30
+ # Load WhisperX model
31
+ whisper_model = whisperx.load_model("tiny", device, compute_type="float16")
32
+ logging.info('Loaded WhisperX model')
33
+
34
+ OPENAI_API_KEY = "sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C"
35
+ if not OPENAI_API_KEY:
36
+ logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
37
+ raise ValueError("OpenAI API key not found.")
38
+ logging.info('Initialized OpenAI client')
39
+ llm_client = OpenAI(api_key=OPENAI_API_KEY) # Corrected import
40
+
41
+ # TTS Voice
42
+ TTS_VOICE = "en-GB-SoniaNeural"
43
+
44
+ # Function to check voice activity using Silero VAD
45
+ def check_vad(audio_data, sample_rate):
46
+ logging.info('Checking voice activity')
47
+ target_sample_rate = 16000
48
+ if sample_rate != target_sample_rate:
49
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
50
+ audio_tensor = resampler(torch.from_numpy(audio_data))
51
+ else:
52
+ audio_tensor = torch.from_numpy(audio_data)
53
+ audio_tensor = audio_tensor.to(device)
54
+
55
+ speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)
56
+ logging.info(f'Found {len(speech_timestamps)} speech timestamps')
57
+ return len(speech_timestamps) > 0
58
+
59
+ # Async function to transcribe audio using WhisperX
60
+ def transcribe(audio_data, sample_rate):
61
+ logging.info('Transcribing audio')
62
+ target_sample_rate = 16000
63
+ if sample_rate != target_sample_rate:
64
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
65
+ audio_data = resampler(torch.from_numpy(audio_data)).numpy()
66
+ else:
67
+ audio_data = audio_data
68
+
69
+ batch_size = 16 # Adjust as needed
70
+ result = whisper_model.transcribe(audio_data, batch_size=batch_size)
71
+ text = result["segments"][0]["text"] if len(result["segments"]) > 0 else ""
72
+ logging.info(f'Transcription result: {text}')
73
+ del result
74
+ gc.collect()
75
+ if device == 'cuda':
76
+ torch.cuda.empty_cache()
77
+ return text
78
+
79
+ # Function to convert text to speech using Edge TTS and stream the audio
80
+ def tts_streaming(text_stream):
81
+ logging.info('Performing TTS')
82
+ buffer = ""
83
+ punctuation = {'.', '!', '?'}
84
+ for text_chunk in text_stream:
85
+ if text_chunk is not None:
86
+ buffer += text_chunk
87
+ # Check for sentence completion
88
+ sentences = []
89
+ start = 0
90
+ for i, char in enumerate(buffer):
91
+ if char in punctuation:
92
+ sentences.append(buffer[start:i+1].strip())
93
+ start = i+1
94
+ buffer = buffer[start:]
95
+
96
+ for sentence in sentences:
97
+ if sentence:
98
+ communicate = edge_tts.Communicate(sentence, TTS_VOICE)
99
+ for chunk in communicate.stream_sync():
100
+ if chunk["type"] == "audio":
101
+ yield chunk["data"]
102
+ # Process any remaining text
103
+ if buffer.strip():
104
+ communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)
105
+ for chunk in communicate.stream_sync():
106
+ if chunk["type"] == "audio":
107
+ yield chunk["data"]
108
+
109
+ # Function to perform language model completion using OpenAI API
110
+ def llm(text):
111
+ logging.info('Getting response from OpenAI API')
112
+ response = llm_client.chat.completions.create(
113
+ model="gpt-4o", # Updated to a more recent model
114
+ messages=[
115
+ {"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."},
116
+ {"role": "user", "content": text}
117
+ ],
118
+ stream=True,
119
+ temperature=0.7,
120
+ top_p=0.9
121
+ )
122
+ for chunk in response:
123
+ yield chunk.choices[0].delta.content
124
+
125
+ class Conversation:
126
+ def __init__(self):
127
+ self.mode = 'idle' # idle, listening, speaking
128
+ self.audio_stream = []
129
+ self.valid_chunk_queue = []
130
+ self.first_valid_chunk = None
131
+ self.last_valid_chunks = []
132
+ self.valid_chunk_transcriptions = ''
133
+ self.in_transcription = False
134
+ self.llm_n_tts_task = None
135
+ self.stop_signal = False
136
+ self.sample_rate = 0
137
+ self.out_audio_stream = []
138
+ self.chunk_buffer = 0.5 # seconds
139
+
140
+ def llm_n_tts(self):
141
+ for text_chunk in llm(self.transcription):
142
+ if self.stop_signal:
143
+ break
144
+ for audio_chunk in tts_streaming([text_chunk]):
145
+ if self.stop_signal:
146
+ break
147
+ self.out_audio_stream.append(np.frombuffer(audio_chunk, dtype=np.int16))
148
+
149
+ def process_audio_chunk(self, audio_chunk):
150
+ # Construct audio stream
151
+ audio_data = AudioSegment.from_file(BytesIO(audio_chunk), format="wav")
152
+ audio_data = np.array(audio_data.get_array_of_samples())
153
+ self.sample_rate = audio_data.frame_rate
154
+
155
+ # Check for voice activity
156
+ vad = check_vad(audio_data, self.sample_rate)
157
+
158
+ if vad: # Voice activity detected
159
+ if self.first_valid_chunk is not None:
160
+ self.valid_chunk_queue.append(self.first_valid_chunk)
161
+ self.first_valid_chunk = None
162
+ self.valid_chunk_queue.append(audio_chunk)
163
+
164
+ if len(self.valid_chunk_queue) > 2:
165
+ # i.e. 3 chunks: 1 non valid chunk + 2 valid chunks
166
+ # this is to ensure that the speaker is speaking
167
+ if self.mode == 'idle':
168
+ self.mode = 'listening'
169
+ elif self.mode == 'speaking':
170
+ # Stop llm and tts
171
+ if self.llm_n_tts_task is not None:
172
+ self.stop_signal = True
173
+ self.llm_n_tts_task
174
+ self.stop_signal = False
175
+ self.mode = 'listening'
176
+
177
+ else: # No voice activity
178
+ if self.mode == 'listening':
179
+ self.last_valid_chunks.append(audio_chunk)
180
+
181
+ if len(self.last_valid_chunks) > 2:
182
+ # i.e. 2 chunks where the speaker stopped speaking, but we account for natural pauses
183
+ # so on the 1.5th second of no voice activity, we append the first 2 of the last valid chunks to the valid chunk queue
184
+ # stop listening and start speaking
185
+ self.valid_chunk_queue.extend(self.last_valid_chunks[:2])
186
+ self.last_valid_chunks = []
187
+
188
+ while len(self.valid_chunk_queue) > 0:
189
+ time.sleep(0.1)
190
+
191
+ self.mode = 'speaking'
192
+ self.llm_n_tts_task = threading.Thread(target=self.llm_n_tts)
193
+ self.llm_n_tts_task.start()
194
+
195
+ def transcribe_loop(self):
196
+ while True:
197
+ if self.mode == 'listening':
198
+ if len(self.valid_chunk_queue) > 0:
199
+ accumulated_chunks = np.concatenate(self.valid_chunk_queue)
200
+ total_duration = len(accumulated_chunks) / self.sample_rate
201
+
202
+ if total_duration >= 3.0 and self.in_transcription == True:
203
+ # i.e. we have at least 3 seconds of audio so we can start transcribing to reduce latency
204
+ first_2s_audio = accumulated_chunks[:int(2 * self.sample_rate)]
205
+ transcribed_text = transcribe(first_2s_audio, self.sample_rate)
206
+ self.valid_chunk_transcriptions += transcribed_text
207
+ self.valid_chunk_queue = [accumulated_chunks[int(2 * self.sample_rate):]]
208
+
209
+ if self.mode == any(['idle', 'speaking']):
210
+ # i.e. the request to stop transcription has been made
211
+ # so process the remaining audio
212
+ transcribed_text = transcribe(accumulated_chunks, self.sample_rate)
213
+ self.valid_chunk_transcriptions += transcribed_text
214
+ self.valid_chunk_queue = []
215
+ else:
216
+ time.sleep(0.1)
217
+
218
+ def stream_out_audio(self):
219
+ while True:
220
+ if len(self.out_audio_stream) > 0:
221
+ yield AudioSegment(data=self.out_audio_stream.pop(0), sample_width=2, frame_rate=self.sample_rate, channels=1).raw_data
222
+
223
+ @app.websocket("/ws")
224
+ async def websocket_endpoint(websocket: fastapi.WebSocket):
225
+ # Accept connection
226
+ await websocket.accept()
227
+
228
+ # Initialize conversation
229
+ conversation = Conversation()
230
+
231
+ # Start conversation threads
232
+ transcribe_thread = threading.Thread(target=conversation.transcribe_loop)
233
+ transcribe_thread.start()
234
+
235
+ # Process audio chunks
236
+ chunk_buffer_size = conversation.chunk_buffer
237
+ while True:
238
+ try:
239
+ audio_chunk = await websocket.receive_bytes()
240
+ conversation.process_audio_chunk(audio_chunk)
241
+
242
+ if conversation.mode == 'speaking':
243
+ for audio_chunk in conversation.stream_out_audio():
244
+ await websocket.send_bytes(audio_chunk)
245
+ else:
246
+ await websocket.send_bytes(b'')
247
+ except Exception as e:
248
+ logging.error(e)
249
+ break
250
+
251
+ @app.get("/")
252
+ async def index():
253
+ return fastapi.responses.FileResponse("index.html")
254
+
255
+ if __name__ == '__main__':
256
+ import uvicorn
257
+ uvicorn.run(app, host='0.0.0.0', port=8000)