barreloflube's picture
Refactor code to remove OpenAI API key from source file
0718992
raw
history blame
10.1 kB
import fastapi
import numpy as np
import torch
import torchaudio
from silero_vad import get_speech_timestamps, load_silero_vad
import whisperx
import edge_tts
import gc
import logging
import time
import os
from openai import OpenAI
import asyncio
from pydub import AudioSegment
from io import BytesIO
import threading
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Configure FastAPI
app = fastapi.FastAPI()
# Load Silero VAD model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.info(f'Using device: {device}')
vad_model = load_silero_vad().to(device)
logging.info('Loaded Silero VAD model')
# Load WhisperX model
whisper_model = whisperx.load_model("tiny", device, compute_type="float16")
logging.info('Loaded WhisperX model')
OPENAI_API_KEY = ""
if not OPENAI_API_KEY:
logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
raise ValueError("OpenAI API key not found.")
logging.info('Initialized OpenAI client')
llm_client = OpenAI(api_key=OPENAI_API_KEY) # Corrected import
# TTS Voice
TTS_VOICE = "en-GB-SoniaNeural"
# Function to check voice activity using Silero VAD
def check_vad(audio_data, sample_rate):
logging.info('Checking voice activity')
target_sample_rate = 16000
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
audio_tensor = resampler(torch.from_numpy(audio_data))
else:
audio_tensor = torch.from_numpy(audio_data)
audio_tensor = audio_tensor.to(device)
speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)
logging.info(f'Found {len(speech_timestamps)} speech timestamps')
return len(speech_timestamps) > 0
# Async function to transcribe audio using WhisperX
def transcribe(audio_data, sample_rate):
logging.info('Transcribing audio')
target_sample_rate = 16000
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
audio_data = resampler(torch.from_numpy(audio_data)).numpy()
else:
audio_data = audio_data
batch_size = 16 # Adjust as needed
result = whisper_model.transcribe(audio_data, batch_size=batch_size)
text = result["segments"][0]["text"] if len(result["segments"]) > 0 else ""
logging.info(f'Transcription result: {text}')
del result
gc.collect()
if device == 'cuda':
torch.cuda.empty_cache()
return text
# Function to convert text to speech using Edge TTS and stream the audio
def tts_streaming(text_stream):
logging.info('Performing TTS')
buffer = ""
punctuation = {'.', '!', '?'}
for text_chunk in text_stream:
if text_chunk is not None:
buffer += text_chunk
# Check for sentence completion
sentences = []
start = 0
for i, char in enumerate(buffer):
if char in punctuation:
sentences.append(buffer[start:i+1].strip())
start = i+1
buffer = buffer[start:]
for sentence in sentences:
if sentence:
communicate = edge_tts.Communicate(sentence, TTS_VOICE)
for chunk in communicate.stream_sync():
if chunk["type"] == "audio":
yield chunk["data"]
# Process any remaining text
if buffer.strip():
communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)
for chunk in communicate.stream_sync():
if chunk["type"] == "audio":
yield chunk["data"]
# Function to perform language model completion using OpenAI API
def llm(text):
logging.info('Getting response from OpenAI API')
response = llm_client.chat.completions.create(
model="gpt-4o", # Updated to a more recent model
messages=[
{"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."},
{"role": "user", "content": text}
],
stream=True,
temperature=0.7,
top_p=0.9
)
for chunk in response:
yield chunk.choices[0].delta.content
class Conversation:
def __init__(self):
self.mode = 'idle' # idle, listening, speaking
self.audio_stream = []
self.valid_chunk_queue = []
self.first_valid_chunk = None
self.last_valid_chunks = []
self.valid_chunk_transcriptions = ''
self.in_transcription = False
self.llm_n_tts_task = None
self.stop_signal = False
self.sample_rate = 0
self.out_audio_stream = []
self.chunk_buffer = 0.5 # seconds
def llm_n_tts(self):
for text_chunk in llm(self.transcription):
if self.stop_signal:
break
for audio_chunk in tts_streaming([text_chunk]):
if self.stop_signal:
break
self.out_audio_stream.append(np.frombuffer(audio_chunk, dtype=np.int16))
def process_audio_chunk(self, audio_chunk):
# Construct audio stream
audio_data = AudioSegment.from_file(BytesIO(audio_chunk), format="wav")
audio_data = np.array(audio_data.get_array_of_samples())
self.sample_rate = audio_data.frame_rate
# Check for voice activity
vad = check_vad(audio_data, self.sample_rate)
if vad: # Voice activity detected
if self.first_valid_chunk is not None:
self.valid_chunk_queue.append(self.first_valid_chunk)
self.first_valid_chunk = None
self.valid_chunk_queue.append(audio_chunk)
if len(self.valid_chunk_queue) > 2:
# i.e. 3 chunks: 1 non valid chunk + 2 valid chunks
# this is to ensure that the speaker is speaking
if self.mode == 'idle':
self.mode = 'listening'
elif self.mode == 'speaking':
# Stop llm and tts
if self.llm_n_tts_task is not None:
self.stop_signal = True
self.llm_n_tts_task
self.stop_signal = False
self.mode = 'listening'
else: # No voice activity
if self.mode == 'listening':
self.last_valid_chunks.append(audio_chunk)
if len(self.last_valid_chunks) > 2:
# i.e. 2 chunks where the speaker stopped speaking, but we account for natural pauses
# so on the 1.5th second of no voice activity, we append the first 2 of the last valid chunks to the valid chunk queue
# stop listening and start speaking
self.valid_chunk_queue.extend(self.last_valid_chunks[:2])
self.last_valid_chunks = []
while len(self.valid_chunk_queue) > 0:
time.sleep(0.1)
self.mode = 'speaking'
self.llm_n_tts_task = threading.Thread(target=self.llm_n_tts)
self.llm_n_tts_task.start()
def transcribe_loop(self):
while True:
if self.mode == 'listening':
if len(self.valid_chunk_queue) > 0:
accumulated_chunks = np.concatenate(self.valid_chunk_queue)
total_duration = len(accumulated_chunks) / self.sample_rate
if total_duration >= 3.0 and self.in_transcription == True:
# i.e. we have at least 3 seconds of audio so we can start transcribing to reduce latency
first_2s_audio = accumulated_chunks[:int(2 * self.sample_rate)]
transcribed_text = transcribe(first_2s_audio, self.sample_rate)
self.valid_chunk_transcriptions += transcribed_text
self.valid_chunk_queue = [accumulated_chunks[int(2 * self.sample_rate):]]
if self.mode == any(['idle', 'speaking']):
# i.e. the request to stop transcription has been made
# so process the remaining audio
transcribed_text = transcribe(accumulated_chunks, self.sample_rate)
self.valid_chunk_transcriptions += transcribed_text
self.valid_chunk_queue = []
else:
time.sleep(0.1)
def stream_out_audio(self):
while True:
if len(self.out_audio_stream) > 0:
yield AudioSegment(data=self.out_audio_stream.pop(0), sample_width=2, frame_rate=self.sample_rate, channels=1).raw_data
@app.websocket("/ws")
async def websocket_endpoint(websocket: fastapi.WebSocket):
# Accept connection
await websocket.accept()
# Initialize conversation
conversation = Conversation()
# Start conversation threads
transcribe_thread = threading.Thread(target=conversation.transcribe_loop)
transcribe_thread.start()
# Process audio chunks
chunk_buffer_size = conversation.chunk_buffer
while True:
try:
audio_chunk = await websocket.receive_bytes()
conversation.process_audio_chunk(audio_chunk)
if conversation.mode == 'speaking':
for audio_chunk in conversation.stream_out_audio():
await websocket.send_bytes(audio_chunk)
else:
await websocket.send_bytes(b'')
except Exception as e:
logging.error(e)
break
@app.get("/")
async def index():
return fastapi.responses.FileResponse("index.html")
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=8000)