import fastapi import numpy as np import torch import torchaudio from silero_vad import get_speech_timestamps, load_silero_vad import whisperx import edge_tts import gc import logging import time import os from openai import AsyncOpenAI import asyncio # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Configure FastAPI app = fastapi.FastAPI() # Load Silero VAD model device = 'cuda' if torch.cuda.is_available() else 'cpu' logging.info(f'Using device: {device}') vad_model = load_silero_vad().to(device) logging.info('Loaded Silero VAD model') # Load WhisperX model whisper_model = whisperx.load_model("tiny", device, compute_type="float16") logging.info('Loaded WhisperX model') OPENAI_API_KEY = "" if not OPENAI_API_KEY: logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.") raise ValueError("OpenAI API key not found.") logging.info('Initialized OpenAI client') aclient = AsyncOpenAI(api_key=OPENAI_API_KEY) # Corrected import # TTS Voice TTS_VOICE = "en-GB-SoniaNeural" # Function to check voice activity using Silero VAD def check_vad(audio_data, sample_rate): logging.info('Checking voice activity') target_sample_rate = 16000 if sample_rate != target_sample_rate: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) audio_tensor = resampler(torch.from_numpy(audio_data)) else: audio_tensor = torch.from_numpy(audio_data) audio_tensor = audio_tensor.to(device) speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate) logging.info(f'Found {len(speech_timestamps)} speech timestamps') return len(speech_timestamps) > 0 # Async function to transcribe audio using WhisperX def transcript_sync(audio_data, sample_rate): logging.info('Transcribing audio') target_sample_rate = 16000 if sample_rate != target_sample_rate: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) audio_data = resampler(torch.from_numpy(audio_data)).numpy() else: audio_data = audio_data batch_size = 16 # Adjust as needed result = whisper_model.transcribe(audio_data, batch_size=batch_size) text = result["segments"][0]["text"] if len(result["segments"]) > 0 else "" logging.info(f'Transcription result: {text}') del result gc.collect() if device == 'cuda': torch.cuda.empty_cache() return text async def transcript(audio_data, sample_rate): loop = asyncio.get_running_loop() text = await loop.run_in_executor(None, transcript_sync, audio_data, sample_rate) return text # Async function to get streaming response from OpenAI API async def llm(text): logging.info('Getting response from OpenAI API') response = await aclient.chat.completions.create(model="gpt-4", # Updated to a more recent model messages=[ {"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."}, {"role": "user", "content": text} ], stream=True, temperature=0.7, top_p=0.9) async for chunk in response: yield chunk.choices[0].delta.content # Async function to perform TTS using Edge-TTS async def tts_streaming(text_stream): logging.info('Performing TTS') buffer = "" punctuation = {'.', '!', '?'} for text_chunk in text_stream: if text_chunk is not None: buffer += text_chunk # Check for sentence completion sentences = [] start = 0 for i, char in enumerate(buffer): if char in punctuation: sentences.append(buffer[start:i+1].strip()) start = i+1 buffer = buffer[start:] for sentence in sentences: if sentence: communicate = edge_tts.Communicate(sentence, TTS_VOICE) async for chunk in communicate.stream(): if chunk["type"] == "audio": yield chunk["data"] # Process any remaining text if buffer.strip(): communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE) async for chunk in communicate.stream(): if chunk["type"] == "audio": yield chunk["data"] class Conversation: def __init__(self): self.mode = 'idle' self.chunk_queue = [] self.transcription = '' self.in_transcription = False self.previous_no_vad_audio = None self.llm_task = None self.transcription_task = None self.stop_signal = False self.sample_rate = 16000 # default sample rate self.instream = None async def process_audio(self, audio_chunk): sample_rate, audio_data = audio_chunk self.sample_rate = sample_rate audio_data = np.array(audio_data, dtype=np.float32) # convert to mono if necessary if audio_data.ndim > 1: audio_data = np.mean(audio_data, axis=1) # check for voice activity vad = check_vad(audio_data, sample_rate) if vad: logging.info(f'Voice activity detected in mode: {self.mode}') if self.mode == 'idle': self.mode = 'listening' elif self.mode == 'speaking': # Stop llm and tts tasks if self.llm_task and not self.llm_task.done(): logging.info('Stopping LLM and TTS tasks') self.stop_signal = True await self.llm_task self.mode = 'listening' if self.mode == 'listening': if self.previous_no_vad_audio is not None: self.chunk_queue.append(self.previous_no_vad_audio) self.previous_no_vad_audio = None # Accumulate audio chunks self.chunk_queue.append(audio_data) # Start transcription task if not already running if not self.in_transcription: self.in_transcription = True self.transcription_task = asyncio.create_task(self.transcript_loop()) else: logging.info(f'No voice activity detected in mode: {self.mode}') if self.mode == 'listening': # Add the last chunk to queue self.chunk_queue.append(audio_data) # Change mode to processing self.mode = 'processing' # Wait for transcription to complete while self.in_transcription: await asyncio.sleep(0.1) # Check if transcription is complete if len(self.chunk_queue) == 0: # Start LLM and TTS tasks if not self.llm_task or self.llm_task.done(): self.stop_signal = False self.llm_task = self.llm_and_tts() self.mode = 'responding' if self.mode == 'responding': async for audio_chunk in self.llm_task: if self.instream is None: self.instream = audio_chunk else: self.instream = np.concatenate((self.instream, audio_chunk)) # Send audio to output stream yield self.instream # Cleanup self.llm_task = None self.transcription = '' self.mode = 'idle' self.instream = None # Store previous audio chunk with no voice activity self.previous_no_vad_audio = audio_data async def transcript_loop(self): while True: if len(self.chunk_queue) > 0: accumulated_audio = np.concatenate(self.chunk_queue) total_samples = len(accumulated_audio) total_duration = total_samples / self.sample_rate if total_duration > 3.0 and self.in_transcription == True: first_two_seconds_samples = int(2.0 * self.sample_rate) first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples] transcribed_text = await transcript(first_two_seconds_audio, self.sample_rate) self.transcription += transcribed_text remaining_audio = accumulated_audio[first_two_seconds_samples:] self.chunk_queue = [remaining_audio] else: transcribed_text = await transcript(accumulated_audio, self.sample_rate) self.transcription += transcribed_text self.chunk_queue = [] self.in_transcription = False else: await asyncio.sleep(0.1) if len(self.chunk_queue) == 0 and self.mode in ['idle', 'processing']: self.in_transcription = False break async def llm_and_tts(self): logging.info('Handling LLM and TTS') async for text_chunk in llm(self.transcription): if self.stop_signal: logging.info('LLM and TTS task stopped') break async for audio_chunk in tts_streaming([text_chunk]): if self.stop_signal: logging.info('LLM and TTS task stopped during TTS') break yield np.frombuffer(audio_chunk, dtype=np.int16) @app.websocket('/ws') async def websocket_endpoint(websocket: fastapi.WebSocket): await websocket.accept() logging.info('WebSocket connection established') conversation = Conversation() audio_buffer = [] buffer_duration = 0.5 # 500ms try: while True: audio_chunk_bytes = await websocket.receive_bytes() if audio_chunk_bytes is None: break audio_chunk = (conversation.sample_rate, np.frombuffer(audio_chunk_bytes, dtype=np.int16)) audio_buffer.append(audio_chunk[1]) # Calculate the duration of the buffered audio total_samples = sum(len(chunk) for chunk in audio_buffer) total_duration = total_samples / conversation.sample_rate if total_duration >= buffer_duration: # Concatenate buffered audio chunks buffered_audio = np.concatenate(audio_buffer) audio_buffer = [] # Reset buffer # Process the buffered audio async for audio_data in conversation.process_audio((conversation.sample_rate, buffered_audio)): if audio_data is not None: await websocket.send_bytes(audio_data.tobytes()) except Exception as e: logging.error(f'WebSocket error: {e}') finally: logging.info('WebSocket connection closed') await websocket.close() @app.get('/') def index(): return fastapi.responses.FileResponse('index.html') if __name__ == '__main__': import uvicorn uvicorn.run(app, host='0.0.0.0', port=8000)