{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import fastapi\n", "import numpy as np\n", "import torch\n", "import torchaudio\n", "from silero_vad import get_speech_timestamps, load_silero_vad\n", "import whisperx\n", "import edge_tts\n", "import gc\n", "import logging\n", "import time\n", "from openai import OpenAI\n", "import threading\n", "import asyncio\n", "\n", "# Configure logging\n", "logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n", "\n", "# Configure FastAPI\n", "app = fastapi.FastAPI()\n", "\n", "# Load Silero VAD model\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "logging.info(f'Using device: {device}')\n", "vad_model = load_silero_vad().to(device) # Ensure the model is on the correct device\n", "logging.info('Loaded Silero VAD model')\n", "\n", "# Load WhisperX model\n", "whisper_model = whisperx.load_model(\"tiny\", device, compute_type=\"float16\")\n", "logging.info('Loaded WhisperX model')\n", "\n", "OPENAI_API_KEY = \"\" # os.getenv(\"OPENAI_API_KEY\")\n", "if not OPENAI_API_KEY:\n", " logging.error(\"OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.\")\n", " raise ValueError(\"OpenAI API key not found.\")\n", "\n", "# Initialize OpenAI client\n", "openai_client = OpenAI(api_key=OPENAI_API_KEY)\n", "logging.info('Initialized OpenAI client')\n", "\n", "# TTS Voice\n", "TTS_VOICE = \"en-GB-SoniaNeural\"\n", "\n", "# Function to check voice activity using Silero VAD\n", "def check_vad(audio_data, sample_rate):\n", " logging.info('Checking voice activity')\n", " # Resample to 16000 Hz if necessary\n", " target_sample_rate = 16000\n", " if sample_rate != target_sample_rate:\n", " resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)\n", " audio_tensor = resampler(torch.from_numpy(audio_data))\n", " else:\n", " audio_tensor = torch.from_numpy(audio_data)\n", " audio_tensor = audio_tensor.to(device)\n", "\n", " # Log audio data details\n", " logging.info(f'Audio tensor shape: {audio_tensor.shape}, dtype: {audio_tensor.dtype}, device: {audio_tensor.device}')\n", "\n", " # Get speech timestamps\n", " speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate)\n", " logging.info(f'Found {len(speech_timestamps)} speech timestamps')\n", " return len(speech_timestamps) > 0\n", "\n", "# Function to transcribe audio using WhisperX\n", "def transcript(audio_data, sample_rate):\n", " logging.info('Transcribing audio')\n", " # Resample to 16000 Hz if necessary\n", " target_sample_rate = 16000\n", " if sample_rate != target_sample_rate:\n", " resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)\n", " audio_data = resampler(torch.from_numpy(audio_data)).numpy()\n", " else:\n", " audio_data = audio_data\n", "\n", " # Transcribe\n", " batch_size = 16 # Adjust as needed\n", " result = whisper_model.transcribe(audio_data, batch_size=batch_size)\n", " text = result[\"segments\"][0][\"text\"] if len(result[\"segments\"]) > 0 else \"\"\n", " logging.info(f'Transcription result: {text}')\n", " # Clear GPU memory\n", " del result\n", " gc.collect()\n", " if device == 'cuda':\n", " torch.cuda.empty_cache()\n", " return text\n", "\n", "# Function to get streaming response from OpenAI API\n", "def llm(text):\n", " logging.info('Getting response from OpenAI API')\n", " response = openai_client.chat.completions.create(\n", " model=\"gpt-4o\", # Updated to a more recent model\n", " messages=[\n", " {\"role\": \"system\", \"content\": \"You respond to the following transcript from the conversation that you are having with the user.\"},\n", " {\"role\": \"user\", \"content\": text} \n", " ],\n", " stream=True,\n", " temperature=0.7, # Optional: Adjust as needed\n", " top_p=0.9, # Optional: Adjust as needed\n", " )\n", " for chunk in response:\n", " yield chunk.choices[0].delta.content\n", "\n", "# Function to perform TTS per sentence using Edge-TTS\n", "def tts_streaming(text_stream):\n", " logging.info('Performing TTS')\n", " buffer = \"\"\n", " punctuation = {'.', '!', '?'}\n", " for text_chunk in text_stream:\n", " if text_chunk is not None:\n", " buffer += text_chunk\n", " # Check for sentence completion\n", " sentences = []\n", " start = 0\n", " for i, char in enumerate(buffer):\n", " if (char in punctuation):\n", " sentences.append(buffer[start:i+1].strip())\n", " start = i+1\n", " buffer = buffer[start:]\n", "\n", " for sentence in sentences:\n", " if sentence:\n", " communicate = edge_tts.Communicate(sentence, TTS_VOICE)\n", " for chunk in communicate.stream_sync():\n", " if chunk[\"type\"] == \"audio\":\n", " yield chunk[\"data\"]\n", " # Process any remaining text\n", " if buffer.strip():\n", " communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)\n", " for chunk in communicate.stream_sync():\n", " if chunk[\"type\"] == \"audio\":\n", " yield chunk[\"data\"]\n", "\n", "# Function to handle LLM and TTS\n", "def llm_and_tts(transcribed_text):\n", " logging.info('Handling LLM and TTS')\n", " # Get streaming response from LLM\n", " for text_chunk in llm(transcribed_text):\n", " if state.get('stop_signal'):\n", " logging.info('LLM and TTS task stopped')\n", " break\n", " # Get audio data from TTS\n", " for audio_chunk in tts_streaming([text_chunk]):\n", " if state.get('stop_signal'):\n", " logging.info('LLM and TTS task stopped during TTS')\n", " break\n", " yield np.frombuffer(audio_chunk, dtype=np.int16)\n", "\n", "state = {\n", " 'mode': 'idle',\n", " 'chunk_queue': [],\n", " 'transcription': '',\n", " 'in_transcription': False,\n", " 'previous_no_vad_audio': [],\n", " 'llm_task': None,\n", " 'instream': None,\n", " 'stop_signal': False,\n", " 'args': {\n", " 'sample_rate': 16000,\n", " 'chunk_size': 0.5, # seconds\n", " 'transcript_chunk_size': 2, # seconds\n", " }\n", "}\n", "\n", "def transcript_loop():\n", " while True:\n", " if len(state['chunk_queue']) > 0:\n", " accumulated_audio = np.concatenate(state['chunk_queue'])\n", " total_samples = sum(len(chunk) for chunk in state['chunk_queue'])\n", " total_duration = total_samples / state['args']['sample_rate']\n", " \n", " # Run transcription on the first 2 seconds if len > 3 seconds\n", " if total_duration > 3.0 and state['in_transcription'] == True:\n", " first_two_seconds_samples = int(2.0 * state['args']['sample_rate'])\n", " first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]\n", " transcribed_text = transcript(first_two_seconds_audio, state['args']['sample_rate'])\n", " state['transcription'] += transcribed_text\n", " remaining_audio = accumulated_audio[first_two_seconds_samples:]\n", " state['chunk_queue'] = [remaining_audio]\n", " else: # Run transcription on the accumulated audio\n", " transcribed_text = transcript(accumulated_audio, state['args']['sample_rate'])\n", " state['transcription'] += transcribed_text\n", " state['chunk_queue'] = []\n", " state['in_transcription'] = False\n", " else:\n", " time.sleep(0.1)\n", "\n", " if len(state['chunk_queue']) == 0 and state['mode'] == any(['idle', 'processing']):\n", " state['in_transcription'] = False\n", " break\n", "\n", "def process_audio(audio_chunk):\n", " # returns output audio\n", " \n", " sample_rate, audio_data = audio_chunk\n", " audio_data = np.array(audio_data, dtype=np.float32)\n", " \n", " # convert to mono if necessary\n", " if audio_data.ndim > 1:\n", " audio_data = np.mean(audio_data, axis=1)\n", "\n", " mode = state['mode']\n", " chunk_queue = state['chunk_queue']\n", " transcription = state['transcription']\n", " in_transcription = state['in_transcription']\n", " previous_no_vad_audio = state['previous_no_vad_audio']\n", " llm_task = state['llm_task']\n", " instream = state['instream']\n", " stop_signal = state['stop_signal']\n", " args = state['args']\n", " \n", " args['sample_rate'] = sample_rate\n", " \n", " # check for voice activity\n", " vad = check_vad(audio_data, sample_rate)\n", " \n", " if vad:\n", " logging.info(f'Voice activity detected in mode: {mode}')\n", " if mode == 'idle':\n", " mode = 'listening'\n", " elif mode == 'speaking':\n", " # Stop llm and tts tasks\n", " if llm_task and llm_task.is_alive():\n", " # Implement task cancellation logic if possible\n", " logging.info('Stopping LLM and TTS tasks')\n", " # Since we cannot kill threads directly, we need to handle this in the tasks\n", " stop_signal = True\n", " llm_task.join()\n", " mode = 'listening'\n", "\n", " if mode == 'listening':\n", " if previous_no_vad_audio is not None:\n", " chunk_queue.append(previous_no_vad_audio)\n", " previous_no_vad_audio = None\n", " # Accumulate audio chunks\n", " chunk_queue.append(audio_data)\n", " \n", " # Start transcription thread if not already running\n", " if not in_transcription:\n", " in_transcription = True\n", " transcription_task = threading.Thread(target=transcript_loop)\n", " transcription_task.start()\n", " \n", " elif mode == 'speaking':\n", " # Continue accumulating audio chunks\n", " chunk_queue.append(audio_data)\n", " else:\n", " logging.info(f'No voice activity detected in mode: {mode}')\n", " if mode == 'listening':\n", " # Add the last chunk to queue\n", " chunk_queue.append(audio_data)\n", " \n", " # Change mode to processing\n", " mode = 'processing'\n", " \n", " # Wait for transcription to complete\n", " while in_transcription:\n", " time.sleep(0.1)\n", " \n", " # Check if transcription is complete\n", " if len(chunk_queue) == 0:\n", " # Start LLM and TTS tasks\n", " if not llm_task or not llm_task.is_alive():\n", " stop_signal = False\n", " llm_task = threading.Thread(target=llm_and_tts, args=(transcription))\n", " llm_task.start()\n", " \n", " if mode == 'processing':\n", " # Wait for LLM and TTS tasks to start yielding audio\n", " if llm_task and llm_task.is_alive():\n", " mode = 'responding'\n", " \n", " if mode == 'responding':\n", " for audio_chunk in llm_task:\n", " if instream is None:\n", " instream = audio_chunk\n", " else:\n", " instream = np.concatenate((instream, audio_chunk))\n", " \n", " # Send audio to output stream\n", " yield instream\n", " \n", " # Cleanup\n", " llm_task = None\n", " transcription = ''\n", " mode = 'idle'\n", " \n", " # Updaate state\n", " state['mode'] = mode\n", " state['chunk_queue'] = chunk_queue\n", " state['transcription'] = transcription\n", " state['in_transcription'] = in_transcription\n", " state['previous_no_vad_audio'] = previous_no_vad_audio\n", " state['llm_task'] = llm_task\n", " state['instream'] = instream\n", " state['stop_signal'] = stop_signal\n", " state['args'] = args\n", " \n", " # Store previous audio chunk with no voice activity\n", " previous_no_vad_audio = audio_data\n", " \n", " # Update state\n", " state['mode'] = mode\n", " state['chunk_queue'] = chunk_queue\n", " state['transcription'] = transcription\n", " state['in_transcription'] = in_transcription\n", " state['previous_no_vad_audio'] = previous_no_vad_audio\n", " state['llm_task'] = llm_task\n", " state['instream'] = instream\n", " state['stop_signal'] = stop_signal\n", " state['args'] = args" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 1. Load audio.mp3\n", "# 2. Split audio into chunks\n", "# 3. Process each chunk inside a loop\n", "\n", "# Split audio into chunks of 500 ms or less\n", "from pydub import AudioSegment\n", "audio_segment = AudioSegment.from_file('audio.mp3')\n", "chunks = [chunk for chunk in audio_segment[::500]]\n", "chunks[0]\n", "chunks = [(chunk.frame_rate, np.array(chunk.get_array_of_samples(), dtype=np.int16)) for chunk in chunks]\n", "\n", "output_audio = []\n", "# Process each chunk\n", "for chunk in chunks:\n", " for audio_chunk in process_audio(chunk):\n", " output_audio.append(audio_chunk)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "output_audio" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import asyncio\n", "import websockets\n", "from pydub import AudioSegment\n", "import numpy as np\n", "import simpleaudio as sa\n", "\n", "# Constants\n", "AUDIO_FILE = 'audio.mp3' # Input audio file\n", "CHUNK_DURATION_MS = 250 # Duration of each chunk in milliseconds\n", "WEBSOCKET_URI = 'ws://localhost:8000/ws' # WebSocket endpoint\n", "\n", "async def send_audio_chunks(uri):\n", " # Load audio file using pydub\n", " audio = AudioSegment.from_file(AUDIO_FILE)\n", "\n", " # Ensure audio is mono and 16kHz\n", " if audio.channels > 1:\n", " audio = audio.set_channels(1)\n", " if audio.frame_rate != 16000:\n", " audio = audio.set_frame_rate(16000)\n", " if audio.sample_width != 2: # 2 bytes for int16\n", " audio = audio.set_sample_width(2)\n", "\n", " # Split audio into chunks\n", " chunks = [audio[i:i+CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)]\n", "\n", " # Store received audio data\n", " received_audio_data = b''\n", "\n", " async with websockets.connect(uri) as websocket:\n", " print(\"Connected to server.\")\n", " for idx, chunk in enumerate(chunks):\n", " # Get raw audio data\n", " raw_data = chunk.raw_data\n", "\n", " # Send audio chunk to server\n", " await websocket.send(raw_data)\n", " print(f\"Sent chunk {idx+1}/{len(chunks)}\")\n", "\n", " # Receive response (non-blocking)\n", " try:\n", " response = await asyncio.wait_for(websocket.recv(), timeout=0.1)\n", " if isinstance(response, bytes):\n", " received_audio_data += response\n", " print(f\"Received audio data of length {len(response)} bytes\")\n", " except asyncio.TimeoutError:\n", " pass # No response received yet\n", "\n", " # Simulate real-time by waiting for chunk duration\n", " await asyncio.sleep(CHUNK_DURATION_MS / 1000.0)\n", "\n", " # Send a final empty message to indicate end of transmission\n", " await websocket.send(b'')\n", " print(\"Finished sending audio. Waiting for responses...\")\n", "\n", " # Receive any remaining responses\n", " while True:\n", " try:\n", " response = await asyncio.wait_for(websocket.recv(), timeout=1)\n", " if isinstance(response, bytes):\n", " received_audio_data += response\n", " print(f\"Received audio data of length {len(response)} bytes\")\n", " except asyncio.TimeoutError:\n", " print(\"No more responses. Closing connection.\")\n", " break\n", "\n", " print(\"Connection closed.\")\n", "\n", " # Save received audio data to a file or play it\n", " if received_audio_data:\n", " # Convert bytes to numpy array\n", " audio_array = np.frombuffer(received_audio_data, dtype=np.int16)\n", "\n", " # Play audio using simpleaudio\n", " play_obj = sa.play_buffer(audio_array, 1, 2, 16000)\n", " play_obj.wait_done()\n", "\n", " # Optionally, save to a WAV file\n", " output_audio = AudioSegment(\n", " data=received_audio_data,\n", " sample_width=2, # 2 bytes for int16\n", " frame_rate=16000,\n", " channels=1\n", " )\n", " output_audio.export(\"output_response.wav\", format=\"wav\")\n", " print(\"Saved response audio to 'output_response.wav'\")\n", " else:\n", " print(\"No audio data received.\")\n", "\n", "def main():\n", " asyncio.run(send_audio_chunks(WEBSOCKET_URI))\n", "\n", "if __name__ == '__main__':\n", " main()" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }