aai / playground /refs /test.py
barreloflube's picture
Refactor code to update UI buttons in audio_tab()
70eeaf7
raw
history blame
12.2 kB
import fastapi
import numpy as np
import torch
import torchaudio
from silero_vad import get_speech_timestamps, load_silero_vad
import whisperx
import edge_tts
import gc
import logging
import time
from openai import OpenAI
import threading
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Configure FastAPI
app = fastapi.FastAPI()
# Load Silero VAD model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.info(f'Using device: {device}')
vad_model = load_silero_vad().to(device) # Ensure the model is on the correct device
logging.info('Loaded Silero VAD model')
# Load WhisperX model
whisper_model = whisperx.load_model("tiny", device, compute_type="float16")
logging.info('Loaded WhisperX model')
OPENAI_API_KEY = "sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C" # os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
raise ValueError("OpenAI API key not found.")
# Initialize OpenAI client
openai_client = OpenAI(api_key=OPENAI_API_KEY)
logging.info('Initialized OpenAI client')
# TTS Voice
TTS_VOICE = "en-GB-SoniaNeural"
# Function to check voice activity using Silero VAD
def check_vad(audio_data, sample_rate):
logging.info('Checking voice activity')
# Resample to 16000 Hz if necessary
target_sample_rate = 16000
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
audio_tensor = resampler(torch.from_numpy(audio_data))
else:
audio_tensor = torch.from_numpy(audio_data)
audio_tensor = audio_tensor.to(device)
# Log audio data details
logging.info(f'Audio tensor shape: {audio_tensor.shape}, dtype: {audio_tensor.dtype}, device: {audio_tensor.device}')
# Get speech timestamps
speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)
logging.info(f'Found {len(speech_timestamps)} speech timestamps')
return len(speech_timestamps) > 0
# Function to transcribe audio using WhisperX
def transcript(audio_data, sample_rate):
logging.info('Transcribing audio')
# Resample to 16000 Hz if necessary
target_sample_rate = 16000
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
audio_data = resampler(torch.from_numpy(audio_data)).numpy()
else:
audio_data = audio_data
# Transcribe
batch_size = 16 # Adjust as needed
result = whisper_model.transcribe(audio_data, batch_size=batch_size)
text = result["segments"][0]["text"] if len(result["segments"]) > 0 else ""
logging.info(f'Transcription result: {text}')
# Clear GPU memory
del result
gc.collect()
if device == 'cuda':
torch.cuda.empty_cache()
return text
# Function to get streaming response from OpenAI API
def llm(text):
logging.info('Getting response from OpenAI API')
response = openai_client.chat.completions.create(
model="gpt-4o", # Updated to a more recent model
messages=[
{"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."},
{"role": "user", "content": text}
],
stream=True,
temperature=0.7, # Optional: Adjust as needed
top_p=0.9, # Optional: Adjust as needed
)
for chunk in response:
yield chunk.choices[0].delta.content
# Function to perform TTS per sentence using Edge-TTS
def tts_streaming(text_stream):
logging.info('Performing TTS')
buffer = ""
punctuation = {'.', '!', '?'}
for text_chunk in text_stream:
if text_chunk is not None:
buffer += text_chunk
# Check for sentence completion
sentences = []
start = 0
for i, char in enumerate(buffer):
if (char in punctuation):
sentences.append(buffer[start:i+1].strip())
start = i+1
buffer = buffer[start:]
for sentence in sentences:
if sentence:
communicate = edge_tts.Communicate(sentence, TTS_VOICE)
for chunk in communicate.stream_sync():
if chunk["type"] == "audio":
yield chunk["data"]
# Process any remaining text
if buffer.strip():
communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)
for chunk in communicate.stream_sync():
if chunk["type"] == "audio":
yield chunk["data"]
# Function to handle LLM and TTS
def llm_and_tts(transcribed_text, state):
logging.info('Handling LLM and TTS')
# Get streaming response from LLM
for text_chunk in llm(transcribed_text):
if state.get('stop_signal'):
logging.info('LLM and TTS task stopped')
break
# Get audio data from TTS
for audio_chunk in tts_streaming([text_chunk]):
if state.get('stop_signal'):
logging.info('LLM and TTS task stopped during TTS')
break
yield np.frombuffer(audio_chunk, dtype=np.int16)
state = {
'mode': 'idle',
'chunk_queue': [],
'transcription': '',
'in_transcription': False,
'previous_no_vad_audio': [],
'llm_task': None,
'instream': None,
'stop_signal': False,
'args': {
'sample_rate': 16000,
'chunk_size': 0.5, # seconds
'transcript_chunk_size': 2, # seconds
}
}
def transcript_loop():
while True:
if len(state['chunk_queue']) > 0:
accumulated_audio = np.concatenate(state['chunk_queue'])
total_samples = sum(len(chunk) for chunk in state['chunk_queue'])
total_duration = total_samples / state['sample_rate']
# Run transcription on the first 2 seconds if len > 3 seconds
if total_duration > 3.0 and state['in_transcription'] == True:
first_two_seconds_samples = int(2.0 * state['sample_rate'])
first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]
transcribed_text = transcript(first_two_seconds_audio, state['sample_rate'])
state['transcription'] += transcribed_text
remaining_audio = accumulated_audio[first_two_seconds_samples:]
state['chunk_queue'] = [remaining_audio]
else: # Run transcription on the accumulated audio
transcribed_text = transcript(accumulated_audio, state['sample_rate'])
state['transcription'] += transcribed_text
state['chunk_queue'] = []
state['in_transcription'] = False
else:
time.sleep(0.1)
if len(state['chunk_queue']) == 0 and state['mode'] == any(['idle', 'processing']):
state['in_transcription'] = False
break
def process_audio(audio_chunk):
# returns output audio
sample_rate, audio_data = audio_chunk
audio_data = np.array(audio_data, dtype=np.float32)
# convert to mono if necessary
if audio_data.ndim > 1:
audio_data = np.mean(audio_data, axis=1)
mode = state['mode']
chunk_queue = state['chunk_queue']
transcription = state['transcription']
in_transcription = state['in_transcription']
previous_no_vad_audio = state['previous_no_vad_audio']
llm_task = state['llm_task']
instream = state['instream']
stop_signal = state['stop_signal']
args = state['args']
args['sample_rate'] = sample_rate
# check for voice activity
vad = check_vad(audio_data, sample_rate)
if vad:
logging.info(f'Voice activity detected in mode: {mode}')
if mode == 'idle':
mode = 'listening'
elif mode == 'speaking':
# Stop llm and tts tasks
if llm_task and llm_task.is_alive():
# Implement task cancellation logic if possible
logging.info('Stopping LLM and TTS tasks')
# Since we cannot kill threads directly, we need to handle this in the tasks
stop_signal = True
llm_task.join()
mode = 'listening'
if mode == 'listening':
if previous_no_vad_audio is not None:
chunk_queue.append(previous_no_vad_audio)
previous_no_vad_audio = None
# Accumulate audio chunks
chunk_queue.append(audio_data)
# Start transcription thread if not already running
if not in_transcription:
in_transcription = True
transcription_task = threading.Thread(target=transcript_loop, args=(chunk_queue, sample_rate))
transcription_task.start()
elif mode == 'speaking':
# Continue accumulating audio chunks
chunk_queue.append(audio_data)
else:
logging.info(f'No voice activity detected in mode: {mode}')
if mode == 'listening':
# Add the last chunk to queue
chunk_queue.append(audio_data)
# Change mode to processing
mode = 'processing'
# Wait for transcription to complete
while in_transcription:
time.sleep(0.1)
# Check if transcription is complete
if len(chunk_queue) == 0:
# Start LLM and TTS tasks
if not llm_task or not llm_task.is_alive():
stop_signal = False
llm_task = threading.Thread(target=llm_and_tts, args=(transcription, state))
llm_task.start()
if mode == 'processing':
# Wait for LLM and TTS tasks to start yielding audio
if llm_task and llm_task.is_alive():
mode = 'responding'
if mode == 'responding':
for audio_chunk in llm_task:
if instream is None:
instream = audio_chunk
else:
instream = np.concatenate((instream, audio_chunk))
# Send audio to output stream
yield instream
# Cleanup
llm_task = None
transcription = ''
mode = 'idle'
# Updaate state
state['mode'] = mode
state['chunk_queue'] = chunk_queue
state['transcription'] = transcription
state['in_transcription'] = in_transcription
state['previous_no_vad_audio'] = previous_no_vad_audio
state['llm_task'] = llm_task
state['instream'] = instream
state['stop_signal'] = stop_signal
state['args'] = args
# Store previous audio chunk with no voice activity
previous_no_vad_audio = audio_data
# Update state
state['mode'] = mode
state['chunk_queue'] = chunk_queue
state['transcription'] = transcription
state['in_transcription'] = in_transcription
state['previous_no_vad_audio'] = previous_no_vad_audio
state['llm_task'] = llm_task
state['instream'] = instream
state['stop_signal'] = stop_signal
state['args'] = args
@app.websocket('/ws')
def websocket_endpoint(websocket: fastapi.WebSocket):
logging.info('WebSocket connection established')
try:
while True:
time.sleep(state['args']['chunk_size'])
audio_chunk = websocket.receive_bytes()
if audio_chunk is None:
break
for audio_data in process_audio(audio_chunk):
websocket.send_bytes(audio_data.tobytes())
except Exception as e:
logging.error(f'WebSocket error: {e}')
finally:
logging.info('WebSocket connection closed')
websocket.close()
@app.get('/')
def index():
return fastapi.FileResponse('index.html')