File size: 12,179 Bytes
70eeaf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0718992
70eeaf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
import fastapi
import numpy as np
import torch
import torchaudio
from silero_vad import get_speech_timestamps, load_silero_vad
import whisperx
import edge_tts
import gc
import logging
import time
from openai import OpenAI
import threading

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Configure FastAPI
app = fastapi.FastAPI()

# Load Silero VAD model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.info(f'Using device: {device}')
vad_model = load_silero_vad().to(device)  # Ensure the model is on the correct device
logging.info('Loaded Silero VAD model')

# Load WhisperX model
whisper_model = whisperx.load_model("tiny", device, compute_type="float16")
logging.info('Loaded WhisperX model')

OPENAI_API_KEY = "" # os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
    raise ValueError("OpenAI API key not found.")

# Initialize OpenAI client
openai_client = OpenAI(api_key=OPENAI_API_KEY)
logging.info('Initialized OpenAI client')

# TTS Voice
TTS_VOICE = "en-GB-SoniaNeural"

# Function to check voice activity using Silero VAD
def check_vad(audio_data, sample_rate):
    logging.info('Checking voice activity')
    # Resample to 16000 Hz if necessary
    target_sample_rate = 16000
    if sample_rate != target_sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        audio_tensor = resampler(torch.from_numpy(audio_data))
    else:
        audio_tensor = torch.from_numpy(audio_data)
    audio_tensor = audio_tensor.to(device)

    # Log audio data details
    logging.info(f'Audio tensor shape: {audio_tensor.shape}, dtype: {audio_tensor.dtype}, device: {audio_tensor.device}')

    # Get speech timestamps
    speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)
    logging.info(f'Found {len(speech_timestamps)} speech timestamps')
    return len(speech_timestamps) > 0

# Function to transcribe audio using WhisperX
def transcript(audio_data, sample_rate):
    logging.info('Transcribing audio')
    # Resample to 16000 Hz if necessary
    target_sample_rate = 16000
    if sample_rate != target_sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        audio_data = resampler(torch.from_numpy(audio_data)).numpy()
    else:
        audio_data = audio_data

    # Transcribe
    batch_size = 16  # Adjust as needed
    result = whisper_model.transcribe(audio_data, batch_size=batch_size)
    text = result["segments"][0]["text"] if len(result["segments"]) > 0 else ""
    logging.info(f'Transcription result: {text}')
    # Clear GPU memory
    del result
    gc.collect()
    if device == 'cuda':
        torch.cuda.empty_cache()
    return text

# Function to get streaming response from OpenAI API
def llm(text):
    logging.info('Getting response from OpenAI API')
    response = openai_client.chat.completions.create(
        model="gpt-4o",  # Updated to a more recent model
        messages=[
            {"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."},
            {"role": "user", "content": text}  
        ],
        stream=True,
        temperature=0.7,  # Optional: Adjust as needed
        top_p=0.9,        # Optional: Adjust as needed
    )
    for chunk in response:
        yield chunk.choices[0].delta.content

# Function to perform TTS per sentence using Edge-TTS
def tts_streaming(text_stream):
    logging.info('Performing TTS')
    buffer = ""
    punctuation = {'.', '!', '?'}
    for text_chunk in text_stream:
        if text_chunk is not None:
            buffer += text_chunk
        # Check for sentence completion
        sentences = []
        start = 0
        for i, char in enumerate(buffer):
            if (char in punctuation):
                sentences.append(buffer[start:i+1].strip())
                start = i+1
        buffer = buffer[start:]

        for sentence in sentences:
            if sentence:
                communicate = edge_tts.Communicate(sentence, TTS_VOICE)
                for chunk in communicate.stream_sync():
                    if chunk["type"] == "audio":
                        yield chunk["data"]
    # Process any remaining text
    if buffer.strip():
        communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)
        for chunk in communicate.stream_sync():
            if chunk["type"] == "audio":
                yield chunk["data"]

# Function to handle LLM and TTS
def llm_and_tts(transcribed_text, state):
    logging.info('Handling LLM and TTS')
    # Get streaming response from LLM
    for text_chunk in llm(transcribed_text):
        if state.get('stop_signal'):
            logging.info('LLM and TTS task stopped')
            break
        # Get audio data from TTS
        for audio_chunk in tts_streaming([text_chunk]):
            if state.get('stop_signal'):
                logging.info('LLM and TTS task stopped during TTS')
                break
            yield np.frombuffer(audio_chunk, dtype=np.int16)

state = {
    'mode': 'idle',
    'chunk_queue': [],
    'transcription': '',
    'in_transcription': False,
    'previous_no_vad_audio': [],
    'llm_task': None,
    'instream': None,
    'stop_signal': False,
    'args': {
        'sample_rate': 16000,
        'chunk_size': 0.5, # seconds
        'transcript_chunk_size': 2, # seconds
    }
}

