barreloflube's picture
Refactor code to remove OpenAI API key from source file
0718992
raw
history blame
11.2 kB
import fastapi
import numpy as np
import torch
import torchaudio
from silero_vad import get_speech_timestamps, load_silero_vad
import whisperx
import edge_tts
import gc
import logging
import time
import os
from openai import AsyncOpenAI
import asyncio
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Configure FastAPI
app = fastapi.FastAPI()
# Load Silero VAD model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.info(f'Using device: {device}')
vad_model = load_silero_vad().to(device)
logging.info('Loaded Silero VAD model')
# Load WhisperX model
whisper_model = whisperx.load_model("tiny", device, compute_type="float16")
logging.info('Loaded WhisperX model')
OPENAI_API_KEY = ""
if not OPENAI_API_KEY:
logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
raise ValueError("OpenAI API key not found.")
logging.info('Initialized OpenAI client')
aclient = AsyncOpenAI(api_key=OPENAI_API_KEY) # Corrected import
# TTS Voice
TTS_VOICE = "en-GB-SoniaNeural"
# Function to check voice activity using Silero VAD
def check_vad(audio_data, sample_rate):
logging.info('Checking voice activity')
target_sample_rate = 16000
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
audio_tensor = resampler(torch.from_numpy(audio_data))
else:
audio_tensor = torch.from_numpy(audio_data)
audio_tensor = audio_tensor.to(device)
speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)
logging.info(f'Found {len(speech_timestamps)} speech timestamps')
return len(speech_timestamps) > 0
# Async function to transcribe audio using WhisperX
def transcript_sync(audio_data, sample_rate):
logging.info('Transcribing audio')
target_sample_rate = 16000
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
audio_data = resampler(torch.from_numpy(audio_data)).numpy()
else:
audio_data = audio_data
batch_size = 16 # Adjust as needed
result = whisper_model.transcribe(audio_data, batch_size=batch_size)
text = result["segments"][0]["text"] if len(result["segments"]) > 0 else ""
logging.info(f'Transcription result: {text}')
del result
gc.collect()
if device == 'cuda':
torch.cuda.empty_cache()
return text
async def transcript(audio_data, sample_rate):
loop = asyncio.get_running_loop()
text = await loop.run_in_executor(None, transcript_sync, audio_data, sample_rate)
return text
# Async function to get streaming response from OpenAI API
async def llm(text):
logging.info('Getting response from OpenAI API')
response = await aclient.chat.completions.create(model="gpt-4", # Updated to a more recent model
messages=[
{"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."},
{"role": "user", "content": text}
],
stream=True,
temperature=0.7,
top_p=0.9)
async for chunk in response:
yield chunk.choices[0].delta.content
# Async function to perform TTS using Edge-TTS
async def tts_streaming(text_stream):
logging.info('Performing TTS')
buffer = ""
punctuation = {'.', '!', '?'}
for text_chunk in text_stream:
if text_chunk is not None:
buffer += text_chunk
# Check for sentence completion
sentences = []
start = 0
for i, char in enumerate(buffer):
if char in punctuation:
sentences.append(buffer[start:i+1].strip())
start = i+1
buffer = buffer[start:]
for sentence in sentences:
if sentence:
communicate = edge_tts.Communicate(sentence, TTS_VOICE)
async for chunk in communicate.stream():
if chunk["type"] == "audio":
yield chunk["data"]
# Process any remaining text
if buffer.strip():
communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)
async for chunk in communicate.stream():
if chunk["type"] == "audio":
yield chunk["data"]
class Conversation:
def __init__(self):
self.mode = 'idle'
self.chunk_queue = []
self.transcription = ''
self.in_transcription = False
self.previous_no_vad_audio = None
self.llm_task = None
self.transcription_task = None
self.stop_signal = False
self.sample_rate = 16000 # default sample rate
self.instream = None
async def process_audio(self, audio_chunk):
sample_rate, audio_data = audio_chunk
self.sample_rate = sample_rate
audio_data = np.array(audio_data, dtype=np.float32)
# convert to mono if necessary
if audio_data.ndim > 1:
audio_data = np.mean(audio_data, axis=1)
# check for voice activity
vad = check_vad(audio_data, sample_rate)
if vad:
logging.info(f'Voice activity detected in mode: {self.mode}')
if self.mode == 'idle':
self.mode = 'listening'
elif self.mode == 'speaking':
# Stop llm and tts tasks
if self.llm_task and not self.llm_task.done():
logging.info('Stopping LLM and TTS tasks')
self.stop_signal = True
await self.llm_task
self.mode = 'listening'
if self.mode == 'listening':
if self.previous_no_vad_audio is not None:
self.chunk_queue.append(self.previous_no_vad_audio)
self.previous_no_vad_audio = None
# Accumulate audio chunks
self.chunk_queue.append(audio_data)
# Start transcription task if not already running
if not self.in_transcription:
self.in_transcription = True
self.transcription_task = asyncio.create_task(self.transcript_loop())
else:
logging.info(f'No voice activity detected in mode: {self.mode}')
if self.mode == 'listening':
# Add the last chunk to queue
self.chunk_queue.append(audio_data)
# Change mode to processing
self.mode = 'processing'
# Wait for transcription to complete
while self.in_transcription:
await asyncio.sleep(0.1)
# Check if transcription is complete
if len(self.chunk_queue) == 0:
# Start LLM and TTS tasks
if not self.llm_task or self.llm_task.done():
self.stop_signal = False
self.llm_task = self.llm_and_tts()
self.mode = 'responding'
if self.mode == 'responding':
async for audio_chunk in self.llm_task:
if self.instream is None:
self.instream = audio_chunk
else:
self.instream = np.concatenate((self.instream, audio_chunk))
# Send audio to output stream
yield self.instream
# Cleanup
self.llm_task = None
self.transcription = ''
self.mode = 'idle'
self.instream = None
# Store previous audio chunk with no voice activity
self.previous_no_vad_audio = audio_data
async def transcript_loop(self):
while True:
if len(self.chunk_queue) > 0:
accumulated_audio = np.concatenate(self.chunk_queue)
total_samples = len(accumulated_audio)
total_duration = total_samples / self.sample_rate
if total_duration > 3.0 and self.in_transcription == True:
first_two_seconds_samples = int(2.0 * self.sample_rate)
first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]
transcribed_text = await transcript(first_two_seconds_audio, self.sample_rate)
self.transcription += transcribed_text
remaining_audio = accumulated_audio[first_two_seconds_samples:]
self.chunk_queue = [remaining_audio]
else:
transcribed_text = await transcript(accumulated_audio, self.sample_rate)
self.transcription += transcribed_text
self.chunk_queue = []
self.in_transcription = False
else:
await asyncio.sleep(0.1)
if len(self.chunk_queue) == 0 and self.mode in ['idle', 'processing']:
self.in_transcription = False
break
async def llm_and_tts(self):
logging.info('Handling LLM and TTS')
async for text_chunk in llm(self.transcription):
if self.stop_signal:
logging.info('LLM and TTS task stopped')
break
async for audio_chunk in tts_streaming([text_chunk]):
if self.stop_signal:
logging.info('LLM and TTS task stopped during TTS')
break
yield np.frombuffer(audio_chunk, dtype=np.int16)
@app.websocket('/ws')
async def websocket_endpoint(websocket: fastapi.WebSocket):
await websocket.accept()
logging.info('WebSocket connection established')
conversation = Conversation()
audio_buffer = []
buffer_duration = 0.5 # 500ms
try:
while True:
audio_chunk_bytes = await websocket.receive_bytes()
if audio_chunk_bytes is None:
break
audio_chunk = (conversation.sample_rate, np.frombuffer(audio_chunk_bytes, dtype=np.int16))
audio_buffer.append(audio_chunk[1])
# Calculate the duration of the buffered audio
total_samples = sum(len(chunk) for chunk in audio_buffer)
total_duration = total_samples / conversation.sample_rate
if total_duration >= buffer_duration:
# Concatenate buffered audio chunks
buffered_audio = np.concatenate(audio_buffer)
audio_buffer = [] # Reset buffer
# Process the buffered audio
async for audio_data in conversation.process_audio((conversation.sample_rate, buffered_audio)):
if audio_data is not None:
await websocket.send_bytes(audio_data.tobytes())
except Exception as e:
logging.error(f'WebSocket error: {e}')
finally:
logging.info('WebSocket connection closed')
await websocket.close()
@app.get('/')
def index():
return fastapi.responses.FileResponse('index.html')
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=8000)