Spaces:
Sleeping
Sleeping
# main.py | |
import asyncio | |
import base64 | |
import io | |
import logging | |
import os | |
from threading import Thread, Event # Added Event for better thread control | |
import time # For timeout checks | |
import soundfile as sf | |
import torch | |
import uvicorn | |
import whisper | |
from fastapi import FastAPI, File, UploadFile, WebSocket, WebSocketDisconnect | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer | |
from transformers import AutoTokenizer, GenerationConfig # Keep transformers.GenerationConfig | |
import google.generativeai as genai | |
import numpy as np | |
# --- Configuration --- | |
WHISPER_MODEL_SIZE = os.getenv("WHISPER_MODEL_SIZE", "tiny") | |
TTS_MODEL_NAME = "ai4bharat/indic-parler-tts" | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyD6x3Yoby4eQ6QL2kaaG_Rz3fG3rh7wPB8") | |
GEMINI_MODEL_NAME = "gemini-1.5-flash-latest" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
attn_implementation = "flash_attention_2" if torch.cuda.is_available() else "eager" | |
torch_dtype_tts = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else (torch.float16 if DEVICE == "cuda" else torch.float32) | |
torch_dtype_whisper = torch.float16 if DEVICE == "cuda" else torch.float32 | |
TTS_DEFAULT_PARAMS = { | |
"do_sample": True, | |
"temperature": 1.0, | |
"top_k": 50, | |
"top_p": 0.95, | |
"min_new_tokens": 5, # Reduced for quicker start with streamer | |
# "max_new_tokens": 256, # Optional global cap | |
} | |
# --- Logging --- | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# --- FastAPI App Initialization --- | |
app = FastAPI(title="Conversational AI Chatbot with Enhanced Stream Abortion") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# --- Global Model Variables --- | |
whisper_model = None | |
gemini_model_instance = None | |
tts_model = None | |
tts_tokenizer = None | |
# We will build the GenerationConfig object from TTS_DEFAULT_PARAMS inside the functions | |
# or store it globally if preferred, initialized from transformers.GenerationConfig | |
# --- Model Loading & API Configuration --- | |
async def load_resources(): | |
global whisper_model, tts_model, tts_tokenizer, gemini_model_instance | |
logger.info(f"Loading local models. Whisper on {DEVICE} with {torch_dtype_whisper}, TTS on {DEVICE} with {torch_dtype_tts}") | |
try: | |
logger.info(f"Loading Whisper model: {WHISPER_MODEL_SIZE}") | |
whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device=DEVICE) | |
logger.info("Whisper model loaded successfully.") | |
logger.info(f"Loading IndicParler-TTS model: {TTS_MODEL_NAME}") | |
tts_model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL_NAME, attn_implementation=attn_implementation).to(DEVICE, dtype=torch_dtype_tts) | |
tts_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_NAME) | |
if tts_tokenizer: | |
if tts_tokenizer.pad_token_id is not None: | |
TTS_DEFAULT_PARAMS["pad_token_id"] = tts_tokenizer.pad_token_id | |
# ParlerTTS uses a special token_id for silence, not eos_token_id for generation end. | |
# eos_token_id is more for text models. | |
# if tts_tokenizer.eos_token_id is not None: | |
# TTS_DEFAULT_PARAMS["eos_token_id"] = tts_tokenizer.eos_token_id | |
logger.info(f"IndicParler-TTS model loaded. Default generation params: {TTS_DEFAULT_PARAMS}") | |
if not GEMINI_API_KEY: | |
logger.warning("GEMINI_API_KEY not found. LLM functionality will be limited.") | |
else: | |
try: | |
genai.configure(api_key=GEMINI_API_KEY) | |
gemini_model_instance = genai.GenerativeModel(GEMINI_MODEL_NAME) | |
logger.info(f"Gemini API configured with model: {GEMINI_MODEL_NAME}") | |
except Exception as e: | |
logger.error(f"Failed to configure Gemini API: {e}", exc_info=True) | |
gemini_model_instance = None | |
except Exception as e: | |
logger.error(f"Error loading models: {e}", exc_info=True) | |
logger.info("Local models and API configurations loaded.") | |
# --- Helper Functions --- | |
async def transcribe_audio_bytes(audio_bytes: bytes) -> str: | |
if not whisper_model: | |
raise RuntimeError("Whisper model not loaded.") | |
temp_audio_path = f"temp_audio_main_{os.urandom(4).hex()}.wav" | |
try: | |
with open(temp_audio_path, "wb") as f: | |
f.write(audio_bytes) | |
result = whisper_model.transcribe(temp_audio_path, fp16=(DEVICE == "cuda" and torch_dtype_whisper == torch.float16)) | |
transcribed_text = result["text"].strip() | |
logger.info(f"Transcription: {transcribed_text}") | |
return transcribed_text | |
except Exception as e: | |
logger.error(f"Error during transcription: {e}", exc_info=True) | |
return "" | |
finally: | |
if os.path.exists(temp_audio_path): | |
try: | |
os.remove(temp_audio_path) | |
except Exception as e_del: | |
logger.error(f"Error deleting temp audio file {temp_audio_path}: {e_del}") | |
async def generate_gemini_response(text: str) -> str: | |
if not gemini_model_instance: | |
logger.error("Gemini model instance not available.") | |
return "Sorry, the language model is currently unavailable." | |
try: | |
full_prompt = f"User: {text}\nAssistant:" | |
loop = asyncio.get_event_loop() | |
response = await loop.run_in_executor(None, gemini_model_instance.generate_content, full_prompt) | |
response_text = "I'm sorry, I couldn't generate a response for that." | |
if hasattr(response, 'text') and response.text: # For simple text responses | |
response_text = response.text.strip() | |
elif response.parts: # New way to access parts for gemini-1.5-flash and pro | |
response_text = "".join(part.text for part in response.parts).strip() | |
elif response.candidates and response.candidates[0].content.parts: # Older way | |
response_text = response.candidates[0].content.parts[0].text.strip() | |
else: | |
safety_feedback = "" | |
if hasattr(response, 'prompt_feedback') and response.prompt_feedback: | |
safety_feedback = f" Safety Feedback: {response.prompt_feedback}" | |
elif response.candidates and hasattr(response.candidates[0], 'finish_reason') and response.candidates[0].finish_reason != "STOP": | |
safety_feedback = f" Finish Reason: {response.candidates[0].finish_reason}" | |
logger.warning(f"Gemini response might be empty or blocked.{safety_feedback}") | |
logger.info(f"Gemini LLM Response: {response_text}") | |
return response_text | |
except Exception as e: | |
logger.error(f"Error during Gemini LLM generation: {e}", exc_info=True) | |
return "Sorry, I encountered an error trying to respond." | |
async def synthesize_speech_streaming(text: str, description: str = "A clear, female voice speaking in English.", play_steps_in_s: float = 0.4, cancellation_event: Event = Event()): | |
if not tts_model or not tts_tokenizer: | |
logger.error("TTS model or tokenizer not loaded.") | |
if cancellation_event and cancellation_event.is_set(): logger.info("TTS cancelled before start."); yield b""; return | |
yield b"" | |
return | |
if not text or not text.strip(): | |
logger.warning("TTS input text is empty. Yielding empty audio.") | |
if cancellation_event and cancellation_event.is_set(): logger.info("TTS cancelled before start (empty text)."); yield b""; return | |
yield b"" | |
return | |
streamer = None | |
thread = None | |
try: | |
logger.info(f"Starting TTS streaming with ParlerTTSStreamer for: \"{text[:50]}...\"") | |
# Ensure sampling_rate is correctly accessed from the model's config | |
# For ParlerTTS, it's usually under model.config.audio_encoder.sampling_rate | |
if hasattr(tts_model.config, 'audio_encoder') and hasattr(tts_model.config.audio_encoder, 'sampling_rate'): | |
sampling_rate = tts_model.config.audio_encoder.sampling_rate | |
else: | |
logger.warning("Could not find tts_model.config.audio_encoder.sampling_rate, defaulting to 24000") | |
sampling_rate = 24000 # A common default for ParlerTTS if not found | |
try: | |
frame_rate = getattr(tts_model.config.audio_encoder, 'frame_rate', 100) | |
except AttributeError: | |
logger.warning("frame_rate not found in tts_model.config.audio_encoder. Using default of 100 Hz for play_steps calculation.") | |
frame_rate = 100 | |
play_steps = int(frame_rate * play_steps_in_s) | |
if play_steps == 0 : play_steps = 1 | |
logger.info(f"Streamer params: sampling_rate={sampling_rate}, frame_rate={frame_rate}, play_steps_in_s={play_steps_in_s}, play_steps={play_steps}") | |
streamer = ParlerTTSStreamer(tts_model, device=DEVICE, play_steps=play_steps) | |
description_inputs = tts_tokenizer(description, return_tensors="pt") | |
prompt_inputs = tts_tokenizer(text, return_tensors="pt") | |
gen_config_dict = TTS_DEFAULT_PARAMS.copy() | |
# ParlerTTS generate method might not take a GenerationConfig object directly, | |
# but rather individual kwargs. The streamer example passes them as kwargs. | |
# We ensure pad_token_id and eos_token_id are set if the tokenizer has them. | |
if tts_tokenizer.pad_token_id is not None: | |
gen_config_dict["pad_token_id"] = tts_tokenizer.pad_token_id | |
# ParlerTTS might not use eos_token_id in the same way as text models. | |
# if tts_tokenizer.eos_token_id is not None: | |
# gen_config_dict["eos_token_id"] = tts_tokenizer.eos_token_id | |
thread_generation_kwargs = { | |
"input_ids": description_inputs.input_ids.to(DEVICE), | |
"prompt_input_ids": prompt_inputs.input_ids.to(DEVICE), | |
"attention_mask": description_inputs.attention_mask.to(DEVICE) if hasattr(description_inputs, 'attention_mask') else None, | |
"streamer": streamer, | |
**gen_config_dict # Spread the generation parameters | |
} | |
if thread_generation_kwargs["attention_mask"] is None: | |
del thread_generation_kwargs["attention_mask"] | |
def _generate_in_thread(): | |
try: | |
logger.info(f"TTS generation thread started.") | |
with torch.no_grad(): | |
tts_model.generate(**thread_generation_kwargs) | |
logger.info("TTS generation thread finished model.generate().") | |
except Exception as e_thread: | |
logger.error(f"Error in TTS generation thread: {e_thread}", exc_info=True) | |
finally: | |
if streamer: streamer.end() | |
logger.info("TTS generation thread called streamer.end().") | |
thread = Thread(target=_generate_in_thread) | |
thread.daemon = True | |
thread.start() | |
loop = asyncio.get_event_loop() | |
while True: | |
if cancellation_event and cancellation_event.is_set(): | |
logger.info("TTS streaming cancelled by event.") | |
break | |
try: | |
# Run the blocking streamer.__next__() in an executor | |
audio_chunk_tensor = await loop.run_in_executor(None, streamer.__next__) | |
if audio_chunk_tensor is None: | |
logger.info("Streamer yielded None explicitly, ending stream.") | |
break | |
# This check for numel == 0 is important as streamer might yield empty tensors | |
if not isinstance(audio_chunk_tensor, torch.Tensor) or audio_chunk_tensor.numel() == 0: | |
# REMOVED: if streamer.is_done(): (AttributeError) | |
# Instead, rely on StopIteration or explicit None from streamer | |
await asyncio.sleep(0.01) # Small sleep if empty but not done | |
continue | |
audio_chunk_np = audio_chunk_tensor.cpu().to(torch.float32).numpy().squeeze() | |
if audio_chunk_np.size == 0: | |
continue | |
audio_chunk_int16 = np.clip(audio_chunk_np * 32767, -32768, 32767).astype(np.int16) | |
yield audio_chunk_int16.tobytes() | |
# No need for sleep here if chunks are substantial, client will process | |
# await asyncio.sleep(0.001) # Can be removed or made very small | |
except StopIteration: | |
logger.info("Streamer finished (StopIteration).") | |
break | |
except Exception as e_stream_iter: | |
logger.error(f"Error iterating streamer: {e_stream_iter}", exc_info=True) | |
break | |
logger.info(f"Finished TTS streaming iteration for: \"{text[:50]}...\"") | |
except Exception as e: | |
logger.error(f"Error in synthesize_speech_streaming function: {e}", exc_info=True) | |
yield b"" | |
finally: | |
logger.info("Exiting synthesize_speech_streaming. Ensuring streamer is ended and thread is joined.") | |
if streamer: | |
streamer.end() | |
if thread and thread.is_alive(): | |
logger.info("Waiting for TTS generation thread to complete in finally block...") | |
final_join_start_time = time.time() | |
thread.join(timeout=2.0) | |
if thread.is_alive(): | |
logger.warning(f"TTS generation thread still alive after {time.time() - final_join_start_time:.2f}s in finally block.") | |
# --- FastAPI HTTP Endpoints --- | |
async def speech_to_text_endpoint(file: UploadFile = File(...)): | |
if not whisper_model: | |
return JSONResponse(content={"error": "Whisper model not loaded"}, status_code=503) | |
try: | |
audio_bytes = await file.read() | |
transcribed_text = await transcribe_audio_bytes(audio_bytes) | |
return {"transcription": transcribed_text} | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
async def llm_endpoint(payload: dict): | |
if not gemini_model_instance: | |
return JSONResponse(content={"error": "Gemini LLM not configured or API key missing"}, status_code=503) | |
try: | |
text = payload.get("text") | |
if not text: | |
return JSONResponse(content={"error": "No text provided"}, status_code=400) | |
response = await generate_gemini_response(text) | |
return {"response": response} | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
async def text_to_speech_endpoint(payload: dict): | |
if not tts_model or not tts_tokenizer: | |
return JSONResponse(content={"error": "TTS model/tokenizer not loaded"}, status_code=503) | |
try: | |
text = payload.get("text") | |
description = payload.get("description", "A clear, female voice speaking in English.") | |
if not text: | |
return JSONResponse(content={"error": "No text provided"}, status_code=400) | |
description_inputs = tts_tokenizer(description, return_tensors="pt") | |
prompt_inputs = tts_tokenizer(text, return_tensors="pt") | |
# Use a GenerationConfig object for clarity and consistency | |
gen_config_dict = TTS_DEFAULT_PARAMS.copy() | |
if tts_tokenizer.pad_token_id is not None: | |
gen_config_dict["pad_token_id"] = tts_tokenizer.pad_token_id | |
# if tts_tokenizer.eos_token_id is not None: # ParlerTTS might not use standard eos | |
# gen_config_dict["eos_token_id"] = tts_tokenizer.eos_token_id | |
# Create GenerationConfig from transformers | |
generation_config_obj = GenerationConfig(**gen_config_dict) | |
with torch.no_grad(): | |
generation = tts_model.generate( | |
input_ids=description_inputs.input_ids.to(DEVICE), | |
prompt_input_ids=prompt_inputs.input_ids.to(DEVICE), | |
attention_mask=description_inputs.attention_mask.to(DEVICE) if hasattr(description_inputs, 'attention_mask') else None, | |
generation_config=generation_config_obj # Pass the config object | |
).cpu().to(torch.float32).numpy().squeeze() | |
audio_io = io.BytesIO() | |
scaled_generation = np.clip(generation * 32767, -32768, 32767).astype(np.int16) | |
current_sampling_rate = tts_model.config.audio_encoder.sampling_rate if hasattr(tts_model.config, 'audio_encoder') else 24000 | |
sf.write(audio_io, scaled_generation, samplerate=current_sampling_rate, format='WAV', subtype='PCM_16') | |
audio_io.seek(0) | |
audio_bytes = audio_io.read() | |
if not audio_bytes: | |
return JSONResponse(content={"error": "TTS failed to generate audio"}, status_code=500) | |
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') | |
return {"audio_base64": audio_base64, "format": "wav", "sample_rate": current_sampling_rate} | |
except Exception as e: | |
logger.error(f"TTS endpoint error: {e}", exc_info=True) | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
# --- WebSocket Endpoint for Real-time Conversation --- | |
async def conversation_websocket(websocket: WebSocket): | |
await websocket.accept() | |
logger.info(f"WebSocket connection accepted from: {websocket.client}") | |
tts_cancellation_event = Event() # For this specific connection | |
try: | |
while True: | |
if websocket.client_state.name != 'CONNECTED': # Check if client disconnected before receive | |
logger.info(f"WebSocket client {websocket.client} disconnected before receive.") | |
break | |
audio_data = await websocket.receive_bytes() | |
logger.info(f"Received {len(audio_data)} bytes of user audio data from {websocket.client}.") | |
if not audio_data: | |
logger.warning(f"Received empty audio data from user {websocket.client}.") | |
continue | |
transcribed_text = await transcribe_audio_bytes(audio_data) | |
if not transcribed_text: | |
logger.warning(f"Transcription failed for {websocket.client}.") | |
await websocket.send_text("SYSTEM_ERROR: Transcription failed.") | |
continue | |
await websocket.send_text(f"USER_TRANSCRIPT: {transcribed_text}") | |
llm_response_text = await generate_gemini_response(transcribed_text) | |
if not llm_response_text or "Sorry, I encountered an error" in llm_response_text or "unavailable" in llm_response_text: | |
logger.warning(f"LLM (Gemini) failed for {websocket.client}: {llm_response_text}") | |
await websocket.send_text(f"SYSTEM_ERROR: LLM failed. ({llm_response_text})") | |
continue | |
await websocket.send_text(f"ASSISTANT_RESPONSE_TEXT: {llm_response_text}") | |
tts_description = "A clear, female voice speaking in English." | |
current_sampling_rate = tts_model.config.audio_encoder.sampling_rate if hasattr(tts_model.config, 'audio_encoder') else 24000 | |
audio_params_msg = f"TTS_STREAM_START:{{\"sample_rate\": {current_sampling_rate}, \"channels\": 1, \"bit_depth\": 16}}" | |
await websocket.send_text(audio_params_msg) | |
logger.info(f"Sent to client {websocket.client}: {audio_params_msg}") | |
chunk_count = 0 | |
tts_cancellation_event.clear() # Reset event for new TTS task | |
async for audio_chunk_bytes in synthesize_speech_streaming(llm_response_text, tts_description, cancellation_event=tts_cancellation_event): | |
if not audio_chunk_bytes: | |
logger.debug(f"Received empty bytes from streaming generator for {websocket.client}, might be end or error in generator.") | |
continue | |
try: | |
if websocket.client_state.name != 'CONNECTED': | |
logger.warning(f"Client {websocket.client} disconnected during TTS stream. Aborting TTS.") | |
tts_cancellation_event.set() # Signal TTS thread to stop | |
break | |
await websocket.send_bytes(audio_chunk_bytes) | |
chunk_count += 1 | |
except Exception as send_err: | |
logger.warning(f"Error sending audio chunk to {websocket.client}: {send_err}. Client likely disconnected.") | |
tts_cancellation_event.set() # Signal TTS thread to stop | |
break | |
if not tts_cancellation_event.is_set(): # Only send END if not cancelled | |
logger.info(f"Sent {chunk_count} TTS audio chunks to client {websocket.client}.") | |
await websocket.send_text("TTS_STREAM_END") | |
logger.info(f"Sent TTS_STREAM_END to client {websocket.client}.") | |
else: | |
logger.info(f"TTS stream for {websocket.client} was cancelled. Sent {chunk_count} chunks before cancellation.") | |
except WebSocketDisconnect: | |
logger.info(f"WebSocket connection closed by client {websocket.client}.") | |
tts_cancellation_event.set() # Signal any active TTS to stop | |
except Exception as e: | |
logger.error(f"Error in WebSocket conversation with {websocket.client}: {e}", exc_info=True) | |
tts_cancellation_event.set() # Signal any active TTS to stop | |
try: | |
if websocket.client_state.name == 'CONNECTED': | |
await websocket.send_text(f"SYSTEM_ERROR: An unexpected error occurred: {str(e)}") | |
except Exception: pass | |
finally: | |
logger.info(f"Cleaning up WebSocket connection for {websocket.client}.") | |
tts_cancellation_event.set() # Ensure event is set on any exit path | |
if websocket.client_state.name == 'CONNECTED' or websocket.client_state.name == 'CONNECTING': | |
try: await websocket.close() | |
except Exception: pass | |
logger.info(f"WebSocket connection resources cleaned up for {websocket.client}.") | |
# ... (HTML serving and main execution block remain the same) ... | |
async def get_home(): | |
html_content = """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Conversational AI Chatbot (Streaming)</title> | |
<style> | |
body { font-family: Arial, sans-serif; margin: 20px; background-color: #f4f4f4; color: #333; } | |
#chatbox { width: 80%; max-width: 600px; margin: auto; background-color: #fff; padding: 20px; box-shadow: 0 0 10px rgba(0,0,0,0.1); border-radius: 8px; } | |
.message { padding: 10px; margin-bottom: 10px; border-radius: 5px; } | |
.user { background-color: #e1f5fe; text-align: right; } | |
.assistant { background-color: #f1f8e9; } | |
.system { background-color: #ffebee; color: #c62828; font-style: italic;} | |
#audioPlayerContainer { margin-top: 10px; } | |
#audioPlayer { display: none; width: 100%; } | |
button { padding: 10px 15px; background-color: #007bff; color: white; border: none; border-radius: 5px; cursor: pointer; margin-top:10px; } | |
button:disabled { background-color: #ccc; } | |
#status { margin-top: 10px; font-style: italic; color: #666; } | |
#transcriptionArea, #llmResponseArea { margin-top: 10px; padding: 5px; border: 1px solid #eee; background: #fafafa; word-wrap: break-word;} | |
</style> | |
</head> | |
<body> | |
<div id="chatbox"> | |
<h2>Real-time AI Chatbot (Streaming TTS)</h2> | |
<div id="messages"></div> | |
<div id="transcriptionArea"><strong>You (transcribed):</strong> <span id="userTranscript">...</span></div> | |
<div id="llmResponseArea"><strong>Assistant (text):</strong> <span id="assistantTranscript">...</span></div> | |
<button id="startRecordButton">Start Recording</button> | |
<button id="stopRecordButton" disabled>Stop Recording</button> | |
<p id="status">Status: Idle</p> | |
<div id="audioPlayerContainer"> | |
<audio id="audioPlayer" controls></audio> | |
</div> | |
</div> | |
<script> | |
const startRecordButton = document.getElementById('startRecordButton'); | |
const stopRecordButton = document.getElementById('stopRecordButton'); | |
const audioPlayer = document.getElementById('audioPlayer'); | |
const messagesDiv = document.getElementById('messages'); | |
const statusDiv = document.getElementById('status'); | |
const userTranscriptSpan = document.getElementById('userTranscript'); | |
const assistantTranscriptSpan = document.getElementById('assistantTranscript'); | |
let websocket; | |
let mediaRecorder; | |
let userAudioChunks = []; | |
let assistantAudioBufferQueue = []; | |
let audioContext; | |
let expectedSampleRate; | |
let ttsStreaming = false; | |
let audioPlaying = false; | |
let sourceNode = null; | |
function initAudioContext() { | |
if (!audioContext || audioContext.state === 'closed') { | |
try { | |
audioContext = new (window.AudioContext || window.webkitAudioContext)(); | |
console.log("AudioContext initialized or re-initialized."); | |
} catch (e) { | |
console.error("Web Audio API is not supported in this browser.", e); | |
addMessage("Error: Web Audio API not supported. Cannot play streamed audio.", "system"); | |
audioContext = null; | |
} | |
} | |
} | |
function connectWebSocket() { | |
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; | |
const wsUrl = `${protocol}//${window.location.host}/ws/conversation`; | |
websocket = new WebSocket(wsUrl); | |
websocket.binaryType = 'arraybuffer'; | |
websocket.onopen = () => { | |
statusDiv.textContent = 'Status: Connected. Ready to record.'; | |
startRecordButton.disabled = false; | |
initAudioContext(); | |
}; | |
websocket.onmessage = (event) => { | |
if (event.data instanceof ArrayBuffer) { | |
if (ttsStreaming && audioContext && expectedSampleRate) { | |
const pcmDataInt16 = new Int16Array(event.data); | |
if (pcmDataInt16.length > 0) { | |
assistantAudioBufferQueue.push(pcmDataInt16); | |
playNextChunkFromQueue(); | |
} | |
} else { | |
console.warn("Received ArrayBuffer data but not in TTS streaming mode or AudioContext not ready."); | |
} | |
} else { | |
const messageText = event.data; | |
if (messageText.startsWith("USER_TRANSCRIPT:")) { | |
const transcript = messageText.substring("USER_TRANSCRIPT:".length).trim(); | |
userTranscriptSpan.textContent = transcript; | |
} else if (messageText.startsWith("ASSISTANT_RESPONSE_TEXT:")) { | |
const llmResponse = messageText.substring("ASSISTANT_RESPONSE_TEXT:".length).trim(); | |
assistantTranscriptSpan.textContent = llmResponse; | |
addMessage(`Assistant: ${llmResponse}`, 'assistant'); | |
} else if (messageText.startsWith("TTS_STREAM_START:")) { | |
ttsStreaming = true; | |
assistantAudioBufferQueue = []; | |
audioPlaying = false; | |
if (sourceNode) { | |
try { sourceNode.stop(); } catch(e) { console.warn("Error stopping previous sourceNode:", e); } | |
sourceNode = null; | |
} | |
audioPlayer.style.display = 'none'; | |
audioPlayer.src = ""; | |
try { | |
const paramsText = messageText.substring("TTS_STREAM_START:".length); | |
const params = JSON.parse(paramsText); | |
expectedSampleRate = params.sample_rate; | |
initAudioContext(); | |
statusDiv.textContent = 'Status: Receiving audio stream...'; | |
addMessage('Assistant (Audio stream starting...)', 'assistant'); | |
} catch (e) { | |
console.error("Could not parse TTS_STREAM_START params:", e); | |
statusDiv.textContent = 'Error: Could not parse audio stream parameters.'; | |
ttsStreaming = false; | |
} | |
} else if (messageText === "TTS_STREAM_END") { | |
ttsStreaming = false; | |
if (!audioPlaying && assistantAudioBufferQueue.length === 0) { | |
statusDiv.textContent = 'Status: Audio stream finished (or was empty).'; | |
} else if (!audioPlaying && assistantAudioBufferQueue.length > 0) { | |
playNextChunkFromQueue(); | |
statusDiv.textContent = 'Status: Audio stream finished. Playing remaining...'; | |
} else { | |
statusDiv.textContent = 'Status: Audio stream finished. Playing remaining...'; | |
} | |
addMessage('Assistant (Audio stream ended)', 'assistant'); | |
} else if (messageText.startsWith("SYSTEM_ERROR:")) { | |
const errorMsg = messageText.substring("SYSTEM_ERROR:".length).trim(); | |
addMessage(`System Error: ${errorMsg}`, 'system'); | |
statusDiv.textContent = `Error: ${errorMsg}`; | |
ttsStreaming = false; | |
assistantAudioBufferQueue = []; | |
} else { | |
addMessage(messageText, 'system'); | |
} | |
} | |
}; | |
websocket.onerror = (error) => { | |
console.error('WebSocket Error:', error); | |
statusDiv.textContent = 'Status: WebSocket error. Try reconnecting.'; | |
addMessage('WebSocket Error. Check console.', 'system'); | |
ttsStreaming = false; | |
}; | |
websocket.onclose = () => { | |
statusDiv.textContent = 'Status: Disconnected. Please refresh to reconnect.'; | |
startRecordButton.disabled = true; | |
stopRecordButton.disabled = true; | |
addMessage('Disconnected from server.', 'system'); | |
ttsStreaming = false; | |
if (audioContext && audioContext.state !== 'closed') { | |
audioContext.close().catch(e => console.warn("Error closing AudioContext:", e)); | |
audioContext = null; | |
console.log("AudioContext closed."); | |
} | |
}; | |
} | |
function playNextChunkFromQueue() { | |
if (audioPlaying || assistantAudioBufferQueue.length === 0 || !audioContext || audioContext.state !== 'running' || !expectedSampleRate) { | |
if (assistantAudioBufferQueue.length === 0 && !ttsStreaming && !audioPlaying) { | |
console.log("Queue empty, not streaming, not playing: Playback complete."); | |
statusDiv.textContent = 'Status: Audio playback complete.'; | |
} | |
return; | |
} | |
audioPlaying = true; | |
const pcmDataInt16 = assistantAudioBufferQueue.shift(); | |
const float32Pcm = new Float32Array(pcmDataInt16.length); | |
for (let i = 0; i < pcmDataInt16.length; i++) { | |
float32Pcm[i] = pcmDataInt16[i] / 32768.0; | |
} | |
const audioBuffer = audioContext.createBuffer(1, float32Pcm.length, expectedSampleRate); | |
audioBuffer.getChannelData(0).set(float32Pcm); | |
sourceNode = audioContext.createBufferSource(); | |
sourceNode.buffer = audioBuffer; | |
sourceNode.connect(audioContext.destination); | |
sourceNode.onended = () => { | |
audioPlaying = false; | |
if (ttsStreaming || assistantAudioBufferQueue.length > 0) { | |
playNextChunkFromQueue(); | |
} else { | |
statusDiv.textContent = 'Status: Audio playback finished.'; | |
console.log("All queued audio chunks played."); | |
} | |
}; | |
sourceNode.start(); | |
statusDiv.textContent = 'Status: Playing audio chunk...'; | |
} | |
function addMessage(text, type) { | |
const messageElement = document.createElement('div'); | |
messageElement.classList.add('message', type); | |
messageElement.textContent = text; | |
messagesDiv.appendChild(messageElement); | |
messagesDiv.scrollTop = messagesDiv.scrollHeight; | |
} | |
startRecordButton.onclick = async () => { | |
if (!websocket || websocket.readyState !== WebSocket.OPEN) { | |
alert("WebSocket is not connected. Please wait or refresh."); | |
return; | |
} | |
if (audioContext && audioContext.state === 'suspended') { | |
audioContext.resume().catch(e => console.error("Error resuming AudioContext:", e)); | |
} | |
initAudioContext(); | |
try { | |
const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); | |
let options = { mimeType: 'audio/webm;codecs=opus' }; | |
if (!MediaRecorder.isTypeSupported(options.mimeType)) { | |
console.warn(`${options.mimeType} is not supported, trying default.`); | |
options = {}; | |
} | |
mediaRecorder = new MediaRecorder(stream, options); | |
userAudioChunks = []; | |
mediaRecorder.ondataavailable = event => { | |
if (event.data.size > 0) userAudioChunks.push(event.data); | |
}; | |
mediaRecorder.onstop = () => { | |
if (userAudioChunks.length === 0) { | |
console.log("No audio data recorded."); | |
statusDiv.textContent = 'Status: No audio data recorded. Try again.'; | |
startRecordButton.disabled = false; | |
stopRecordButton.disabled = true; | |
return; | |
} | |
const audioBlob = new Blob(userAudioChunks, { type: mediaRecorder.mimeType }); | |
if (websocket && websocket.readyState === WebSocket.OPEN) { | |
websocket.send(audioBlob); | |
statusDiv.textContent = 'Status: Audio sent. Waiting for response...'; | |
} else { | |
statusDiv.textContent = 'Status: WebSocket not open. Cannot send audio.'; | |
} | |
userAudioChunks = []; | |
}; | |
mediaRecorder.start(250); | |
startRecordButton.disabled = true; | |
stopRecordButton.disabled = false; | |
statusDiv.textContent = 'Status: Recording...'; | |
userTranscriptSpan.textContent = "..."; | |
assistantTranscriptSpan.textContent = "..."; | |
audioPlayer.style.display = 'none'; | |
audioPlayer.src = ''; | |
assistantAudioBufferQueue = []; | |
if (sourceNode) { try {sourceNode.stop();} catch(e){} sourceNode = null; } | |
} catch (err) { | |
console.error('Error accessing microphone:', err); | |
statusDiv.textContent = 'Status: Error accessing microphone.'; | |
alert('Could not access microphone: ' + err.message); | |
} | |
}; | |
stopRecordButton.onclick = () => { | |
if (mediaRecorder && mediaRecorder.state === "recording") { | |
mediaRecorder.stop(); | |
startRecordButton.disabled = false; | |
stopRecordButton.disabled = true; | |
} | |
}; | |
connectWebSocket(); | |
</script> | |
</body> | |
</html> | |
""" | |
return HTMLResponse(content=html_content) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") | |