def transcript_loop():
    while True:
        if len(state['chunk_queue']) > 0:
            accumulated_audio = np.concatenate(state['chunk_queue'])
            total_samples = sum(len(chunk) for chunk in state['chunk_queue'])
            total_duration = total_samples / state['sample_rate']
            
            # Run transcription on the first 2 seconds if len > 3 seconds
            if total_duration > 3.0 and state['in_transcription'] == True:
                first_two_seconds_samples = int(2.0 * state['sample_rate'])
                first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]
                transcribed_text = transcript(first_two_seconds_audio, state['sample_rate'])
                state['transcription'] += transcribed_text
                remaining_audio = accumulated_audio[first_two_seconds_samples:]
                state['chunk_queue'] = [remaining_audio]
            else: # Run transcription on the accumulated audio
                transcribed_text = transcript(accumulated_audio, state['sample_rate'])
                state['transcription'] += transcribed_text
                state['chunk_queue'] = []
                state['in_transcription'] = False
        else:
            time.sleep(0.1)

        if len(state['chunk_queue']) == 0 and state['mode'] == any(['idle', 'processing']):
            state['in_transcription'] = False
            break

def process_audio(audio_chunk):
    # returns output audio
    
    sample_rate, audio_data = audio_chunk
    audio_data = np.array(audio_data, dtype=np.float32)
    
    # convert to mono if necessary
    if audio_data.ndim > 1:
        audio_data = np.mean(audio_data, axis=1)

    mode = state['mode']
    chunk_queue = state['chunk_queue']
    transcription = state['transcription']
    in_transcription = state['in_transcription']
    previous_no_vad_audio = state['previous_no_vad_audio']
    llm_task = state['llm_task']
    instream = state['instream']
    stop_signal = state['stop_signal']
    args = state['args']
    
    args['sample_rate'] = sample_rate
    
    # check for voice activity
    vad = check_vad(audio_data, sample_rate)
    
    if vad:
        logging.info(f'Voice activity detected in mode: {mode}')
        if mode == 'idle':
            mode = 'listening'
        elif mode == 'speaking':
            # Stop llm and tts tasks
            if llm_task and llm_task.is_alive():
                # Implement task cancellation logic if possible
                logging.info('Stopping LLM and TTS tasks')
                # Since we cannot kill threads directly, we need to handle this in the tasks
                stop_signal = True
                llm_task.join()
            mode = 'listening'

        if mode == 'listening':
            if previous_no_vad_audio is not None:
                chunk_queue.append(previous_no_vad_audio)
                previous_no_vad_audio = None
            # Accumulate audio chunks
            chunk_queue.append(audio_data)
            
            # Start transcription thread if not already running
            if not in_transcription:
                in_transcription = True
                transcription_task = threading.Thread(target=transcript_loop, args=(chunk_queue, sample_rate))
                transcription_task.start()
        
        elif mode == 'speaking':
            # Continue accumulating audio chunks
            chunk_queue.append(audio_data)
    else:
        logging.info(f'No voice activity detected in mode: {mode}')
        if mode == 'listening':
            # Add the last chunk to queue
            chunk_queue.append(audio_data)
            
            # Change mode to processing
            mode = 'processing'
            
            # Wait for transcription to complete
            while in_transcription:
                time.sleep(0.1)
            
            # Check if transcription is complete
            if len(chunk_queue) == 0:
                # Start LLM and TTS tasks
                if not llm_task or not llm_task.is_alive():
                    stop_signal = False
                    llm_task = threading.Thread(target=llm_and_tts, args=(transcription, state))
                    llm_task.start()
        
        if mode == 'processing':
            # Wait for LLM and TTS tasks to start yielding audio
            if llm_task and llm_task.is_alive():
                mode = 'responding'
        
        if mode == 'responding':
            for audio_chunk in llm_task:
                if instream is None:
                    instream = audio_chunk
                else:
                    instream = np.concatenate((instream, audio_chunk))
                
                # Send audio to output stream
                yield instream
            
            # Cleanup
            llm_task = None
            transcription = ''
            mode = 'idle'
            
            # Updaate state
            state['mode'] = mode
            state['chunk_queue'] = chunk_queue
            state['transcription'] = transcription
            state['in_transcription'] = in_transcription
            state['previous_no_vad_audio'] = previous_no_vad_audio
            state['llm_task'] = llm_task
            state['instream'] = instream
            state['stop_signal'] = stop_signal
            state['args'] = args
        
        # Store previous audio chunk with no voice activity
        previous_no_vad_audio = audio_data
        
        # Update state
        state['mode'] = mode
        state['chunk_queue'] = chunk_queue
        state['transcription'] = transcription
        state['in_transcription'] = in_transcription
        state['previous_no_vad_audio'] = previous_no_vad_audio
        state['llm_task'] = llm_task
        state['instream'] = instream
        state['stop_signal'] = stop_signal
        state['args'] = args
            

@app.websocket('/ws')
def websocket_endpoint(websocket: fastapi.WebSocket):
    logging.info('WebSocket connection established')
    try:
        while True:
            time.sleep(state['args']['chunk_size'])
            audio_chunk = websocket.receive_bytes()
            if audio_chunk is None:
                break
            for audio_data in process_audio(audio_chunk):
                websocket.send_bytes(audio_data.tobytes())
    except Exception as e:
        logging.error(f'WebSocket error: {e}')
    finally:
        logging.info('WebSocket connection closed')
        websocket.close()

@app.get('/')
def index():
    return fastapi.FileResponse('index.html')