Luigi's picture
add endpoint detection
548b7ed
from fastapi import FastAPI, WebSocket
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from app.asr_worker import create_recognizer, stream_audio
import json
from starlette.websockets import WebSocketDisconnect
app = FastAPI()
app.mount("/static", StaticFiles(directory="app/static"), name="static")
@app.get("/")
async def root():
with open("app/static/index.html") as f:
return HTMLResponse(f.read())
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
print("[DEBUG main] ▶ Attempting to accept WebSocket…")
await websocket.accept()
print("[DEBUG main] ▶ WebSocket.accept() returned → client is connected!")
recognizer = None
stream = None
orig_sr = 48000 # default fallback
try:
while True:
data = await websocket.receive()
kind = data.get("type")
# Handle config messages
if kind not in ("websocket.receive", "websocket.receive_bytes"):
print(f"[DEBUG main] Received control/frame: {data}")
continue
if kind == "websocket.receive" and "text" in data:
raw = data["text"]
try:
config_msg = json.loads(raw)
except Exception as e:
print(f"[ERROR main] JSON parse failed: {e}")
continue
if config_msg.get("type") == "config":
# 1) sample rate
orig_sr = int(config_msg["sampleRate"])
print(f"[INFO main] Set original sample rate to {orig_sr}")
# 2) model & precision
model_id = config_msg.get("model")
precision = config_msg.get("precision")
print(f"[INFO main] Selected model: {model_id}, precision: {precision}")
# 3) hotwords & boost score
hotwords = config_msg.get("hotwords", [])
hotwords_score = float(config_msg.get("hotwordsScore", 0.0))
print(f"[INFO main] Hotwords: {hotwords}, score: {hotwords_score}")
# 4) Parse endpoint detection rules
ep1 = float(config_msg.get("epRule1", 2.4))
ep2 = float(config_msg.get("epRule2", 1.2))
ep3 = int( config_msg.get("epRule3", 300))
print(f"[INFO main] Endpoint rules: rule1={ep1}s, rule2={ep2}s, rule3={ep3}ms")
# 5) create recognizer with endpoint settings & biasing
recognizer = create_recognizer(
model_id,
precision,
hotwords=hotwords,
hotwords_score=hotwords_score,
ep_rule1=ep1,
ep_rule2=ep2,
ep_rule3=ep3
)
stream = recognizer.create_stream()
print("[INFO main] WebSocket connection accepted; created a streaming context.")
continue
# Don't process audio until after config
if recognizer is None or stream is None:
continue
# If it’s a text payload but with bytes (some FastAPI versions put audio under 'text'!)
if kind == "websocket.receive" and "bytes" in data:
raw_audio = data["bytes"]
# print(f"[INFO main] (text+bytes) Received audio chunk: {len(raw_audio)} bytes")
result, rms = stream_audio(raw_audio, stream, recognizer, orig_sr)
vol_to_send = min(rms, 1.0)
# print(f"[INFO main] Sending → partial='{result[:30]}…', volume={vol_to_send:.4f}")
# 1) send the interim
await websocket.send_json({"partial": result, "volume": vol_to_send})
# 2) DEBUG: log when endpoint is seen
is_ep = recognizer.is_endpoint(stream)
# print(f"[DEBUG main] is_endpoint={is_ep}")
# 3) if endpoint, emit final and reset
if is_ep:
if result.strip():
print(f"[DEBUG main] Emitting final: {result!r}")
await websocket.send_json({"final": result})
recognizer.reset(stream)
continue
elif kind == "websocket.receive_bytes":
raw_audio = data["bytes"]
# print(f"[INFO main] Received audio chunk: {len(raw_audio)} bytes")
# This will also print its own debug info (see asr_worker.py)
result, rms = stream_audio(raw_audio, stream, recognizer, orig_sr)
vol_to_send = min(rms, 1.0)
# print(f"[INFO main] Sending → partial='{result[:30]}…', volume={vol_to_send:.4f}")
await websocket.send_json({
"partial": result,
"volume": min(rms, 1.0)
})
# -- INSERT: emit final on endpoint detection --
if recognizer.is_endpoint(stream):
if result.strip():
await websocket.send_json({"final": result})
recognizer.reset(stream)
except Exception as e:
print(f"[ERROR main] Unexpected exception: {e}")
try:
await websocket.close()
except:
pass
print("[INFO main] WebSocket closed, cleanup complete.")