Spaces:
Building
Building
""" | |
WebSocket Handler for Real-time STT/TTS | |
""" | |
from fastapi import WebSocket, WebSocketDisconnect, HTTPException | |
from typing import Dict, Any, Optional | |
import json | |
import asyncio | |
import base64 | |
from datetime import datetime | |
import sys | |
import numpy as np | |
from enum import Enum | |
from session import Session, session_store | |
from config_provider import ConfigProvider | |
from chat_handler import handle_new_message, handle_parameter_followup | |
from stt_factory import STTFactory | |
from tts_factory import TTSFactory | |
from utils import log | |
# ========================= CONSTANTS ========================= | |
SILENCE_THRESHOLD_MS = 2000 | |
AUDIO_CHUNK_SIZE = 4096 | |
ENERGY_THRESHOLD = 0.01 | |
# ========================= ENUMS ========================= | |
class ConversationState(Enum): | |
IDLE = "idle" | |
LISTENING = "listening" | |
PROCESSING_STT = "processing_stt" | |
PROCESSING_LLM = "processing_llm" | |
PROCESSING_TTS = "processing_tts" | |
PLAYING_AUDIO = "playing_audio" | |
# ========================= CLASSES ========================= | |
class AudioBuffer: | |
"""Buffer for accumulating audio chunks""" | |
def __init__(self): | |
self.chunks = [] | |
self.total_size = 0 | |
def add_chunk(self, chunk_data: str): | |
"""Add base64 encoded audio chunk""" | |
decoded = base64.b64decode(chunk_data) | |
self.chunks.append(decoded) | |
self.total_size += len(decoded) | |
def get_audio(self) -> bytes: | |
"""Get concatenated audio data""" | |
return b''.join(self.chunks) | |
def clear(self): | |
"""Clear buffer""" | |
self.chunks.clear() | |
self.total_size = 0 | |
class SilenceDetector: | |
"""Detect silence in audio stream""" | |
def __init__(self, threshold_ms: int = SILENCE_THRESHOLD_MS, energy_threshold: float = ENERGY_THRESHOLD): | |
self.threshold_ms = threshold_ms | |
self.energy_threshold = energy_threshold | |
self.silence_start = None | |
self.sample_rate = 16000 # Default sample rate | |
def is_silence(self, audio_chunk: bytes) -> bool: | |
"""Check if audio chunk is silence""" | |
try: | |
# Convert bytes to numpy array (assuming 16-bit PCM) | |
audio_data = np.frombuffer(audio_chunk, dtype=np.int16) | |
# Calculate RMS energy | |
rms = np.sqrt(np.mean(audio_data**2)) | |
normalized_rms = rms / 32768.0 # Normalize for 16-bit audio | |
return normalized_rms < self.energy_threshold | |
except Exception as e: | |
log(f"β οΈ Silence detection error: {e}") | |
return False | |
def update(self, audio_chunk: bytes) -> Optional[int]: | |
"""Update silence detection and return silence duration in ms""" | |
is_silent = self.is_silence(audio_chunk) | |
if is_silent: | |
if self.silence_start is None: | |
self.silence_start = datetime.now() | |
log("π Silence started") | |
else: | |
silence_duration = (datetime.now() - self.silence_start).total_seconds() * 1000 | |
return int(silence_duration) | |
else: | |
if self.silence_start is not None: | |
log("π Speech detected, silence broken") | |
self.silence_start = None | |
return 0 | |
class BargeInHandler: | |
"""Handle barge-in (interruption) logic""" | |
def __init__(self): | |
self.interrupted_at_state: Optional[ConversationState] = None | |
self.accumulated_text: str = "" | |
self.pending_audio_chunks = [] | |
def handle_interruption(self, current_state: ConversationState): | |
"""Handle user interruption""" | |
self.interrupted_at_state = current_state | |
log(f"π Barge-in detected at state: {current_state.value}") | |
def should_preserve_context(self) -> bool: | |
"""Check if context should be preserved after interruption""" | |
# Preserve context if interrupted during LLM or TTS processing | |
return self.interrupted_at_state in [ | |
ConversationState.PROCESSING_LLM, | |
ConversationState.PROCESSING_TTS, | |
ConversationState.PLAYING_AUDIO | |
] | |
class ConversationManager: | |
"""Manage conversation state and flow""" | |
def __init__(self, session: Session): | |
self.session = session | |
self.state = ConversationState.IDLE | |
self.audio_buffer = AudioBuffer() | |
self.silence_detector = SilenceDetector() | |
self.barge_in_handler = BargeInHandler() | |
self.stt_manager = None | |
self.current_transcription = "" | |
self.is_streaming = False | |
async def initialize_stt(self): | |
"""Initialize STT provider""" | |
try: | |
self.stt_manager = STTFactory.create_provider() | |
if self.stt_manager: | |
config = ConfigProvider.get().global_config.stt_settings | |
await self.stt_manager.start_streaming({ | |
"language": config.get("language", "tr-TR"), | |
"interim_results": config.get("interim_results", True), | |
"single_utterance": False, # Important for continuous listening | |
"enable_punctuation": config.get("enable_punctuation", True) | |
}) | |
log("β STT manager initialized") | |
return True | |
except Exception as e: | |
log(f"β Failed to initialize STT: {e}") | |
return False | |
def change_state(self, new_state: ConversationState): | |
"""Change conversation state""" | |
old_state = self.state | |
self.state = new_state | |
log(f"π State change: {old_state.value} β {new_state.value}") | |
def handle_barge_in(self): | |
"""Handle user interruption""" | |
self.barge_in_handler.handle_interruption(self.state) | |
self.change_state(ConversationState.LISTENING) | |
def reset_audio_buffer(self): | |
"""Reset audio buffer for new utterance""" | |
self.audio_buffer.clear() | |
self.silence_detector.silence_start = None | |
self.current_transcription = "" | |
# ========================= WEBSOCKET HANDLER ========================= | |
async def websocket_endpoint(websocket: WebSocket, session_id: str): | |
"""Main WebSocket endpoint for real-time conversation""" | |
await websocket.accept() | |
log(f"π WebSocket connected for session: {session_id}") | |
# Get session | |
session = session_store.get_session(session_id) | |
if not session: | |
await websocket.send_json({ | |
"type": "error", | |
"message": "Session not found" | |
}) | |
await websocket.close() | |
return | |
# Initialize conversation manager | |
conversation = ConversationManager(session) | |
# Initialize STT | |
stt_initialized = await conversation.initialize_stt() | |
if not stt_initialized: | |
await websocket.send_json({ | |
"type": "error", | |
"message": "STT initialization failed" | |
}) | |
try: | |
while True: | |
# Receive message | |
message = await websocket.receive_json() | |
message_type = message.get("type") | |
if message_type == "audio_chunk": | |
await handle_audio_chunk(websocket, conversation, message) | |
elif message_type == "control": | |
await handle_control_message(websocket, conversation, message) | |
elif message_type == "ping": | |
# Keep-alive ping | |
await websocket.send_json({"type": "pong"}) | |
except WebSocketDisconnect: | |
log(f"π WebSocket disconnected for session: {session_id}") | |
await cleanup_conversation(conversation) | |
except Exception as e: | |
log(f"β WebSocket error: {e}") | |
await websocket.send_json({ | |
"type": "error", | |
"message": str(e) | |
}) | |
await cleanup_conversation(conversation) | |
# ========================= MESSAGE HANDLERS ========================= | |
async def handle_audio_chunk(websocket: WebSocket, conversation: ConversationManager, message: Dict[str, Any]): | |
"""Handle incoming audio chunk""" | |
try: | |
audio_data = message.get("data") | |
if not audio_data: | |
return | |
# Check for barge-in | |
if conversation.state in [ConversationState.PLAYING_AUDIO, ConversationState.PROCESSING_TTS]: | |
conversation.handle_barge_in() | |
await websocket.send_json({ | |
"type": "control", | |
"action": "stop_playback" | |
}) | |
# Change state to listening if idle | |
if conversation.state == ConversationState.IDLE: | |
conversation.change_state(ConversationState.LISTENING) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": "idle", | |
"to": "listening" | |
}) | |
# Add to buffer | |
conversation.audio_buffer.add_chunk(audio_data) | |
# Decode for processing | |
decoded_audio = base64.b64decode(audio_data) | |
# Check silence | |
silence_duration = conversation.silence_detector.update(decoded_audio) | |
# Stream to STT if available | |
if conversation.stt_manager and conversation.state == ConversationState.LISTENING: | |
async for result in conversation.stt_manager.stream_audio(decoded_audio): | |
# Send interim results | |
await websocket.send_json({ | |
"type": "transcription", | |
"text": result.text, | |
"is_final": result.is_final, | |
"confidence": result.confidence | |
}) | |
if result.is_final: | |
conversation.current_transcription = result.text | |
# Check if user stopped speaking (2 seconds of silence) | |
if silence_duration > SILENCE_THRESHOLD_MS and conversation.current_transcription: | |
log(f"π User stopped speaking after {silence_duration}ms of silence") | |
await process_user_input(websocket, conversation) | |
except Exception as e: | |
log(f"β Audio chunk handling error: {e}") | |
await websocket.send_json({ | |
"type": "error", | |
"message": f"Audio processing error: {str(e)}" | |
}) | |
async def handle_control_message(websocket: WebSocket, conversation: ConversationManager, message: Dict[str, Any]): | |
"""Handle control messages""" | |
action = message.get("action") | |
if action == "start_session": | |
# Session already started | |
await websocket.send_json({ | |
"type": "session_started", | |
"session_id": conversation.session.session_id | |
}) | |
elif action == "end_session": | |
# Clean up and close | |
await cleanup_conversation(conversation) | |
await websocket.close() | |
elif action == "interrupt": | |
# Handle explicit interrupt | |
conversation.handle_barge_in() | |
await websocket.send_json({ | |
"type": "control", | |
"action": "interrupt_acknowledged" | |
}) | |
elif action == "reset": | |
# Reset conversation state | |
conversation.reset_audio_buffer() | |
conversation.change_state(ConversationState.IDLE) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": conversation.state.value, | |
"to": "idle" | |
}) | |
# ========================= PROCESSING FUNCTIONS ========================= | |
async def process_user_input(websocket: WebSocket, conversation: ConversationManager): | |
"""Process complete user input""" | |
try: | |
user_text = conversation.current_transcription | |
if not user_text: | |
conversation.reset_audio_buffer() | |
conversation.change_state(ConversationState.IDLE) | |
return | |
log(f"π¬ Processing user input: {user_text}") | |
# Change state to processing | |
conversation.change_state(ConversationState.PROCESSING_STT) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": "listening", | |
"to": "processing_stt" | |
}) | |
# Send final transcription | |
await websocket.send_json({ | |
"type": "transcription", | |
"text": user_text, | |
"is_final": True, | |
"confidence": 0.95 | |
}) | |
# Process with LLM | |
conversation.change_state(ConversationState.PROCESSING_LLM) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": "processing_stt", | |
"to": "processing_llm" | |
}) | |
# Add to session history | |
conversation.session.add_turn("user", user_text) | |
# Get response based on session state | |
if conversation.session.state == "await_param": | |
response_text = await handle_parameter_followup(conversation.session, user_text) | |
else: | |
response_text = await handle_new_message(conversation.session, user_text) | |
# Add response to history | |
conversation.session.add_turn("assistant", response_text) | |
# Send text response | |
await websocket.send_json({ | |
"type": "assistant_response", | |
"text": response_text | |
}) | |
# Generate TTS if enabled | |
tts_provider = TTSFactory.create_provider() | |
if tts_provider: | |
conversation.change_state(ConversationState.PROCESSING_TTS) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": "processing_llm", | |
"to": "processing_tts" | |
}) | |
# Generate audio | |
audio_data = await tts_provider.synthesize(response_text) | |
# Send audio in chunks | |
chunk_size = 4096 | |
for i in range(0, len(audio_data), chunk_size): | |
chunk = audio_data[i:i + chunk_size] | |
await websocket.send_json({ | |
"type": "tts_audio", | |
"data": base64.b64encode(chunk).decode('utf-8'), | |
"chunk_index": i // chunk_size, | |
"is_last": i + chunk_size >= len(audio_data) | |
}) | |
conversation.change_state(ConversationState.PLAYING_AUDIO) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": "processing_tts", | |
"to": "playing_audio" | |
}) | |
else: | |
# No TTS, go back to idle | |
conversation.change_state(ConversationState.IDLE) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": "processing_llm", | |
"to": "idle" | |
}) | |
# Reset for next input | |
conversation.reset_audio_buffer() | |
except Exception as e: | |
log(f"β Error processing user input: {e}") | |
await websocket.send_json({ | |
"type": "error", | |
"message": f"Processing error: {str(e)}" | |
}) | |
conversation.reset_audio_buffer() | |
conversation.change_state(ConversationState.IDLE) | |
# ========================= CLEANUP ========================= | |
async def cleanup_conversation(conversation: ConversationManager): | |
"""Clean up conversation resources""" | |
try: | |
if conversation.stt_manager: | |
await conversation.stt_manager.stop_streaming() | |
log(f"π§Ή Cleaned up conversation for session: {conversation.session.session_id}") | |
except Exception as e: | |
log(f"β οΈ Cleanup error: {e}") |