|
import fastapi |
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
from silero_vad import get_speech_timestamps, load_silero_vad |
|
import whisperx |
|
import edge_tts |
|
import gc |
|
import logging |
|
import time |
|
from openai import OpenAI |
|
import threading |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
app = fastapi.FastAPI() |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
logging.info(f'Using device: {device}') |
|
vad_model = load_silero_vad().to(device) |
|
logging.info('Loaded Silero VAD model') |
|
|
|
|
|
whisper_model = whisperx.load_model("tiny", device, compute_type="float16") |
|
logging.info('Loaded WhisperX model') |
|
|
|
OPENAI_API_KEY = "sk-proj-gcrtuxd5qzaRYT82Ii3eT3BlbkFJpVQHBc9ZJrmSksLbQc3C" |
|
if not OPENAI_API_KEY: |
|
logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.") |
|
raise ValueError("OpenAI API key not found.") |
|
|
|
|
|
openai_client = OpenAI(api_key=OPENAI_API_KEY) |
|
logging.info('Initialized OpenAI client') |
|
|
|
|
|
TTS_VOICE = "en-GB-SoniaNeural" |
|
|
|
|
|
def check_vad(audio_data, sample_rate): |
|
logging.info('Checking voice activity') |
|
|
|
target_sample_rate = 16000 |
|
if sample_rate != target_sample_rate: |
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) |
|
audio_tensor = resampler(torch.from_numpy(audio_data)) |
|
else: |
|
audio_tensor = torch.from_numpy(audio_data) |
|
audio_tensor = audio_tensor.to(device) |
|
|
|
|
|
logging.info(f'Audio tensor shape: {audio_tensor.shape}, dtype: {audio_tensor.dtype}, device: {audio_tensor.device}') |
|
|
|
|
|
speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate) |
|
logging.info(f'Found {len(speech_timestamps)} speech timestamps') |
|
return len(speech_timestamps) > 0 |
|
|
|
|
|
def transcript(audio_data, sample_rate): |
|
logging.info('Transcribing audio') |
|
|
|
target_sample_rate = 16000 |
|
if sample_rate != target_sample_rate: |
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) |
|
audio_data = resampler(torch.from_numpy(audio_data)).numpy() |
|
else: |
|
audio_data = audio_data |
|
|
|
|
|
batch_size = 16 |
|
result = whisper_model.transcribe(audio_data, batch_size=batch_size) |
|
text = result["segments"][0]["text"] if len(result["segments"]) > 0 else "" |
|
logging.info(f'Transcription result: {text}') |
|
|
|
del result |
|
gc.collect() |
|
if device == 'cuda': |
|
torch.cuda.empty_cache() |
|
return text |
|
|
|
|
|
def llm(text): |
|
logging.info('Getting response from OpenAI API') |
|
response = openai_client.chat.completions.create( |
|
model="gpt-4o", |
|
messages=[ |
|
{"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."}, |
|
{"role": "user", "content": text} |
|
], |
|
stream=True, |
|
temperature=0.7, |
|
top_p=0.9, |
|
) |
|
for chunk in response: |
|
yield chunk.choices[0].delta.content |
|
|
|
|
|
def tts_streaming(text_stream): |
|
logging.info('Performing TTS') |
|
buffer = "" |
|
punctuation = {'.', '!', '?'} |
|
for text_chunk in text_stream: |
|
if text_chunk is not None: |
|
buffer += text_chunk |
|
|
|
sentences = [] |
|
start = 0 |
|
for i, char in enumerate(buffer): |
|
if (char in punctuation): |
|
sentences.append(buffer[start:i+1].strip()) |
|
start = i+1 |
|
buffer = buffer[start:] |
|
|
|
for sentence in sentences: |
|
if sentence: |
|
communicate = edge_tts.Communicate(sentence, TTS_VOICE) |
|
for chunk in communicate.stream_sync(): |
|
if chunk["type"] == "audio": |
|
yield chunk["data"] |
|
|
|
if buffer.strip(): |
|
communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE) |
|
for chunk in communicate.stream_sync(): |
|
if chunk["type"] == "audio": |
|
yield chunk["data"] |
|
|
|
|
|
def llm_and_tts(transcribed_text, state): |
|
logging.info('Handling LLM and TTS') |
|
|
|
for text_chunk in llm(transcribed_text): |
|
if state.get('stop_signal'): |
|
logging.info('LLM and TTS task stopped') |
|
break |
|
|
|
for audio_chunk in tts_streaming([text_chunk]): |
|
if state.get('stop_signal'): |
|
logging.info('LLM and TTS task stopped during TTS') |
|
break |
|
yield np.frombuffer(audio_chunk, dtype=np.int16) |
|
|
|
state = { |
|
'mode': 'idle', |
|
'chunk_queue': [], |
|
'transcription': '', |
|
'in_transcription': False, |
|
'previous_no_vad_audio': [], |
|
'llm_task': None, |
|
'instream': None, |
|
'stop_signal': False, |
|
'args': { |
|
'sample_rate': 16000, |
|
'chunk_size': 0.5, |
|
'transcript_chunk_size': 2, |
|
} |
|
} |
|
|
|
def transcript_loop(): |
|
while True: |
|
if len(state['chunk_queue']) > 0: |
|
accumulated_audio = np.concatenate(state['chunk_queue']) |
|
total_samples = sum(len(chunk) for chunk in state['chunk_queue']) |
|
total_duration = total_samples / state['sample_rate'] |
|
|
|
|
|
if total_duration > 3.0 and state['in_transcription'] == True: |
|
first_two_seconds_samples = int(2.0 * state['sample_rate']) |
|
first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples] |
|
transcribed_text = transcript(first_two_seconds_audio, state['sample_rate']) |
|
state['transcription'] += transcribed_text |
|
remaining_audio = accumulated_audio[first_two_seconds_samples:] |
|
state['chunk_queue'] = [remaining_audio] |
|
else: |
|
transcribed_text = transcript(accumulated_audio, state['sample_rate']) |
|
state['transcription'] += transcribed_text |
|
state['chunk_queue'] = [] |
|
state['in_transcription'] = False |
|
else: |
|
time.sleep(0.1) |
|
|
|
if len(state['chunk_queue']) == 0 and state['mode'] == any(['idle', 'processing']): |
|
state['in_transcription'] = False |
|
break |
|
|
|
def process_audio(audio_chunk): |
|
|
|
|
|
sample_rate, audio_data = audio_chunk |
|
audio_data = np.array(audio_data, dtype=np.float32) |
|
|
|
|
|
if audio_data.ndim > 1: |
|
audio_data = np.mean(audio_data, axis=1) |
|
|
|
mode = state['mode'] |
|
chunk_queue = state['chunk_queue'] |
|
transcription = state['transcription'] |
|
in_transcription = state['in_transcription'] |
|
previous_no_vad_audio = state['previous_no_vad_audio'] |
|
llm_task = state['llm_task'] |
|
instream = state['instream'] |
|
stop_signal = state['stop_signal'] |
|
args = state['args'] |
|
|
|
args['sample_rate'] = sample_rate |
|
|
|
|
|
vad = check_vad(audio_data, sample_rate) |
|
|
|
if vad: |
|
logging.info(f'Voice activity detected in mode: {mode}') |
|
if mode == 'idle': |
|
mode = 'listening' |
|
elif mode == 'speaking': |
|
|
|
if llm_task and llm_task.is_alive(): |
|
|
|
logging.info('Stopping LLM and TTS tasks') |
|
|
|
stop_signal = True |
|
llm_task.join() |
|
mode = 'listening' |
|
|
|
if mode == 'listening': |
|
if previous_no_vad_audio is not None: |
|
chunk_queue.append(previous_no_vad_audio) |
|
previous_no_vad_audio = None |
|
|
|
chunk_queue.append(audio_data) |
|
|
|
|
|
if not in_transcription: |
|
in_transcription = True |
|
transcription_task = threading.Thread(target=transcript_loop, args=(chunk_queue, sample_rate)) |
|
transcription_task.start() |
|
|
|
elif mode == 'speaking': |
|
|
|
chunk_queue.append(audio_data) |
|
else: |
|
logging.info(f'No voice activity detected in mode: {mode}') |
|
if mode == 'listening': |
|
|
|
chunk_queue.append(audio_data) |
|
|
|
|
|
mode = 'processing' |
|
|
|
|
|
while in_transcription: |
|
time.sleep(0.1) |
|
|
|
|
|
if len(chunk_queue) == 0: |
|
|
|
if not llm_task or not llm_task.is_alive(): |
|
stop_signal = False |
|
llm_task = threading.Thread(target=llm_and_tts, args=(transcription, state)) |
|
llm_task.start() |
|
|
|
if mode == 'processing': |
|
|
|
if llm_task and llm_task.is_alive(): |
|
mode = 'responding' |
|
|
|
if mode == 'responding': |
|
for audio_chunk in llm_task: |
|
if instream is None: |
|
instream = audio_chunk |
|
else: |
|
instream = np.concatenate((instream, audio_chunk)) |
|
|
|
|
|
yield instream |
|
|
|
|
|
llm_task = None |
|
transcription = '' |
|
mode = 'idle' |
|
|
|
|
|
state['mode'] = mode |
|
state['chunk_queue'] = chunk_queue |
|
state['transcription'] = transcription |
|
state['in_transcription'] = in_transcription |
|
state['previous_no_vad_audio'] = previous_no_vad_audio |
|
state['llm_task'] = llm_task |
|
state['instream'] = instream |
|
state['stop_signal'] = stop_signal |
|
state['args'] = args |
|
|
|
|
|
previous_no_vad_audio = audio_data |
|
|
|
|
|
state['mode'] = mode |
|
state['chunk_queue'] = chunk_queue |
|
state['transcription'] = transcription |
|
state['in_transcription'] = in_transcription |
|
state['previous_no_vad_audio'] = previous_no_vad_audio |
|
state['llm_task'] = llm_task |
|
state['instream'] = instream |
|
state['stop_signal'] = stop_signal |
|
state['args'] = args |
|
|
|
|
|
@app.websocket('/ws') |
|
def websocket_endpoint(websocket: fastapi.WebSocket): |
|
logging.info('WebSocket connection established') |
|
try: |
|
while True: |
|
time.sleep(state['args']['chunk_size']) |
|
audio_chunk = websocket.receive_bytes() |
|
if audio_chunk is None: |
|
break |
|
for audio_data in process_audio(audio_chunk): |
|
websocket.send_bytes(audio_data.tobytes()) |
|
except Exception as e: |
|
logging.error(f'WebSocket error: {e}') |
|
finally: |
|
logging.info('WebSocket connection closed') |
|
websocket.close() |
|
|
|
@app.get('/') |
|
def index(): |
|
return fastapi.FileResponse('index.html') |
|
|