Spaces:
Sleeping
Sleeping
Changes made in requirements.txt, dockerfile and main.py to handle twilio websocket connection
6ee9d08 | import os | |
| import base64 | |
| import logging | |
| import json | |
| import re | |
| from contextlib import asynccontextmanager | |
| from typing import Optional | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, status, Depends, Header, HTTPException | |
| from fastapi.concurrency import run_in_threadpool | |
| from pydantic import BaseModel | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| from elevenlabs.client import ElevenLabs | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_postgres.vectorstores import PGVector | |
| from sqlalchemy import create_engine | |
| # --- NEW IMPORTS FOR TWILIO INTEGRATION --- | |
| import asyncio | |
| import audioop | |
| import wave | |
| import io | |
| from pydub import AudioSegment | |
| # --- SETUP --- | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
| logging.getLogger('tensorflow').setLevel(logging.ERROR) | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Load environment variables | |
| load_dotenv() | |
| NEON_DATABASE_URL = os.getenv("NEON_DATABASE_URL") | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| ELEVENLABS_API_KEY = os.getenv("ELEVENLABS_API_KEY") | |
| SHARED_SECRET = os.getenv("SHARED_SECRET") | |
| # --- CONFIGURATION --- | |
| COLLECTION_NAME = "real_estate_embeddings" | |
| EMBEDDING_MODEL = "hkunlp/instructor-large" | |
| ELEVENLABS_VOICE_NAME = "Leo" | |
| PLANNER_MODEL = "gpt-4o-mini" | |
| ANSWERER_MODEL = "gpt-4o" | |
| TABLE_DESCRIPTIONS = """ | |
| - "ongoing_projects_source": Details about projects currently under construction. | |
| - "upcoming_projects_source": Information on future planned projects. | |
| - "completed_projects_source": Facts about projects that are already finished. | |
| - "historical_sales_source": Specific sales records, including price, date, and property ID. | |
| - "past_customers_source": Information about previous customers. | |
| - "feedback_source": Customer feedback and ratings for projects. | |
| """ | |
| # VAD Configuration | |
| SILENCE_THRESHOLD = 1000 # RMS threshold for speech detection (tune based on testing) | |
| MAX_AUDIO_BYTES = 80000 # Max buffer ~10s at 8kHz (prevent overflow) | |
| # Max loop iterations to avoid infinite loops (safety) | |
| MAX_LOOP_COUNT = 1200 | |
| # --- GLOBAL VARIABLES FOR LIFESPAN --- | |
| embeddings = None | |
| vector_store = None | |
| # Initialize clients (will be used after load_dotenv) | |
| client_openai = OpenAI(api_key=OPENAI_API_KEY) | |
| client_elevenlabs = ElevenLabs(api_key=ELEVENLABS_API_KEY) | |
| # --- LIFESPAN / STARTUP --- | |
| async def lifespan(app: FastAPI): | |
| global embeddings, vector_store | |
| logging.info(f"Initializing embedding model: '{EMBEDDING_MODEL}'...") | |
| embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) | |
| logging.info("Embedding model loaded successfully.") | |
| logging.info(f"Connecting to vector store '{COLLECTION_NAME}'...") | |
| engine = create_engine(NEON_DATABASE_URL, pool_pre_ping=True) | |
| vector_store = PGVector( | |
| connection=engine, | |
| collection_name=COLLECTION_NAME, | |
| embeddings=embeddings, | |
| ) | |
| logging.info("Successfully connected to the vector store.") | |
| yield | |
| logging.info("Application shutting down.") | |
| # --- FASTAPI APP --- | |
| app = FastAPI(lifespan=lifespan) | |
| # --- PROMPTS --- | |
| QUERY_FORMULATION_PROMPT = f""" | |
| You are a query analysis agent. Your task is to transform a user's query into a precise search query for a vector database and determine the correct table to filter by. | |
| **Available Tables:** | |
| {TABLE_DESCRIPTIONS} | |
| **User's Query:** "{{user_query}}" | |
| **Your Task:** | |
| 1. Rephrase the user's query into a clear, keyword-focused English question suitable for a database search. | |
| 2. Analyze the user's query for keywords indicating project status (e.g., "ongoing", "under construction", "completed", "finished", "upcoming", "new launch"). | |
| 3. If such status keywords are present, identify the single most relevant table from the list above to filter by. | |
| 4. If no specific status keywords are mentioned (e.g., the user asks generally about projects in a location), set the filter table to null. | |
| 5. Respond ONLY with a JSON object containing "search_query" and "filter_table" (which should be the table name string or null). | |
| """ | |
| ANSWER_SYSTEM_PROMPT = """ | |
| You are an expert AI assistant for a premier real estate developer. | |
| ## YOUR PERSONA | |
| - You are professional, helpful, and highly knowledgeable. Your tone should be polite and articulate. | |
| ## CORE BUSINESS KNOWLEDGE | |
| - **Operational Cities:** We are currently operational in Pune, Mumbai, Bengaluru, Delhi, Chennai, Hyderabad, Goa, Gurgaon, Kolkata. | |
| - **Property Types:** We offer luxury apartments, villas, and commercial properties. | |
| - **Budget Range:** Our residential properties typically range from 45 lakhs to 5 crores. | |
| ## CORE RULES | |
| 1. **Language Adaptation:** If the user's original query was in Hinglish, respond in Hinglish. If in English, respond in English. | |
| 2. **Fact-Based Answers:** Use the provided CONTEXT to answer the user's question. If the context is empty, use your Core Business Knowledge. | |
| 3. **Stay on Topic:** Only answer questions related to real estate. | |
| """ | |
| # --- HELPER FUNCTIONS (sync helpers executed in threadpool) --- | |
| def convert_mulaw_to_wav_bytes(mulaw_bytes: bytes) -> bytes: | |
| """Converts raw mulaw audio bytes (8kHz) to in-memory WAV file bytes.""" | |
| try: | |
| pcm_bytes = audioop.ulaw2lin(mulaw_bytes, 2) | |
| with io.BytesIO() as wav_buffer: | |
| with wave.open(wav_buffer, 'wb') as wav_file: | |
| wav_file.setnchannels(1) | |
| wav_file.setsampwidth(2) | |
| wav_file.setframerate(8000) | |
| wav_file.writeframes(pcm_bytes) | |
| return wav_buffer.getvalue() | |
| except Exception as e: | |
| logging.error(f"Error converting mulaw to WAV: {e}", exc_info=True) | |
| return b'' | |
| def transcribe_audio_sync(audio_wav_bytes: bytes) -> str: | |
| """Synchronous transcription using the OpenAI client (to be called inside threadpool).""" | |
| for attempt in range(3): | |
| try: | |
| audio_file = io.BytesIO(audio_wav_bytes) | |
| audio_file.name = "stream.wav" | |
| transcript = client_openai.audio.transcriptions.create(model="whisper-1", file=audio_file) | |
| text = transcript.text | |
| # If Hindi script present, transliterate to Roman (Hinglish) | |
| if re.search(r'[\u0900-\u097F]', text): | |
| translit_prompt = f"Transliterate this Hindi text to Roman script (Hinglish style): {text}" | |
| response = client_openai.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[{"role": "user", "content": translit_prompt}], | |
| temperature=0.0 | |
| ) | |
| text = response.choices[0].message.content | |
| return text | |
| except Exception as e: | |
| logging.error(f"Error during transcription (attempt {attempt+1}): {e}", exc_info=True) | |
| if attempt == 2: | |
| return "" | |
| def convert_audio_to_mulaw_sync(audio_bytes: bytes) -> bytes: | |
| """Synchronous conversion of arbitrary audio bytes to 8kHz mulaw (for Twilio).""" | |
| for attempt in range(3): | |
| try: | |
| audio_segment = AudioSegment.from_file(io.BytesIO(audio_bytes)) | |
| audio_segment = audio_segment.set_frame_rate(8000) | |
| audio_segment = audio_segment.set_channels(1) | |
| pcm_data = audio_segment.raw_data | |
| mulaw_data = audioop.lin2ulaw(pcm_data, 2) | |
| return mulaw_data | |
| except Exception as e: | |
| logging.error(f"Error converting audio to mulaw (attempt {attempt+1}): {e}", exc_info=True) | |
| if attempt == 2: | |
| return b'' | |
| def generate_elevenlabs_sync(text: str, voice: str, model: str = "eleven_multilingual_v2", output_format: str = "mp3_44100_128") -> bytes: | |
| """Synchronous ElevenLabs generation wrapper for run_in_threadpool.""" | |
| for attempt in range(3): | |
| try: | |
| # The ElevenLabs client call is synchronous in this codebase | |
| return client_elevenlabs.generate( | |
| text=text, | |
| voice=voice, | |
| model=model, | |
| output_format=output_format | |
| ) | |
| except Exception as e: | |
| logging.error(f"Error in ElevenLabs generate (attempt {attempt+1}): {e}", exc_info=True) | |
| if attempt == 2: | |
| return b'' | |
| # --- LLM / RAG helpers (async, but will call sync via threadpool when appropriate) --- | |
| async def formulate_search_plan(user_query: str) -> dict: | |
| logging.info("Formulating search plan with Planner LLM...") | |
| for attempt in range(3): | |
| try: | |
| response = client_openai.chat.completions.create( | |
| model=PLANNER_MODEL, | |
| messages=[{"role": "user", "content": QUERY_FORMULATION_PROMPT.format(user_query=user_query)}], | |
| response_format={"type": "json_object"}, | |
| temperature=0.0 | |
| ) | |
| plan = json.loads(response.choices[0].message.content) | |
| logging.info(f"Search plan received: {plan}") | |
| return plan | |
| except Exception as e: | |
| logging.error(f"Error in Planner LLM call (attempt {attempt+1}): {e}", exc_info=True) | |
| if attempt == 2: | |
| return {"search_query": user_query, "filter_table": None} | |
| async def get_agent_response(user_text: str) -> str: | |
| """Runs RAG and generation logic for a given text query with retries.""" | |
| for attempt in range(3): | |
| try: | |
| search_plan = await formulate_search_plan(user_text) | |
| search_query = search_plan.get("search_query", user_text) | |
| filter_table = search_plan.get("filter_table") | |
| search_filter = {"source_table": filter_table} if filter_table else {} | |
| if search_filter: | |
| logging.info(f"Applying initial filter: {search_filter}") | |
| retrieved_docs = vector_store.similarity_search(search_query, k=3, filter=search_filter) | |
| if not retrieved_docs: | |
| logging.info("Initial search returned no results. Performing a broader fallback search.") | |
| retrieved_docs = vector_store.similarity_search(search_query, k=3) | |
| context_text = "\n\n".join([doc.page_content for doc in retrieved_docs]) | |
| logging.info(f"Retrieved Context (preview): {context_text[:500]}...") | |
| final_prompt_messages = [ | |
| {"role": "system", "content": ANSWER_SYSTEM_PROMPT}, | |
| {"role": "system", "content": f"Use the following CONTEXT to answer:\n{context_text}"}, | |
| {"role": "user", "content": f"My original question was: '{user_text}'"} | |
| ] | |
| final_response = client_openai.chat.completions.create( | |
| model=ANSWERER_MODEL, | |
| messages=final_prompt_messages | |
| ) | |
| return final_response.choices[0].message.content | |
| except Exception as e: | |
| logging.error(f"Error in get_agent_response (attempt {attempt+1}): {e}", exc_info=True) | |
| if attempt == 2: | |
| return "Sorry, I couldn't generate a response. Please try again." | |
| # --- AUTH DEPENDENCY --- | |
| async def verify_token(x_auth_token: str = Header(...)): | |
| """Dependency to verify the shared secret token.""" | |
| if not SHARED_SECRET or x_auth_token != SHARED_SECRET: | |
| logging.warning("Authentication failed for /test-text-query.") | |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or missing authentication token") | |
| logging.info("Authentication successful for /test-text-query.") | |
| # --- API Endpoints --- | |
| class TextQuery(BaseModel): | |
| query: str | |
| async def test_text_query_endpoint(query: TextQuery): | |
| logging.info(f"Received text query: {query.query}") | |
| response_text = await get_agent_response(query.query) | |
| logging.info(f"Generated text response: {response_text}") | |
| return {"response": response_text} | |
| # --- WEBHOOK / WEBSOCKET FOR TWILIO STREAMING --- | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| logging.info("WebSocket connection accepted from Twilio.") | |
| stream_sid: Optional[str] = None | |
| try: | |
| first_message = await websocket.receive_json() | |
| event = first_message.get("event") | |
| if event != "start": | |
| logging.error("Expected 'start' message. Closing.") | |
| await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) | |
| return | |
| start_data = first_message.get("start", {}) | |
| custom_params = start_data.get("customParameters", {}) | |
| if not custom_params: | |
| logging.error("Missing customParameters in start event. Closing.") | |
| await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) | |
| return | |
| auth_token = custom_params.get("x-auth-token") | |
| stream_sid = start_data.get("streamSid") | |
| if not SHARED_SECRET or auth_token != SHARED_SECRET: | |
| logging.warning("Authentication failed. Invalid token. Closing connection.") | |
| await websocket.close(code=status.WS_1008_POLICY_VIOLATION) | |
| return | |
| logging.info(f"Authentication successful. Stream SID: {stream_sid}") | |
| logging.debug(f"Full start message: {first_message}") | |
| # Main loop with VAD using timeout + RMS threshold | |
| accumulated_audio_mulaw = b'' | |
| loop_counter = 0 | |
| while True: | |
| loop_counter += 1 | |
| if loop_counter > MAX_LOOP_COUNT: | |
| logging.info("Max loop count reached. Exiting to prevent hang.") | |
| break | |
| try: | |
| message_str = await asyncio.wait_for(websocket.receive_text(), timeout=1.0) | |
| message = json.loads(message_str) | |
| event = message.get("event") | |
| if event == "media": | |
| payload = message['media']['payload'] | |
| mulaw_chunk = base64.b64decode(payload) | |
| # Compute RMS to avoid buffering pure silence / static | |
| try: | |
| pcm_chunk = audioop.ulaw2lin(mulaw_chunk, 2) | |
| rms = audioop.rms(pcm_chunk, 2) | |
| except Exception as e: | |
| logging.debug(f"Could not compute RMS on chunk: {e}") | |
| rms = 0 | |
| if rms > SILENCE_THRESHOLD: | |
| accumulated_audio_mulaw += mulaw_chunk | |
| logging.debug(f"Buffered audio chunk; RMS={rms}, total_bytes={len(accumulated_audio_mulaw)}") | |
| else: | |
| logging.debug(f"Ignored low-energy chunk; RMS={rms}") | |
| # Safety: if buffer too large, process it | |
| if len(accumulated_audio_mulaw) > MAX_AUDIO_BYTES: | |
| logging.info(f"Max audio buffer reached ({len(accumulated_audio_mulaw)} bytes). Processing buffer.") | |
| await process_audio_buffer(websocket, stream_sid or "", accumulated_audio_mulaw) | |
| accumulated_audio_mulaw = b'' | |
| elif event == "stop": | |
| logging.info("Twilio stream sent 'stop' event.") | |
| # Process remaining buffered audio before breaking | |
| if accumulated_audio_mulaw: | |
| logging.info(f"Processing remaining audio on stop event ({len(accumulated_audio_mulaw)} bytes).") | |
| await process_audio_buffer(websocket, stream_sid or "", accumulated_audio_mulaw) | |
| accumulated_audio_mulaw = b'' | |
| break | |
| else: | |
| logging.debug(f"Ignored unknown event type: {event}") | |
| except asyncio.TimeoutError: | |
| # VAD trigger: no new data within timeout -> treat as end-of-speech | |
| if accumulated_audio_mulaw: | |
| logging.info(f"End of speech detected (timeout). Processing {len(accumulated_audio_mulaw)} bytes.") | |
| await process_audio_buffer(websocket, stream_sid or "", accumulated_audio_mulaw) | |
| accumulated_audio_mulaw = b'' | |
| else: | |
| # No buffered audio, loop again | |
| pass | |
| except (ValueError, json.JSONDecodeError) as e: | |
| logging.warning(f"Invalid message received: {e}. Skipping this message.") | |
| except WebSocketDisconnect: | |
| logging.info("WebSocket disconnected by client.") | |
| break | |
| except WebSocketDisconnect: | |
| logging.info("Call disconnected during start phase.") | |
| except Exception as e: | |
| logging.error(f"An error occurred in the main loop: {e}", exc_info=True) | |
| finally: | |
| try: | |
| await websocket.close() | |
| except Exception: | |
| pass | |
| # --- PROCESS AUDIO BUFFER (async wrapper that uses sync helpers in threadpool) --- | |
| async def process_audio_buffer(websocket: WebSocket, stream_sid: str, accumulated_audio_mulaw: bytes): | |
| logging.info(f"Processing audio buffer of {len(accumulated_audio_mulaw)} bytes...") | |
| # 1. Convert accumulated mulaw audio to WAV (in threadpool) | |
| wav_bytes = await run_in_threadpool(convert_mulaw_to_wav_bytes, accumulated_audio_mulaw) | |
| if not wav_bytes: | |
| logging.warning("WAV conversion produced no bytes. Skipping processing.") | |
| return | |
| # 2. Transcribe the WAV audio (in threadpool) | |
| user_text = await run_in_threadpool(transcribe_audio_sync, wav_bytes) | |
| if not user_text or not user_text.strip(): | |
| logging.info("Transcription empty; skipping further processing.") | |
| return | |
| user_text = user_text.strip() | |
| logging.info(f"User said: {user_text}") | |
| # 3. Get AI agent response (async) | |
| agent_response_text = await get_agent_response(user_text) | |
| logging.info(f"AI Responded (preview): {agent_response_text[:200]}") | |
| if not agent_response_text or not agent_response_text.strip(): | |
| logging.warning("Agent generated empty response; skipping TTS.") | |
| return | |
| # 4. Generate AI speech with ElevenLabs (in threadpool wrapper with retries inside) | |
| ai_audio_bytes = await run_in_threadpool(generate_elevenlabs_sync, agent_response_text, ELEVENLABS_VOICE_NAME) | |
| if not ai_audio_bytes: | |
| logging.error("ElevenLabs returned no audio bytes; skipping sending audio.") | |
| return | |
| # 5. Convert AI speech to 8kHz mulaw for Twilio (in threadpool) | |
| mulaw_payload_bytes = await run_in_threadpool(convert_audio_to_mulaw_sync, ai_audio_bytes) | |
| if not mulaw_payload_bytes: | |
| logging.error("Conversion to mulaw failed; skipping sending audio.") | |
| return | |
| # 6. Base64 encode and send the audio back to Twilio | |
| try: | |
| base64_payload = base64.b64encode(mulaw_payload_bytes).decode('utf-8') | |
| await websocket.send_json({ | |
| "event": "media", | |
| "streamSid": stream_sid, | |
| "media": {"payload": base64_payload} | |
| }) | |
| logging.info("Sent AI audio response back to Twilio.") | |
| except Exception as e: | |
| logging.error(f"Failed to send AI audio to Twilio: {e}", exc_info=True) | |
| return | |
| # 7. Send 'clear' to flush Twilio's buffer | |
| try: | |
| await websocket.send_json({"event": "clear", "streamSid": stream_sid}) | |
| logging.info("Sent clear event to Twilio.") | |
| except Exception as e: | |
| logging.error(f"Failed to send 'clear' event: {e}", exc_info=True) | |
| # End of file | |