barreloflube
commited on
Commit
•
70eeaf7
1
Parent(s):
099c588
Refactor code to update UI buttons in audio_tab()
Browse files- playground/refs/audio.m4a +0 -0
- playground/refs/audio.npy +0 -0
- playground/refs/test.ipynb +0 -0
- playground/refs/test.py +330 -0
- playground/testapp/audio.mp3 +0 -0
- playground/testapp/index.html +79 -0
- playground/testapp/main.py +292 -0
- playground/testapp/test.ipynb +478 -0
- playground/testapp/test.py +257 -0
playground/refs/audio.m4a
ADDED
Binary file (268 kB). View file
|
|
playground/refs/audio.npy
ADDED
Binary file (102 kB). View file
|
|
playground/refs/test.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
playground/refs/test.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fastapi
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
from silero_vad import get_speech_timestamps, load_silero_vad
|
6 |
+
import whisperx
|
7 |
+
import edge_tts
|
8 |
+
import gc
|
9 |
+
import logging
|
10 |
+
import time
|
11 |
+
from openai import OpenAI
|
12 |
+
import threading
|
13 |
+
|
14 |
+
# Configure logging
|
15 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
16 |
+
|
17 |
+
# Configure FastAPI
|
18 |
+
app = fastapi.FastAPI()
|
19 |
+
|
20 |
+
# Load Silero VAD model
|
21 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
22 |
+
logging.info(f'Using device: {device}')
|
23 |
+
vad_model = load_silero_vad().to(device) # Ensure the model is on the correct device
|
24 |
+
logging.info('Loaded Silero VAD model')
|
25 |
+
|
26 |
+
# Load WhisperX model
|
27 |
+
whisper_model = whisperx.load_model("tiny", device, compute_type="float16")
|
28 |
+
logging.info('Loaded WhisperX model')
|
29 |
+
|
30 |
+
OPENAI_API_KEY = "sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C" # os.getenv("OPENAI_API_KEY")
|
31 |
+
if not OPENAI_API_KEY:
|
32 |
+
logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
|
33 |
+
raise ValueError("OpenAI API key not found.")
|
34 |
+
|
35 |
+
# Initialize OpenAI client
|
36 |
+
openai_client = OpenAI(api_key=OPENAI_API_KEY)
|
37 |
+
logging.info('Initialized OpenAI client')
|
38 |
+
|
39 |
+
# TTS Voice
|
40 |
+
TTS_VOICE = "en-GB-SoniaNeural"
|
41 |
+
|
42 |
+
# Function to check voice activity using Silero VAD
|
43 |
+
def check_vad(audio_data, sample_rate):
|
44 |
+
logging.info('Checking voice activity')
|
45 |
+
# Resample to 16000 Hz if necessary
|
46 |
+
target_sample_rate = 16000
|
47 |
+
if sample_rate != target_sample_rate:
|
48 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
49 |
+
audio_tensor = resampler(torch.from_numpy(audio_data))
|
50 |
+
else:
|
51 |
+
audio_tensor = torch.from_numpy(audio_data)
|
52 |
+
audio_tensor = audio_tensor.to(device)
|
53 |
+
|
54 |
+
# Log audio data details
|
55 |
+
logging.info(f'Audio tensor shape: {audio_tensor.shape}, dtype: {audio_tensor.dtype}, device: {audio_tensor.device}')
|
56 |
+
|
57 |
+
# Get speech timestamps
|
58 |
+
speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)
|
59 |
+
logging.info(f'Found {len(speech_timestamps)} speech timestamps')
|
60 |
+
return len(speech_timestamps) > 0
|
61 |
+
|
62 |
+
# Function to transcribe audio using WhisperX
|
63 |
+
def transcript(audio_data, sample_rate):
|
64 |
+
logging.info('Transcribing audio')
|
65 |
+
# Resample to 16000 Hz if necessary
|
66 |
+
target_sample_rate = 16000
|
67 |
+
if sample_rate != target_sample_rate:
|
68 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
69 |
+
audio_data = resampler(torch.from_numpy(audio_data)).numpy()
|
70 |
+
else:
|
71 |
+
audio_data = audio_data
|
72 |
+
|
73 |
+
# Transcribe
|
74 |
+
batch_size = 16 # Adjust as needed
|
75 |
+
result = whisper_model.transcribe(audio_data, batch_size=batch_size)
|
76 |
+
text = result["segments"][0]["text"] if len(result["segments"]) > 0 else ""
|
77 |
+
logging.info(f'Transcription result: {text}')
|
78 |
+
# Clear GPU memory
|
79 |
+
del result
|
80 |
+
gc.collect()
|
81 |
+
if device == 'cuda':
|
82 |
+
torch.cuda.empty_cache()
|
83 |
+
return text
|
84 |
+
|
85 |
+
# Function to get streaming response from OpenAI API
|
86 |
+
def llm(text):
|
87 |
+
logging.info('Getting response from OpenAI API')
|
88 |
+
response = openai_client.chat.completions.create(
|
89 |
+
model="gpt-4o", # Updated to a more recent model
|
90 |
+
messages=[
|
91 |
+
{"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."},
|
92 |
+
{"role": "user", "content": text}
|
93 |
+
],
|
94 |
+
stream=True,
|
95 |
+
temperature=0.7, # Optional: Adjust as needed
|
96 |
+
top_p=0.9, # Optional: Adjust as needed
|
97 |
+
)
|
98 |
+
for chunk in response:
|
99 |
+
yield chunk.choices[0].delta.content
|
100 |
+
|
101 |
+
# Function to perform TTS per sentence using Edge-TTS
|
102 |
+
def tts_streaming(text_stream):
|
103 |
+
logging.info('Performing TTS')
|
104 |
+
buffer = ""
|
105 |
+
punctuation = {'.', '!', '?'}
|
106 |
+
for text_chunk in text_stream:
|
107 |
+
if text_chunk is not None:
|
108 |
+
buffer += text_chunk
|
109 |
+
# Check for sentence completion
|
110 |
+
sentences = []
|
111 |
+
start = 0
|
112 |
+
for i, char in enumerate(buffer):
|
113 |
+
if (char in punctuation):
|
114 |
+
sentences.append(buffer[start:i+1].strip())
|
115 |
+
start = i+1
|
116 |
+
buffer = buffer[start:]
|
117 |
+
|
118 |
+
for sentence in sentences:
|
119 |
+
if sentence:
|
120 |
+
communicate = edge_tts.Communicate(sentence, TTS_VOICE)
|
121 |
+
for chunk in communicate.stream_sync():
|
122 |
+
if chunk["type"] == "audio":
|
123 |
+
yield chunk["data"]
|
124 |
+
# Process any remaining text
|
125 |
+
if buffer.strip():
|
126 |
+
communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)
|
127 |
+
for chunk in communicate.stream_sync():
|
128 |
+
if chunk["type"] == "audio":
|
129 |
+
yield chunk["data"]
|
130 |
+
|
131 |
+
# Function to handle LLM and TTS
|
132 |
+
def llm_and_tts(transcribed_text, state):
|
133 |
+
logging.info('Handling LLM and TTS')
|
134 |
+
# Get streaming response from LLM
|
135 |
+
for text_chunk in llm(transcribed_text):
|
136 |
+
if state.get('stop_signal'):
|
137 |
+
logging.info('LLM and TTS task stopped')
|
138 |
+
break
|
139 |
+
# Get audio data from TTS
|
140 |
+
for audio_chunk in tts_streaming([text_chunk]):
|
141 |
+
if state.get('stop_signal'):
|
142 |
+
logging.info('LLM and TTS task stopped during TTS')
|
143 |
+
break
|
144 |
+
yield np.frombuffer(audio_chunk, dtype=np.int16)
|
145 |
+
|
146 |
+
state = {
|
147 |
+
'mode': 'idle',
|
148 |
+
'chunk_queue': [],
|
149 |
+
'transcription': '',
|
150 |
+
'in_transcription': False,
|
151 |
+
'previous_no_vad_audio': [],
|
152 |
+
'llm_task': None,
|
153 |
+
'instream': None,
|
154 |
+
'stop_signal': False,
|
155 |
+
'args': {
|
156 |
+
'sample_rate': 16000,
|
157 |
+
'chunk_size': 0.5, # seconds
|
158 |
+
'transcript_chunk_size': 2, # seconds
|
159 |
+
}
|
160 |
+
}
|
161 |
+
|
162 |
+
def transcript_loop():
|
163 |
+
while True:
|
164 |
+
if len(state['chunk_queue']) > 0:
|
165 |
+
accumulated_audio = np.concatenate(state['chunk_queue'])
|
166 |
+
total_samples = sum(len(chunk) for chunk in state['chunk_queue'])
|
167 |
+
total_duration = total_samples / state['sample_rate']
|
168 |
+
|
169 |
+
# Run transcription on the first 2 seconds if len > 3 seconds
|
170 |
+
if total_duration > 3.0 and state['in_transcription'] == True:
|
171 |
+
first_two_seconds_samples = int(2.0 * state['sample_rate'])
|
172 |
+
first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]
|
173 |
+
transcribed_text = transcript(first_two_seconds_audio, state['sample_rate'])
|
174 |
+
state['transcription'] += transcribed_text
|
175 |
+
remaining_audio = accumulated_audio[first_two_seconds_samples:]
|
176 |
+
state['chunk_queue'] = [remaining_audio]
|
177 |
+
else: # Run transcription on the accumulated audio
|
178 |
+
transcribed_text = transcript(accumulated_audio, state['sample_rate'])
|
179 |
+
state['transcription'] += transcribed_text
|
180 |
+
state['chunk_queue'] = []
|
181 |
+
state['in_transcription'] = False
|
182 |
+
else:
|
183 |
+
time.sleep(0.1)
|
184 |
+
|
185 |
+
if len(state['chunk_queue']) == 0 and state['mode'] == any(['idle', 'processing']):
|
186 |
+
state['in_transcription'] = False
|
187 |
+
break
|
188 |
+
|
189 |
+
def process_audio(audio_chunk):
|
190 |
+
# returns output audio
|
191 |
+
|
192 |
+
sample_rate, audio_data = audio_chunk
|
193 |
+
audio_data = np.array(audio_data, dtype=np.float32)
|
194 |
+
|
195 |
+
# convert to mono if necessary
|
196 |
+
if audio_data.ndim > 1:
|
197 |
+
audio_data = np.mean(audio_data, axis=1)
|
198 |
+
|
199 |
+
mode = state['mode']
|
200 |
+
chunk_queue = state['chunk_queue']
|
201 |
+
transcription = state['transcription']
|
202 |
+
in_transcription = state['in_transcription']
|
203 |
+
previous_no_vad_audio = state['previous_no_vad_audio']
|
204 |
+
llm_task = state['llm_task']
|
205 |
+
instream = state['instream']
|
206 |
+
stop_signal = state['stop_signal']
|
207 |
+
args = state['args']
|
208 |
+
|
209 |
+
args['sample_rate'] = sample_rate
|
210 |
+
|
211 |
+
# check for voice activity
|
212 |
+
vad = check_vad(audio_data, sample_rate)
|
213 |
+
|
214 |
+
if vad:
|
215 |
+
logging.info(f'Voice activity detected in mode: {mode}')
|
216 |
+
if mode == 'idle':
|
217 |
+
mode = 'listening'
|
218 |
+
elif mode == 'speaking':
|
219 |
+
# Stop llm and tts tasks
|
220 |
+
if llm_task and llm_task.is_alive():
|
221 |
+
# Implement task cancellation logic if possible
|
222 |
+
logging.info('Stopping LLM and TTS tasks')
|
223 |
+
# Since we cannot kill threads directly, we need to handle this in the tasks
|
224 |
+
stop_signal = True
|
225 |
+
llm_task.join()
|
226 |
+
mode = 'listening'
|
227 |
+
|
228 |
+
if mode == 'listening':
|
229 |
+
if previous_no_vad_audio is not None:
|
230 |
+
chunk_queue.append(previous_no_vad_audio)
|
231 |
+
previous_no_vad_audio = None
|
232 |
+
# Accumulate audio chunks
|
233 |
+
chunk_queue.append(audio_data)
|
234 |
+
|
235 |
+
# Start transcription thread if not already running
|
236 |
+
if not in_transcription:
|
237 |
+
in_transcription = True
|
238 |
+
transcription_task = threading.Thread(target=transcript_loop, args=(chunk_queue, sample_rate))
|
239 |
+
transcription_task.start()
|
240 |
+
|
241 |
+
elif mode == 'speaking':
|
242 |
+
# Continue accumulating audio chunks
|
243 |
+
chunk_queue.append(audio_data)
|
244 |
+
else:
|
245 |
+
logging.info(f'No voice activity detected in mode: {mode}')
|
246 |
+
if mode == 'listening':
|
247 |
+
# Add the last chunk to queue
|
248 |
+
chunk_queue.append(audio_data)
|
249 |
+
|
250 |
+
# Change mode to processing
|
251 |
+
mode = 'processing'
|
252 |
+
|
253 |
+
# Wait for transcription to complete
|
254 |
+
while in_transcription:
|
255 |
+
time.sleep(0.1)
|
256 |
+
|
257 |
+
# Check if transcription is complete
|
258 |
+
if len(chunk_queue) == 0:
|
259 |
+
# Start LLM and TTS tasks
|
260 |
+
if not llm_task or not llm_task.is_alive():
|
261 |
+
stop_signal = False
|
262 |
+
llm_task = threading.Thread(target=llm_and_tts, args=(transcription, state))
|
263 |
+
llm_task.start()
|
264 |
+
|
265 |
+
if mode == 'processing':
|
266 |
+
# Wait for LLM and TTS tasks to start yielding audio
|
267 |
+
if llm_task and llm_task.is_alive():
|
268 |
+
mode = 'responding'
|
269 |
+
|
270 |
+
if mode == 'responding':
|
271 |
+
for audio_chunk in llm_task:
|
272 |
+
if instream is None:
|
273 |
+
instream = audio_chunk
|
274 |
+
else:
|
275 |
+
instream = np.concatenate((instream, audio_chunk))
|
276 |
+
|
277 |
+
# Send audio to output stream
|
278 |
+
yield instream
|
279 |
+
|
280 |
+
# Cleanup
|
281 |
+
llm_task = None
|
282 |
+
transcription = ''
|
283 |
+
mode = 'idle'
|
284 |
+
|
285 |
+
# Updaate state
|
286 |
+
state['mode'] = mode
|
287 |
+
state['chunk_queue'] = chunk_queue
|
288 |
+
state['transcription'] = transcription
|
289 |
+
state['in_transcription'] = in_transcription
|
290 |
+
state['previous_no_vad_audio'] = previous_no_vad_audio
|
291 |
+
state['llm_task'] = llm_task
|
292 |
+
state['instream'] = instream
|
293 |
+
state['stop_signal'] = stop_signal
|
294 |
+
state['args'] = args
|
295 |
+
|
296 |
+
# Store previous audio chunk with no voice activity
|
297 |
+
previous_no_vad_audio = audio_data
|
298 |
+
|
299 |
+
# Update state
|
300 |
+
state['mode'] = mode
|
301 |
+
state['chunk_queue'] = chunk_queue
|
302 |
+
state['transcription'] = transcription
|
303 |
+
state['in_transcription'] = in_transcription
|
304 |
+
state['previous_no_vad_audio'] = previous_no_vad_audio
|
305 |
+
state['llm_task'] = llm_task
|
306 |
+
state['instream'] = instream
|
307 |
+
state['stop_signal'] = stop_signal
|
308 |
+
state['args'] = args
|
309 |
+
|
310 |
+
|
311 |
+
@app.websocket('/ws')
|
312 |
+
def websocket_endpoint(websocket: fastapi.WebSocket):
|
313 |
+
logging.info('WebSocket connection established')
|
314 |
+
try:
|
315 |
+
while True:
|
316 |
+
time.sleep(state['args']['chunk_size'])
|
317 |
+
audio_chunk = websocket.receive_bytes()
|
318 |
+
if audio_chunk is None:
|
319 |
+
break
|
320 |
+
for audio_data in process_audio(audio_chunk):
|
321 |
+
websocket.send_bytes(audio_data.tobytes())
|
322 |
+
except Exception as e:
|
323 |
+
logging.error(f'WebSocket error: {e}')
|
324 |
+
finally:
|
325 |
+
logging.info('WebSocket connection closed')
|
326 |
+
websocket.close()
|
327 |
+
|
328 |
+
@app.get('/')
|
329 |
+
def index():
|
330 |
+
return fastapi.FileResponse('index.html')
|
playground/testapp/audio.mp3
ADDED
Binary file (386 kB). View file
|
|
playground/testapp/index.html
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Voice Assistant</title>
|
7 |
+
<style>
|
8 |
+
body {
|
9 |
+
font-family: Arial, sans-serif;
|
10 |
+
margin: 20px;
|
11 |
+
}
|
12 |
+
#transcription {
|
13 |
+
margin-top: 20px;
|
14 |
+
padding: 10px;
|
15 |
+
border: 1px solid #ccc;
|
16 |
+
height: 150px;
|
17 |
+
overflow-y: auto;
|
18 |
+
}
|
19 |
+
#audio-player {
|
20 |
+
margin-top: 20px;
|
21 |
+
}
|
22 |
+
</style>
|
23 |
+
</head>
|
24 |
+
<body>
|
25 |
+
<h1>Voice Assistant</h1>
|
26 |
+
<button id="start-btn">Start Recording</button>
|
27 |
+
<button id="stop-btn" disabled>Stop Recording</button>
|
28 |
+
<div id="transcription"></div>
|
29 |
+
<audio id="audio-player" controls></audio>
|
30 |
+
|
31 |
+
<script>
|
32 |
+
const startBtn = document.getElementById('start-btn');
|
33 |
+
const stopBtn = document.getElementById('stop-btn');
|
34 |
+
const transcriptionDiv = document.getElementById('transcription');
|
35 |
+
const audioPlayer = document.getElementById('audio-player');
|
36 |
+
let websocket;
|
37 |
+
let mediaRecorder;
|
38 |
+
let audioChunks = [];
|
39 |
+
|
40 |
+
startBtn.addEventListener('click', async () => {
|
41 |
+
startBtn.disabled = true;
|
42 |
+
stopBtn.disabled = false;
|
43 |
+
|
44 |
+
websocket = new WebSocket('ws://localhost:8000/ws');
|
45 |
+
websocket.binaryType = 'arraybuffer';
|
46 |
+
|
47 |
+
websocket.onmessage = (event) => {
|
48 |
+
if (event.data instanceof ArrayBuffer) {
|
49 |
+
const audioBlob = new Blob([event.data], { type: 'audio/wav' });
|
50 |
+
audioPlayer.src = URL.createObjectURL(audioBlob);
|
51 |
+
audioPlayer.play();
|
52 |
+
} else {
|
53 |
+
transcriptionDiv.innerText += event.data + '\n';
|
54 |
+
}
|
55 |
+
};
|
56 |
+
|
57 |
+
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
58 |
+
mediaRecorder = new MediaRecorder(stream);
|
59 |
+
|
60 |
+
mediaRecorder.ondataavailable = (event) => {
|
61 |
+
if (event.data.size > 0) {
|
62 |
+
audioChunks.push(event.data);
|
63 |
+
websocket.send(event.data);
|
64 |
+
}
|
65 |
+
};
|
66 |
+
|
67 |
+
mediaRecorder.start(1000); // Send audio data every second
|
68 |
+
});
|
69 |
+
|
70 |
+
stopBtn.addEventListener('click', () => {
|
71 |
+
startBtn.disabled = false;
|
72 |
+
stopBtn.disabled = true;
|
73 |
+
|
74 |
+
mediaRecorder.stop();
|
75 |
+
websocket.close();
|
76 |
+
});
|
77 |
+
</script>
|
78 |
+
</body>
|
79 |
+
</html>
|
playground/testapp/main.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fastapi
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
from silero_vad import get_speech_timestamps, load_silero_vad
|
6 |
+
import whisperx
|
7 |
+
import edge_tts
|
8 |
+
import gc
|
9 |
+
import logging
|
10 |
+
import time
|
11 |
+
import os
|
12 |
+
from openai import AsyncOpenAI
|
13 |
+
import asyncio
|
14 |
+
|
15 |
+
# Configure logging
|
16 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
17 |
+
|
18 |
+
# Configure FastAPI
|
19 |
+
app = fastapi.FastAPI()
|
20 |
+
|
21 |
+
# Load Silero VAD model
|
22 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
23 |
+
logging.info(f'Using device: {device}')
|
24 |
+
vad_model = load_silero_vad().to(device)
|
25 |
+
logging.info('Loaded Silero VAD model')
|
26 |
+
|
27 |
+
# Load WhisperX model
|
28 |
+
whisper_model = whisperx.load_model("tiny", device, compute_type="float16")
|
29 |
+
logging.info('Loaded WhisperX model')
|
30 |
+
|
31 |
+
OPENAI_API_KEY = "sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C"
|
32 |
+
if not OPENAI_API_KEY:
|
33 |
+
logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
|
34 |
+
raise ValueError("OpenAI API key not found.")
|
35 |
+
logging.info('Initialized OpenAI client')
|
36 |
+
aclient = AsyncOpenAI(api_key=OPENAI_API_KEY) # Corrected import
|
37 |
+
|
38 |
+
# TTS Voice
|
39 |
+
TTS_VOICE = "en-GB-SoniaNeural"
|
40 |
+
|
41 |
+
# Function to check voice activity using Silero VAD
|
42 |
+
def check_vad(audio_data, sample_rate):
|
43 |
+
logging.info('Checking voice activity')
|
44 |
+
target_sample_rate = 16000
|
45 |
+
if sample_rate != target_sample_rate:
|
46 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
47 |
+
audio_tensor = resampler(torch.from_numpy(audio_data))
|
48 |
+
else:
|
49 |
+
audio_tensor = torch.from_numpy(audio_data)
|
50 |
+
audio_tensor = audio_tensor.to(device)
|
51 |
+
|
52 |
+
speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)
|
53 |
+
logging.info(f'Found {len(speech_timestamps)} speech timestamps')
|
54 |
+
return len(speech_timestamps) > 0
|
55 |
+
|
56 |
+
# Async function to transcribe audio using WhisperX
|
57 |
+
def transcript_sync(audio_data, sample_rate):
|
58 |
+
logging.info('Transcribing audio')
|
59 |
+
target_sample_rate = 16000
|
60 |
+
if sample_rate != target_sample_rate:
|
61 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
62 |
+
audio_data = resampler(torch.from_numpy(audio_data)).numpy()
|
63 |
+
else:
|
64 |
+
audio_data = audio_data
|
65 |
+
|
66 |
+
batch_size = 16 # Adjust as needed
|
67 |
+
result = whisper_model.transcribe(audio_data, batch_size=batch_size)
|
68 |
+
text = result["segments"][0]["text"] if len(result["segments"]) > 0 else ""
|
69 |
+
logging.info(f'Transcription result: {text}')
|
70 |
+
del result
|
71 |
+
gc.collect()
|
72 |
+
if device == 'cuda':
|
73 |
+
torch.cuda.empty_cache()
|
74 |
+
return text
|
75 |
+
|
76 |
+
async def transcript(audio_data, sample_rate):
|
77 |
+
loop = asyncio.get_running_loop()
|
78 |
+
text = await loop.run_in_executor(None, transcript_sync, audio_data, sample_rate)
|
79 |
+
return text
|
80 |
+
|
81 |
+
# Async function to get streaming response from OpenAI API
|
82 |
+
async def llm(text):
|
83 |
+
logging.info('Getting response from OpenAI API')
|
84 |
+
response = await aclient.chat.completions.create(model="gpt-4", # Updated to a more recent model
|
85 |
+
messages=[
|
86 |
+
{"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."},
|
87 |
+
{"role": "user", "content": text}
|
88 |
+
],
|
89 |
+
stream=True,
|
90 |
+
temperature=0.7,
|
91 |
+
top_p=0.9)
|
92 |
+
async for chunk in response:
|
93 |
+
yield chunk.choices[0].delta.content
|
94 |
+
|
95 |
+
# Async function to perform TTS using Edge-TTS
|
96 |
+
async def tts_streaming(text_stream):
|
97 |
+
logging.info('Performing TTS')
|
98 |
+
buffer = ""
|
99 |
+
punctuation = {'.', '!', '?'}
|
100 |
+
for text_chunk in text_stream:
|
101 |
+
if text_chunk is not None:
|
102 |
+
buffer += text_chunk
|
103 |
+
# Check for sentence completion
|
104 |
+
sentences = []
|
105 |
+
start = 0
|
106 |
+
for i, char in enumerate(buffer):
|
107 |
+
if char in punctuation:
|
108 |
+
sentences.append(buffer[start:i+1].strip())
|
109 |
+
start = i+1
|
110 |
+
buffer = buffer[start:]
|
111 |
+
|
112 |
+
for sentence in sentences:
|
113 |
+
if sentence:
|
114 |
+
communicate = edge_tts.Communicate(sentence, TTS_VOICE)
|
115 |
+
async for chunk in communicate.stream():
|
116 |
+
if chunk["type"] == "audio":
|
117 |
+
yield chunk["data"]
|
118 |
+
# Process any remaining text
|
119 |
+
if buffer.strip():
|
120 |
+
communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)
|
121 |
+
async for chunk in communicate.stream():
|
122 |
+
if chunk["type"] == "audio":
|
123 |
+
yield chunk["data"]
|
124 |
+
|
125 |
+
class Conversation:
|
126 |
+
def __init__(self):
|
127 |
+
self.mode = 'idle'
|
128 |
+
self.chunk_queue = []
|
129 |
+
self.transcription = ''
|
130 |
+
self.in_transcription = False
|
131 |
+
self.previous_no_vad_audio = None
|
132 |
+
self.llm_task = None
|
133 |
+
self.transcription_task = None
|
134 |
+
self.stop_signal = False
|
135 |
+
self.sample_rate = 16000 # default sample rate
|
136 |
+
self.instream = None
|
137 |
+
|
138 |
+
async def process_audio(self, audio_chunk):
|
139 |
+
sample_rate, audio_data = audio_chunk
|
140 |
+
self.sample_rate = sample_rate
|
141 |
+
audio_data = np.array(audio_data, dtype=np.float32)
|
142 |
+
|
143 |
+
# convert to mono if necessary
|
144 |
+
if audio_data.ndim > 1:
|
145 |
+
audio_data = np.mean(audio_data, axis=1)
|
146 |
+
|
147 |
+
# check for voice activity
|
148 |
+
vad = check_vad(audio_data, sample_rate)
|
149 |
+
|
150 |
+
if vad:
|
151 |
+
logging.info(f'Voice activity detected in mode: {self.mode}')
|
152 |
+
if self.mode == 'idle':
|
153 |
+
self.mode = 'listening'
|
154 |
+
elif self.mode == 'speaking':
|
155 |
+
# Stop llm and tts tasks
|
156 |
+
if self.llm_task and not self.llm_task.done():
|
157 |
+
logging.info('Stopping LLM and TTS tasks')
|
158 |
+
self.stop_signal = True
|
159 |
+
await self.llm_task
|
160 |
+
self.mode = 'listening'
|
161 |
+
|
162 |
+
if self.mode == 'listening':
|
163 |
+
if self.previous_no_vad_audio is not None:
|
164 |
+
self.chunk_queue.append(self.previous_no_vad_audio)
|
165 |
+
self.previous_no_vad_audio = None
|
166 |
+
# Accumulate audio chunks
|
167 |
+
self.chunk_queue.append(audio_data)
|
168 |
+
|
169 |
+
# Start transcription task if not already running
|
170 |
+
if not self.in_transcription:
|
171 |
+
self.in_transcription = True
|
172 |
+
self.transcription_task = asyncio.create_task(self.transcript_loop())
|
173 |
+
|
174 |
+
else:
|
175 |
+
logging.info(f'No voice activity detected in mode: {self.mode}')
|
176 |
+
if self.mode == 'listening':
|
177 |
+
# Add the last chunk to queue
|
178 |
+
self.chunk_queue.append(audio_data)
|
179 |
+
|
180 |
+
# Change mode to processing
|
181 |
+
self.mode = 'processing'
|
182 |
+
|
183 |
+
# Wait for transcription to complete
|
184 |
+
while self.in_transcription:
|
185 |
+
await asyncio.sleep(0.1)
|
186 |
+
|
187 |
+
# Check if transcription is complete
|
188 |
+
if len(self.chunk_queue) == 0:
|
189 |
+
# Start LLM and TTS tasks
|
190 |
+
if not self.llm_task or self.llm_task.done():
|
191 |
+
self.stop_signal = False
|
192 |
+
self.llm_task = self.llm_and_tts()
|
193 |
+
self.mode = 'responding'
|
194 |
+
|
195 |
+
if self.mode == 'responding':
|
196 |
+
async for audio_chunk in self.llm_task:
|
197 |
+
if self.instream is None:
|
198 |
+
self.instream = audio_chunk
|
199 |
+
else:
|
200 |
+
self.instream = np.concatenate((self.instream, audio_chunk))
|
201 |
+
# Send audio to output stream
|
202 |
+
yield self.instream
|
203 |
+
|
204 |
+
# Cleanup
|
205 |
+
self.llm_task = None
|
206 |
+
self.transcription = ''
|
207 |
+
self.mode = 'idle'
|
208 |
+
self.instream = None
|
209 |
+
|
210 |
+
# Store previous audio chunk with no voice activity
|
211 |
+
self.previous_no_vad_audio = audio_data
|
212 |
+
|
213 |
+
async def transcript_loop(self):
|
214 |
+
while True:
|
215 |
+
if len(self.chunk_queue) > 0:
|
216 |
+
accumulated_audio = np.concatenate(self.chunk_queue)
|
217 |
+
total_samples = len(accumulated_audio)
|
218 |
+
total_duration = total_samples / self.sample_rate
|
219 |
+
|
220 |
+
if total_duration > 3.0 and self.in_transcription == True:
|
221 |
+
first_two_seconds_samples = int(2.0 * self.sample_rate)
|
222 |
+
first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]
|
223 |
+
transcribed_text = await transcript(first_two_seconds_audio, self.sample_rate)
|
224 |
+
self.transcription += transcribed_text
|
225 |
+
remaining_audio = accumulated_audio[first_two_seconds_samples:]
|
226 |
+
self.chunk_queue = [remaining_audio]
|
227 |
+
else:
|
228 |
+
transcribed_text = await transcript(accumulated_audio, self.sample_rate)
|
229 |
+
self.transcription += transcribed_text
|
230 |
+
self.chunk_queue = []
|
231 |
+
self.in_transcription = False
|
232 |
+
else:
|
233 |
+
await asyncio.sleep(0.1)
|
234 |
+
|
235 |
+
if len(self.chunk_queue) == 0 and self.mode in ['idle', 'processing']:
|
236 |
+
self.in_transcription = False
|
237 |
+
break
|
238 |
+
|
239 |
+
async def llm_and_tts(self):
|
240 |
+
logging.info('Handling LLM and TTS')
|
241 |
+
async for text_chunk in llm(self.transcription):
|
242 |
+
if self.stop_signal:
|
243 |
+
logging.info('LLM and TTS task stopped')
|
244 |
+
break
|
245 |
+
async for audio_chunk in tts_streaming([text_chunk]):
|
246 |
+
if self.stop_signal:
|
247 |
+
logging.info('LLM and TTS task stopped during TTS')
|
248 |
+
break
|
249 |
+
yield np.frombuffer(audio_chunk, dtype=np.int16)
|
250 |
+
|
251 |
+
@app.websocket('/ws')
|
252 |
+
async def websocket_endpoint(websocket: fastapi.WebSocket):
|
253 |
+
await websocket.accept()
|
254 |
+
logging.info('WebSocket connection established')
|
255 |
+
conversation = Conversation()
|
256 |
+
audio_buffer = []
|
257 |
+
buffer_duration = 0.5 # 500ms
|
258 |
+
try:
|
259 |
+
while True:
|
260 |
+
audio_chunk_bytes = await websocket.receive_bytes()
|
261 |
+
if audio_chunk_bytes is None:
|
262 |
+
break
|
263 |
+
|
264 |
+
audio_chunk = (conversation.sample_rate, np.frombuffer(audio_chunk_bytes, dtype=np.int16))
|
265 |
+
audio_buffer.append(audio_chunk[1])
|
266 |
+
|
267 |
+
# Calculate the duration of the buffered audio
|
268 |
+
total_samples = sum(len(chunk) for chunk in audio_buffer)
|
269 |
+
total_duration = total_samples / conversation.sample_rate
|
270 |
+
|
271 |
+
if total_duration >= buffer_duration:
|
272 |
+
# Concatenate buffered audio chunks
|
273 |
+
buffered_audio = np.concatenate(audio_buffer)
|
274 |
+
audio_buffer = [] # Reset buffer
|
275 |
+
|
276 |
+
# Process the buffered audio
|
277 |
+
async for audio_data in conversation.process_audio((conversation.sample_rate, buffered_audio)):
|
278 |
+
if audio_data is not None:
|
279 |
+
await websocket.send_bytes(audio_data.tobytes())
|
280 |
+
except Exception as e:
|
281 |
+
logging.error(f'WebSocket error: {e}')
|
282 |
+
finally:
|
283 |
+
logging.info('WebSocket connection closed')
|
284 |
+
await websocket.close()
|
285 |
+
|
286 |
+
@app.get('/')
|
287 |
+
def index():
|
288 |
+
return fastapi.responses.FileResponse('index.html')
|
289 |
+
|
290 |
+
if __name__ == '__main__':
|
291 |
+
import uvicorn
|
292 |
+
uvicorn.run(app, host='0.0.0.0', port=8000)
|
playground/testapp/test.ipynb
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import fastapi\n",
|
10 |
+
"import numpy as np\n",
|
11 |
+
"import torch\n",
|
12 |
+
"import torchaudio\n",
|
13 |
+
"from silero_vad import get_speech_timestamps, load_silero_vad\n",
|
14 |
+
"import whisperx\n",
|
15 |
+
"import edge_tts\n",
|
16 |
+
"import gc\n",
|
17 |
+
"import logging\n",
|
18 |
+
"import time\n",
|
19 |
+
"from openai import OpenAI\n",
|
20 |
+
"import threading\n",
|
21 |
+
"import asyncio\n",
|
22 |
+
"\n",
|
23 |
+
"# Configure logging\n",
|
24 |
+
"logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
|
25 |
+
"\n",
|
26 |
+
"# Configure FastAPI\n",
|
27 |
+
"app = fastapi.FastAPI()\n",
|
28 |
+
"\n",
|
29 |
+
"# Load Silero VAD model\n",
|
30 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
31 |
+
"logging.info(f'Using device: {device}')\n",
|
32 |
+
"vad_model = load_silero_vad().to(device) # Ensure the model is on the correct device\n",
|
33 |
+
"logging.info('Loaded Silero VAD model')\n",
|
34 |
+
"\n",
|
35 |
+
"# Load WhisperX model\n",
|
36 |
+
"whisper_model = whisperx.load_model(\"tiny\", device, compute_type=\"float16\")\n",
|
37 |
+
"logging.info('Loaded WhisperX model')\n",
|
38 |
+
"\n",
|
39 |
+
"# OpenAI API Key from environment variable for security\n",
|
40 |
+
"OPENAI_API_KEY = \"sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C\" # os.getenv(\"OPENAI_API_KEY\")\n",
|
41 |
+
"if not OPENAI_API_KEY:\n",
|
42 |
+
" logging.error(\"OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.\")\n",
|
43 |
+
" raise ValueError(\"OpenAI API key not found.\")\n",
|
44 |
+
"\n",
|
45 |
+
"# Initialize OpenAI client\n",
|
46 |
+
"openai_client = OpenAI(api_key=OPENAI_API_KEY)\n",
|
47 |
+
"logging.info('Initialized OpenAI client')\n",
|
48 |
+
"\n",
|
49 |
+
"# TTS Voice\n",
|
50 |
+
"TTS_VOICE = \"en-GB-SoniaNeural\"\n",
|
51 |
+
"\n",
|
52 |
+
"# Function to check voice activity using Silero VAD\n",
|
53 |
+
"def check_vad(audio_data, sample_rate):\n",
|
54 |
+
" logging.info('Checking voice activity')\n",
|
55 |
+
" # Resample to 16000 Hz if necessary\n",
|
56 |
+
" target_sample_rate = 16000\n",
|
57 |
+
" if sample_rate != target_sample_rate:\n",
|
58 |
+
" resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)\n",
|
59 |
+
" audio_tensor = resampler(torch.from_numpy(audio_data))\n",
|
60 |
+
" else:\n",
|
61 |
+
" audio_tensor = torch.from_numpy(audio_data)\n",
|
62 |
+
" audio_tensor = audio_tensor.to(device)\n",
|
63 |
+
"\n",
|
64 |
+
" # Log audio data details\n",
|
65 |
+
" logging.info(f'Audio tensor shape: {audio_tensor.shape}, dtype: {audio_tensor.dtype}, device: {audio_tensor.device}')\n",
|
66 |
+
"\n",
|
67 |
+
" # Get speech timestamps\n",
|
68 |
+
" speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)\n",
|
69 |
+
" logging.info(f'Found {len(speech_timestamps)} speech timestamps')\n",
|
70 |
+
" return len(speech_timestamps) > 0\n",
|
71 |
+
"\n",
|
72 |
+
"# Function to transcribe audio using WhisperX\n",
|
73 |
+
"def transcript(audio_data, sample_rate):\n",
|
74 |
+
" logging.info('Transcribing audio')\n",
|
75 |
+
" # Resample to 16000 Hz if necessary\n",
|
76 |
+
" target_sample_rate = 16000\n",
|
77 |
+
" if sample_rate != target_sample_rate:\n",
|
78 |
+
" resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)\n",
|
79 |
+
" audio_data = resampler(torch.from_numpy(audio_data)).numpy()\n",
|
80 |
+
" else:\n",
|
81 |
+
" audio_data = audio_data\n",
|
82 |
+
"\n",
|
83 |
+
" # Transcribe\n",
|
84 |
+
" batch_size = 16 # Adjust as needed\n",
|
85 |
+
" result = whisper_model.transcribe(audio_data, batch_size=batch_size)\n",
|
86 |
+
" text = result[\"segments\"][0][\"text\"] if len(result[\"segments\"]) > 0 else \"\"\n",
|
87 |
+
" logging.info(f'Transcription result: {text}')\n",
|
88 |
+
" # Clear GPU memory\n",
|
89 |
+
" del result\n",
|
90 |
+
" gc.collect()\n",
|
91 |
+
" if device == 'cuda':\n",
|
92 |
+
" torch.cuda.empty_cache()\n",
|
93 |
+
" return text\n",
|
94 |
+
"\n",
|
95 |
+
"# Function to get streaming response from OpenAI API\n",
|
96 |
+
"def llm(text):\n",
|
97 |
+
" logging.info('Getting response from OpenAI API')\n",
|
98 |
+
" response = openai_client.chat.completions.create(\n",
|
99 |
+
" model=\"gpt-4o\", # Updated to a more recent model\n",
|
100 |
+
" messages=[\n",
|
101 |
+
" {\"role\": \"system\", \"content\": \"You respond to the following transcript from the conversation that you are having with the user.\"},\n",
|
102 |
+
" {\"role\": \"user\", \"content\": text} \n",
|
103 |
+
" ],\n",
|
104 |
+
" stream=True,\n",
|
105 |
+
" temperature=0.7, # Optional: Adjust as needed\n",
|
106 |
+
" top_p=0.9, # Optional: Adjust as needed\n",
|
107 |
+
" )\n",
|
108 |
+
" for chunk in response:\n",
|
109 |
+
" yield chunk.choices[0].delta.content\n",
|
110 |
+
"\n",
|
111 |
+
"# Function to perform TTS per sentence using Edge-TTS\n",
|
112 |
+
"def tts_streaming(text_stream):\n",
|
113 |
+
" logging.info('Performing TTS')\n",
|
114 |
+
" buffer = \"\"\n",
|
115 |
+
" punctuation = {'.', '!', '?'}\n",
|
116 |
+
" for text_chunk in text_stream:\n",
|
117 |
+
" if text_chunk is not None:\n",
|
118 |
+
" buffer += text_chunk\n",
|
119 |
+
" # Check for sentence completion\n",
|
120 |
+
" sentences = []\n",
|
121 |
+
" start = 0\n",
|
122 |
+
" for i, char in enumerate(buffer):\n",
|
123 |
+
" if (char in punctuation):\n",
|
124 |
+
" sentences.append(buffer[start:i+1].strip())\n",
|
125 |
+
" start = i+1\n",
|
126 |
+
" buffer = buffer[start:]\n",
|
127 |
+
"\n",
|
128 |
+
" for sentence in sentences:\n",
|
129 |
+
" if sentence:\n",
|
130 |
+
" communicate = edge_tts.Communicate(sentence, TTS_VOICE)\n",
|
131 |
+
" for chunk in communicate.stream_sync():\n",
|
132 |
+
" if chunk[\"type\"] == \"audio\":\n",
|
133 |
+
" yield chunk[\"data\"]\n",
|
134 |
+
" # Process any remaining text\n",
|
135 |
+
" if buffer.strip():\n",
|
136 |
+
" communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)\n",
|
137 |
+
" for chunk in communicate.stream_sync():\n",
|
138 |
+
" if chunk[\"type\"] == \"audio\":\n",
|
139 |
+
" yield chunk[\"data\"]\n",
|
140 |
+
"\n",
|
141 |
+
"# Function to handle LLM and TTS\n",
|
142 |
+
"def llm_and_tts(transcribed_text):\n",
|
143 |
+
" logging.info('Handling LLM and TTS')\n",
|
144 |
+
" # Get streaming response from LLM\n",
|
145 |
+
" for text_chunk in llm(transcribed_text):\n",
|
146 |
+
" if state.get('stop_signal'):\n",
|
147 |
+
" logging.info('LLM and TTS task stopped')\n",
|
148 |
+
" break\n",
|
149 |
+
" # Get audio data from TTS\n",
|
150 |
+
" for audio_chunk in tts_streaming([text_chunk]):\n",
|
151 |
+
" if state.get('stop_signal'):\n",
|
152 |
+
" logging.info('LLM and TTS task stopped during TTS')\n",
|
153 |
+
" break\n",
|
154 |
+
" yield np.frombuffer(audio_chunk, dtype=np.int16)\n",
|
155 |
+
"\n",
|
156 |
+
"state = {\n",
|
157 |
+
" 'mode': 'idle',\n",
|
158 |
+
" 'chunk_queue': [],\n",
|
159 |
+
" 'transcription': '',\n",
|
160 |
+
" 'in_transcription': False,\n",
|
161 |
+
" 'previous_no_vad_audio': [],\n",
|
162 |
+
" 'llm_task': None,\n",
|
163 |
+
" 'instream': None,\n",
|
164 |
+
" 'stop_signal': False,\n",
|
165 |
+
" 'args': {\n",
|
166 |
+
" 'sample_rate': 16000,\n",
|
167 |
+
" 'chunk_size': 0.5, # seconds\n",
|
168 |
+
" 'transcript_chunk_size': 2, # seconds\n",
|
169 |
+
" }\n",
|
170 |
+
"}\n",
|
171 |
+
"\n",
|
172 |
+
"def transcript_loop():\n",
|
173 |
+
" while True:\n",
|
174 |
+
" if len(state['chunk_queue']) > 0:\n",
|
175 |
+
" accumulated_audio = np.concatenate(state['chunk_queue'])\n",
|
176 |
+
" total_samples = sum(len(chunk) for chunk in state['chunk_queue'])\n",
|
177 |
+
" total_duration = total_samples / state['args']['sample_rate']\n",
|
178 |
+
" \n",
|
179 |
+
" # Run transcription on the first 2 seconds if len > 3 seconds\n",
|
180 |
+
" if total_duration > 3.0 and state['in_transcription'] == True:\n",
|
181 |
+
" first_two_seconds_samples = int(2.0 * state['args']['sample_rate'])\n",
|
182 |
+
" first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]\n",
|
183 |
+
" transcribed_text = transcript(first_two_seconds_audio, state['args']['sample_rate'])\n",
|
184 |
+
" state['transcription'] += transcribed_text\n",
|
185 |
+
" remaining_audio = accumulated_audio[first_two_seconds_samples:]\n",
|
186 |
+
" state['chunk_queue'] = [remaining_audio]\n",
|
187 |
+
" else: # Run transcription on the accumulated audio\n",
|
188 |
+
" transcribed_text = transcript(accumulated_audio, state['args']['sample_rate'])\n",
|
189 |
+
" state['transcription'] += transcribed_text\n",
|
190 |
+
" state['chunk_queue'] = []\n",
|
191 |
+
" state['in_transcription'] = False\n",
|
192 |
+
" else:\n",
|
193 |
+
" time.sleep(0.1)\n",
|
194 |
+
"\n",
|
195 |
+
" if len(state['chunk_queue']) == 0 and state['mode'] == any(['idle', 'processing']):\n",
|
196 |
+
" state['in_transcription'] = False\n",
|
197 |
+
" break\n",
|
198 |
+
"\n",
|
199 |
+
"def process_audio(audio_chunk):\n",
|
200 |
+
" # returns output audio\n",
|
201 |
+
" \n",
|
202 |
+
" sample_rate, audio_data = audio_chunk\n",
|
203 |
+
" audio_data = np.array(audio_data, dtype=np.float32)\n",
|
204 |
+
" \n",
|
205 |
+
" # convert to mono if necessary\n",
|
206 |
+
" if audio_data.ndim > 1:\n",
|
207 |
+
" audio_data = np.mean(audio_data, axis=1)\n",
|
208 |
+
"\n",
|
209 |
+
" mode = state['mode']\n",
|
210 |
+
" chunk_queue = state['chunk_queue']\n",
|
211 |
+
" transcription = state['transcription']\n",
|
212 |
+
" in_transcription = state['in_transcription']\n",
|
213 |
+
" previous_no_vad_audio = state['previous_no_vad_audio']\n",
|
214 |
+
" llm_task = state['llm_task']\n",
|
215 |
+
" instream = state['instream']\n",
|
216 |
+
" stop_signal = state['stop_signal']\n",
|
217 |
+
" args = state['args']\n",
|
218 |
+
" \n",
|
219 |
+
" args['sample_rate'] = sample_rate\n",
|
220 |
+
" \n",
|
221 |
+
" # check for voice activity\n",
|
222 |
+
" vad = check_vad(audio_data, sample_rate)\n",
|
223 |
+
" \n",
|
224 |
+
" if vad:\n",
|
225 |
+
" logging.info(f'Voice activity detected in mode: {mode}')\n",
|
226 |
+
" if mode == 'idle':\n",
|
227 |
+
" mode = 'listening'\n",
|
228 |
+
" elif mode == 'speaking':\n",
|
229 |
+
" # Stop llm and tts tasks\n",
|
230 |
+
" if llm_task and llm_task.is_alive():\n",
|
231 |
+
" # Implement task cancellation logic if possible\n",
|
232 |
+
" logging.info('Stopping LLM and TTS tasks')\n",
|
233 |
+
" # Since we cannot kill threads directly, we need to handle this in the tasks\n",
|
234 |
+
" stop_signal = True\n",
|
235 |
+
" llm_task.join()\n",
|
236 |
+
" mode = 'listening'\n",
|
237 |
+
"\n",
|
238 |
+
" if mode == 'listening':\n",
|
239 |
+
" if previous_no_vad_audio is not None:\n",
|
240 |
+
" chunk_queue.append(previous_no_vad_audio)\n",
|
241 |
+
" previous_no_vad_audio = None\n",
|
242 |
+
" # Accumulate audio chunks\n",
|
243 |
+
" chunk_queue.append(audio_data)\n",
|
244 |
+
" \n",
|
245 |
+
" # Start transcription thread if not already running\n",
|
246 |
+
" if not in_transcription:\n",
|
247 |
+
" in_transcription = True\n",
|
248 |
+
" transcription_task = threading.Thread(target=transcript_loop)\n",
|
249 |
+
" transcription_task.start()\n",
|
250 |
+
" \n",
|
251 |
+
" elif mode == 'speaking':\n",
|
252 |
+
" # Continue accumulating audio chunks\n",
|
253 |
+
" chunk_queue.append(audio_data)\n",
|
254 |
+
" else:\n",
|
255 |
+
" logging.info(f'No voice activity detected in mode: {mode}')\n",
|
256 |
+
" if mode == 'listening':\n",
|
257 |
+
" # Add the last chunk to queue\n",
|
258 |
+
" chunk_queue.append(audio_data)\n",
|
259 |
+
" \n",
|
260 |
+
" # Change mode to processing\n",
|
261 |
+
" mode = 'processing'\n",
|
262 |
+
" \n",
|
263 |
+
" # Wait for transcription to complete\n",
|
264 |
+
" while in_transcription:\n",
|
265 |
+
" time.sleep(0.1)\n",
|
266 |
+
" \n",
|
267 |
+
" # Check if transcription is complete\n",
|
268 |
+
" if len(chunk_queue) == 0:\n",
|
269 |
+
" # Start LLM and TTS tasks\n",
|
270 |
+
" if not llm_task or not llm_task.is_alive():\n",
|
271 |
+
" stop_signal = False\n",
|
272 |
+
" llm_task = threading.Thread(target=llm_and_tts, args=(transcription))\n",
|
273 |
+
" llm_task.start()\n",
|
274 |
+
" \n",
|
275 |
+
" if mode == 'processing':\n",
|
276 |
+
" # Wait for LLM and TTS tasks to start yielding audio\n",
|
277 |
+
" if llm_task and llm_task.is_alive():\n",
|
278 |
+
" mode = 'responding'\n",
|
279 |
+
" \n",
|
280 |
+
" if mode == 'responding':\n",
|
281 |
+
" for audio_chunk in llm_task:\n",
|
282 |
+
" if instream is None:\n",
|
283 |
+
" instream = audio_chunk\n",
|
284 |
+
" else:\n",
|
285 |
+
" instream = np.concatenate((instream, audio_chunk))\n",
|
286 |
+
" \n",
|
287 |
+
" # Send audio to output stream\n",
|
288 |
+
" yield instream\n",
|
289 |
+
" \n",
|
290 |
+
" # Cleanup\n",
|
291 |
+
" llm_task = None\n",
|
292 |
+
" transcription = ''\n",
|
293 |
+
" mode = 'idle'\n",
|
294 |
+
" \n",
|
295 |
+
" # Updaate state\n",
|
296 |
+
" state['mode'] = mode\n",
|
297 |
+
" state['chunk_queue'] = chunk_queue\n",
|
298 |
+
" state['transcription'] = transcription\n",
|
299 |
+
" state['in_transcription'] = in_transcription\n",
|
300 |
+
" state['previous_no_vad_audio'] = previous_no_vad_audio\n",
|
301 |
+
" state['llm_task'] = llm_task\n",
|
302 |
+
" state['instream'] = instream\n",
|
303 |
+
" state['stop_signal'] = stop_signal\n",
|
304 |
+
" state['args'] = args\n",
|
305 |
+
" \n",
|
306 |
+
" # Store previous audio chunk with no voice activity\n",
|
307 |
+
" previous_no_vad_audio = audio_data\n",
|
308 |
+
" \n",
|
309 |
+
" # Update state\n",
|
310 |
+
" state['mode'] = mode\n",
|
311 |
+
" state['chunk_queue'] = chunk_queue\n",
|
312 |
+
" state['transcription'] = transcription\n",
|
313 |
+
" state['in_transcription'] = in_transcription\n",
|
314 |
+
" state['previous_no_vad_audio'] = previous_no_vad_audio\n",
|
315 |
+
" state['llm_task'] = llm_task\n",
|
316 |
+
" state['instream'] = instream\n",
|
317 |
+
" state['stop_signal'] = stop_signal\n",
|
318 |
+
" state['args'] = args"
|
319 |
+
]
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"cell_type": "code",
|
323 |
+
"execution_count": null,
|
324 |
+
"metadata": {},
|
325 |
+
"outputs": [],
|
326 |
+
"source": [
|
327 |
+
"# 1. Load audio.mp3\n",
|
328 |
+
"# 2. Split audio into chunks\n",
|
329 |
+
"# 3. Process each chunk inside a loop\n",
|
330 |
+
"\n",
|
331 |
+
"# Split audio into chunks of 500 ms or less\n",
|
332 |
+
"from pydub import AudioSegment\n",
|
333 |
+
"audio_segment = AudioSegment.from_file('audio.mp3')\n",
|
334 |
+
"chunks = [chunk for chunk in audio_segment[::500]]\n",
|
335 |
+
"chunks[0]\n",
|
336 |
+
"chunks = [(chunk.frame_rate, np.array(chunk.get_array_of_samples(), dtype=np.int16)) for chunk in chunks]\n",
|
337 |
+
"\n",
|
338 |
+
"output_audio = []\n",
|
339 |
+
"# Process each chunk\n",
|
340 |
+
"for chunk in chunks:\n",
|
341 |
+
" for audio_chunk in process_audio(chunk):\n",
|
342 |
+
" output_audio.append(audio_chunk)"
|
343 |
+
]
|
344 |
+
},
|
345 |
+
{
|
346 |
+
"cell_type": "code",
|
347 |
+
"execution_count": null,
|
348 |
+
"metadata": {},
|
349 |
+
"outputs": [],
|
350 |
+
"source": [
|
351 |
+
"output_audio"
|
352 |
+
]
|
353 |
+
},
|
354 |
+
{
|
355 |
+
"cell_type": "code",
|
356 |
+
"execution_count": null,
|
357 |
+
"metadata": {},
|
358 |
+
"outputs": [],
|
359 |
+
"source": [
|
360 |
+
"import asyncio\n",
|
361 |
+
"import websockets\n",
|
362 |
+
"from pydub import AudioSegment\n",
|
363 |
+
"import numpy as np\n",
|
364 |
+
"import simpleaudio as sa\n",
|
365 |
+
"\n",
|
366 |
+
"# Constants\n",
|
367 |
+
"AUDIO_FILE = 'audio.mp3' # Input audio file\n",
|
368 |
+
"CHUNK_DURATION_MS = 250 # Duration of each chunk in milliseconds\n",
|
369 |
+
"WEBSOCKET_URI = 'ws://localhost:8000/ws' # WebSocket endpoint\n",
|
370 |
+
"\n",
|
371 |
+
"async def send_audio_chunks(uri):\n",
|
372 |
+
" # Load audio file using pydub\n",
|
373 |
+
" audio = AudioSegment.from_file(AUDIO_FILE)\n",
|
374 |
+
"\n",
|
375 |
+
" # Ensure audio is mono and 16kHz\n",
|
376 |
+
" if audio.channels > 1:\n",
|
377 |
+
" audio = audio.set_channels(1)\n",
|
378 |
+
" if audio.frame_rate != 16000:\n",
|
379 |
+
" audio = audio.set_frame_rate(16000)\n",
|
380 |
+
" if audio.sample_width != 2: # 2 bytes for int16\n",
|
381 |
+
" audio = audio.set_sample_width(2)\n",
|
382 |
+
"\n",
|
383 |
+
" # Split audio into chunks\n",
|
384 |
+
" chunks = [audio[i:i+CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)]\n",
|
385 |
+
"\n",
|
386 |
+
" # Store received audio data\n",
|
387 |
+
" received_audio_data = b''\n",
|
388 |
+
"\n",
|
389 |
+
" async with websockets.connect(uri) as websocket:\n",
|
390 |
+
" print(\"Connected to server.\")\n",
|
391 |
+
" for idx, chunk in enumerate(chunks):\n",
|
392 |
+
" # Get raw audio data\n",
|
393 |
+
" raw_data = chunk.raw_data\n",
|
394 |
+
"\n",
|
395 |
+
" # Send audio chunk to server\n",
|
396 |
+
" await websocket.send(raw_data)\n",
|
397 |
+
" print(f\"Sent chunk {idx+1}/{len(chunks)}\")\n",
|
398 |
+
"\n",
|
399 |
+
" # Receive response (non-blocking)\n",
|
400 |
+
" try:\n",
|
401 |
+
" response = await asyncio.wait_for(websocket.recv(), timeout=0.1)\n",
|
402 |
+
" if isinstance(response, bytes):\n",
|
403 |
+
" received_audio_data += response\n",
|
404 |
+
" print(f\"Received audio data of length {len(response)} bytes\")\n",
|
405 |
+
" except asyncio.TimeoutError:\n",
|
406 |
+
" pass # No response received yet\n",
|
407 |
+
"\n",
|
408 |
+
" # Simulate real-time by waiting for chunk duration\n",
|
409 |
+
" await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)\n",
|
410 |
+
"\n",
|
411 |
+
" # Send a final empty message to indicate end of transmission\n",
|
412 |
+
" await websocket.send(b'')\n",
|
413 |
+
" print(\"Finished sending audio. Waiting for responses...\")\n",
|
414 |
+
"\n",
|
415 |
+
" # Receive any remaining responses\n",
|
416 |
+
" while True:\n",
|
417 |
+
" try:\n",
|
418 |
+
" response = await asyncio.wait_for(websocket.recv(), timeout=1)\n",
|
419 |
+
" if isinstance(response, bytes):\n",
|
420 |
+
" received_audio_data += response\n",
|
421 |
+
" print(f\"Received audio data of length {len(response)} bytes\")\n",
|
422 |
+
" except asyncio.TimeoutError:\n",
|
423 |
+
" print(\"No more responses. Closing connection.\")\n",
|
424 |
+
" break\n",
|
425 |
+
"\n",
|
426 |
+
" print(\"Connection closed.\")\n",
|
427 |
+
"\n",
|
428 |
+
" # Save received audio data to a file or play it\n",
|
429 |
+
" if received_audio_data:\n",
|
430 |
+
" # Convert bytes to numpy array\n",
|
431 |
+
" audio_array = np.frombuffer(received_audio_data, dtype=np.int16)\n",
|
432 |
+
"\n",
|
433 |
+
" # Play audio using simpleaudio\n",
|
434 |
+
" play_obj = sa.play_buffer(audio_array, 1, 2, 16000)\n",
|
435 |
+
" play_obj.wait_done()\n",
|
436 |
+
"\n",
|
437 |
+
" # Optionally, save to a WAV file\n",
|
438 |
+
" output_audio = AudioSegment(\n",
|
439 |
+
" data=received_audio_data,\n",
|
440 |
+
" sample_width=2, # 2 bytes for int16\n",
|
441 |
+
" frame_rate=16000,\n",
|
442 |
+
" channels=1\n",
|
443 |
+
" )\n",
|
444 |
+
" output_audio.export(\"output_response.wav\", format=\"wav\")\n",
|
445 |
+
" print(\"Saved response audio to 'output_response.wav'\")\n",
|
446 |
+
" else:\n",
|
447 |
+
" print(\"No audio data received.\")\n",
|
448 |
+
"\n",
|
449 |
+
"def main():\n",
|
450 |
+
" asyncio.run(send_audio_chunks(WEBSOCKET_URI))\n",
|
451 |
+
"\n",
|
452 |
+
"if __name__ == '__main__':\n",
|
453 |
+
" main()"
|
454 |
+
]
|
455 |
+
}
|
456 |
+
],
|
457 |
+
"metadata": {
|
458 |
+
"kernelspec": {
|
459 |
+
"display_name": ".venv",
|
460 |
+
"language": "python",
|
461 |
+
"name": "python3"
|
462 |
+
},
|
463 |
+
"language_info": {
|
464 |
+
"codemirror_mode": {
|
465 |
+
"name": "ipython",
|
466 |
+
"version": 3
|
467 |
+
},
|
468 |
+
"file_extension": ".py",
|
469 |
+
"mimetype": "text/x-python",
|
470 |
+
"name": "python",
|
471 |
+
"nbconvert_exporter": "python",
|
472 |
+
"pygments_lexer": "ipython3",
|
473 |
+
"version": "3.10.12"
|
474 |
+
}
|
475 |
+
},
|
476 |
+
"nbformat": 4,
|
477 |
+
"nbformat_minor": 2
|
478 |
+
}
|
playground/testapp/test.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fastapi
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
from silero_vad import get_speech_timestamps, load_silero_vad
|
6 |
+
import whisperx
|
7 |
+
import edge_tts
|
8 |
+
import gc
|
9 |
+
import logging
|
10 |
+
import time
|
11 |
+
import os
|
12 |
+
from openai import OpenAI
|
13 |
+
import asyncio
|
14 |
+
from pydub import AudioSegment
|
15 |
+
from io import BytesIO
|
16 |
+
import threading
|
17 |
+
|
18 |
+
# Configure logging
|
19 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
20 |
+
|
21 |
+
# Configure FastAPI
|
22 |
+
app = fastapi.FastAPI()
|
23 |
+
|
24 |
+
# Load Silero VAD model
|
25 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
26 |
+
logging.info(f'Using device: {device}')
|
27 |
+
vad_model = load_silero_vad().to(device)
|
28 |
+
logging.info('Loaded Silero VAD model')
|
29 |
+
|
30 |
+
# Load WhisperX model
|
31 |
+
whisper_model = whisperx.load_model("tiny", device, compute_type="float16")
|
32 |
+
logging.info('Loaded WhisperX model')
|
33 |
+
|
34 |
+
OPENAI_API_KEY = "sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C"
|
35 |
+
if not OPENAI_API_KEY:
|
36 |
+
logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
|
37 |
+
raise ValueError("OpenAI API key not found.")
|
38 |
+
logging.info('Initialized OpenAI client')
|
39 |
+
llm_client = OpenAI(api_key=OPENAI_API_KEY) # Corrected import
|
40 |
+
|
41 |
+
# TTS Voice
|
42 |
+
TTS_VOICE = "en-GB-SoniaNeural"
|
43 |
+
|
44 |
+
# Function to check voice activity using Silero VAD
|
45 |
+
def check_vad(audio_data, sample_rate):
|
46 |
+
logging.info('Checking voice activity')
|
47 |
+
target_sample_rate = 16000
|
48 |
+
if sample_rate != target_sample_rate:
|
49 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
50 |
+
audio_tensor = resampler(torch.from_numpy(audio_data))
|
51 |
+
else:
|
52 |
+
audio_tensor = torch.from_numpy(audio_data)
|
53 |
+
audio_tensor = audio_tensor.to(device)
|
54 |
+
|
55 |
+
speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)
|
56 |
+
logging.info(f'Found {len(speech_timestamps)} speech timestamps')
|
57 |
+
return len(speech_timestamps) > 0
|
58 |
+
|
59 |
+
# Async function to transcribe audio using WhisperX
|
60 |
+
def transcribe(audio_data, sample_rate):
|
61 |
+
logging.info('Transcribing audio')
|
62 |
+
target_sample_rate = 16000
|
63 |
+
if sample_rate != target_sample_rate:
|
64 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
65 |
+
audio_data = resampler(torch.from_numpy(audio_data)).numpy()
|
66 |
+
else:
|
67 |
+
audio_data = audio_data
|
68 |
+
|
69 |
+
batch_size = 16 # Adjust as needed
|
70 |
+
result = whisper_model.transcribe(audio_data, batch_size=batch_size)
|
71 |
+
text = result["segments"][0]["text"] if len(result["segments"]) > 0 else ""
|
72 |
+
logging.info(f'Transcription result: {text}')
|
73 |
+
del result
|
74 |
+
gc.collect()
|
75 |
+
if device == 'cuda':
|
76 |
+
torch.cuda.empty_cache()
|
77 |
+
return text
|
78 |
+
|
79 |
+
# Function to convert text to speech using Edge TTS and stream the audio
|
80 |
+
def tts_streaming(text_stream):
|
81 |
+
logging.info('Performing TTS')
|
82 |
+
buffer = ""
|
83 |
+
punctuation = {'.', '!', '?'}
|
84 |
+
for text_chunk in text_stream:
|
85 |
+
if text_chunk is not None:
|
86 |
+
buffer += text_chunk
|
87 |
+
# Check for sentence completion
|
88 |
+
sentences = []
|
89 |
+
start = 0
|
90 |
+
for i, char in enumerate(buffer):
|
91 |
+
if char in punctuation:
|
92 |
+
sentences.append(buffer[start:i+1].strip())
|
93 |
+
start = i+1
|
94 |
+
buffer = buffer[start:]
|
95 |
+
|
96 |
+
for sentence in sentences:
|
97 |
+
if sentence:
|
98 |
+
communicate = edge_tts.Communicate(sentence, TTS_VOICE)
|
99 |
+
for chunk in communicate.stream_sync():
|
100 |
+
if chunk["type"] == "audio":
|
101 |
+
yield chunk["data"]
|
102 |
+
# Process any remaining text
|
103 |
+
if buffer.strip():
|
104 |
+
communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)
|
105 |
+
for chunk in communicate.stream_sync():
|
106 |
+
if chunk["type"] == "audio":
|
107 |
+
yield chunk["data"]
|
108 |
+
|
109 |
+
# Function to perform language model completion using OpenAI API
|
110 |
+
def llm(text):
|
111 |
+
logging.info('Getting response from OpenAI API')
|
112 |
+
response = llm_client.chat.completions.create(
|
113 |
+
model="gpt-4o", # Updated to a more recent model
|
114 |
+
messages=[
|
115 |
+
{"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."},
|
116 |
+
{"role": "user", "content": text}
|
117 |
+
],
|
118 |
+
stream=True,
|
119 |
+
temperature=0.7,
|
120 |
+
top_p=0.9
|
121 |
+
)
|
122 |
+
for chunk in response:
|
123 |
+
yield chunk.choices[0].delta.content
|
124 |
+
|
125 |
+
class Conversation:
|
126 |
+
def __init__(self):
|
127 |
+
self.mode = 'idle' # idle, listening, speaking
|
128 |
+
self.audio_stream = []
|
129 |
+
self.valid_chunk_queue = []
|
130 |
+
self.first_valid_chunk = None
|
131 |
+
self.last_valid_chunks = []
|
132 |
+
self.valid_chunk_transcriptions = ''
|
133 |
+
self.in_transcription = False
|
134 |
+
self.llm_n_tts_task = None
|
135 |
+
self.stop_signal = False
|
136 |
+
self.sample_rate = 0
|
137 |
+
self.out_audio_stream = []
|
138 |
+
self.chunk_buffer = 0.5 # seconds
|
139 |
+
|
140 |
+
def llm_n_tts(self):
|
141 |
+
for text_chunk in llm(self.transcription):
|
142 |
+
if self.stop_signal:
|
143 |
+
break
|
144 |
+
for audio_chunk in tts_streaming([text_chunk]):
|
145 |
+
if self.stop_signal:
|
146 |
+
break
|
147 |
+
self.out_audio_stream.append(np.frombuffer(audio_chunk, dtype=np.int16))
|
148 |
+
|
149 |
+
def process_audio_chunk(self, audio_chunk):
|
150 |
+
# Construct audio stream
|
151 |
+
audio_data = AudioSegment.from_file(BytesIO(audio_chunk), format="wav")
|
152 |
+
audio_data = np.array(audio_data.get_array_of_samples())
|
153 |
+
self.sample_rate = audio_data.frame_rate
|
154 |
+
|
155 |
+
# Check for voice activity
|
156 |
+
vad = check_vad(audio_data, self.sample_rate)
|
157 |
+
|
158 |
+
if vad: # Voice activity detected
|
159 |
+
if self.first_valid_chunk is not None:
|
160 |
+
self.valid_chunk_queue.append(self.first_valid_chunk)
|
161 |
+
self.first_valid_chunk = None
|
162 |
+
self.valid_chunk_queue.append(audio_chunk)
|
163 |
+
|
164 |
+
if len(self.valid_chunk_queue) > 2:
|
165 |
+
# i.e. 3 chunks: 1 non valid chunk + 2 valid chunks
|
166 |
+
# this is to ensure that the speaker is speaking
|
167 |
+
if self.mode == 'idle':
|
168 |
+
self.mode = 'listening'
|
169 |
+
elif self.mode == 'speaking':
|
170 |
+
# Stop llm and tts
|
171 |
+
if self.llm_n_tts_task is not None:
|
172 |
+
self.stop_signal = True
|
173 |
+
self.llm_n_tts_task
|
174 |
+
self.stop_signal = False
|
175 |
+
self.mode = 'listening'
|
176 |
+
|
177 |
+
else: # No voice activity
|
178 |
+
if self.mode == 'listening':
|
179 |
+
self.last_valid_chunks.append(audio_chunk)
|
180 |
+
|
181 |
+
if len(self.last_valid_chunks) > 2:
|
182 |
+
# i.e. 2 chunks where the speaker stopped speaking, but we account for natural pauses
|
183 |
+
# so on the 1.5th second of no voice activity, we append the first 2 of the last valid chunks to the valid chunk queue
|
184 |
+
# stop listening and start speaking
|
185 |
+
self.valid_chunk_queue.extend(self.last_valid_chunks[:2])
|
186 |
+
self.last_valid_chunks = []
|
187 |
+
|
188 |
+
while len(self.valid_chunk_queue) > 0:
|
189 |
+
time.sleep(0.1)
|
190 |
+
|
191 |
+
self.mode = 'speaking'
|
192 |
+
self.llm_n_tts_task = threading.Thread(target=self.llm_n_tts)
|
193 |
+
self.llm_n_tts_task.start()
|
194 |
+
|
195 |
+
def transcribe_loop(self):
|
196 |
+
while True:
|
197 |
+
if self.mode == 'listening':
|
198 |
+
if len(self.valid_chunk_queue) > 0:
|
199 |
+
accumulated_chunks = np.concatenate(self.valid_chunk_queue)
|
200 |
+
total_duration = len(accumulated_chunks) / self.sample_rate
|
201 |
+
|
202 |
+
if total_duration >= 3.0 and self.in_transcription == True:
|
203 |
+
# i.e. we have at least 3 seconds of audio so we can start transcribing to reduce latency
|
204 |
+
first_2s_audio = accumulated_chunks[:int(2 * self.sample_rate)]
|
205 |
+
transcribed_text = transcribe(first_2s_audio, self.sample_rate)
|
206 |
+
self.valid_chunk_transcriptions += transcribed_text
|
207 |
+
self.valid_chunk_queue = [accumulated_chunks[int(2 * self.sample_rate):]]
|
208 |
+
|
209 |
+
if self.mode == any(['idle', 'speaking']):
|
210 |
+
# i.e. the request to stop transcription has been made
|
211 |
+
# so process the remaining audio
|
212 |
+
transcribed_text = transcribe(accumulated_chunks, self.sample_rate)
|
213 |
+
self.valid_chunk_transcriptions += transcribed_text
|
214 |
+
self.valid_chunk_queue = []
|
215 |
+
else:
|
216 |
+
time.sleep(0.1)
|
217 |
+
|
218 |
+
def stream_out_audio(self):
|
219 |
+
while True:
|
220 |
+
if len(self.out_audio_stream) > 0:
|
221 |
+
yield AudioSegment(data=self.out_audio_stream.pop(0), sample_width=2, frame_rate=self.sample_rate, channels=1).raw_data
|
222 |
+
|
223 |
+
@app.websocket("/ws")
|
224 |
+
async def websocket_endpoint(websocket: fastapi.WebSocket):
|
225 |
+
# Accept connection
|
226 |
+
await websocket.accept()
|
227 |
+
|
228 |
+
# Initialize conversation
|
229 |
+
conversation = Conversation()
|
230 |
+
|
231 |
+
# Start conversation threads
|
232 |
+
transcribe_thread = threading.Thread(target=conversation.transcribe_loop)
|
233 |
+
transcribe_thread.start()
|
234 |
+
|
235 |
+
# Process audio chunks
|
236 |
+
chunk_buffer_size = conversation.chunk_buffer
|
237 |
+
while True:
|
238 |
+
try:
|
239 |
+
audio_chunk = await websocket.receive_bytes()
|
240 |
+
conversation.process_audio_chunk(audio_chunk)
|
241 |
+
|
242 |
+
if conversation.mode == 'speaking':
|
243 |
+
for audio_chunk in conversation.stream_out_audio():
|
244 |
+
await websocket.send_bytes(audio_chunk)
|
245 |
+
else:
|
246 |
+
await websocket.send_bytes(b'')
|
247 |
+
except Exception as e:
|
248 |
+
logging.error(e)
|
249 |
+
break
|
250 |
+
|
251 |
+
@app.get("/")
|
252 |
+
async def index():
|
253 |
+
return fastapi.responses.FileResponse("index.html")
|
254 |
+
|
255 |
+
if __name__ == '__main__':
|
256 |
+
import uvicorn
|
257 |
+
uvicorn.run(app, host='0.0.0.0', port=8000)
|