api / app /main.py
SalexAI's picture
Update app/main.py
fc77df5 verified
from __future__ import annotations
import asyncio
import json
import os
import uuid
from typing import Any, Dict, Optional
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from fastrtc import Stream, ReplyOnPause, get_stt_model, get_tts_model
from .gemini_text import (
gemini_chat_turn,
get_session,
deliver_function_result,
)
app = FastAPI()
# ----------------------------
# FastRTC Voice Chat (VAD + STT + TTS)
# ----------------------------
# These are CPU-friendly, but still heavy on Spaces. Keep them global.
STT_MODEL_NAME = os.getenv("FASTRTC_STT_MODEL", "moonshine/tiny")
TTS_MODEL_NAME = os.getenv("FASTRTC_TTS_MODEL", "kokoro")
stt = get_stt_model(model=STT_MODEL_NAME)
tts = get_tts_model(model=TTS_MODEL_NAME)
def _voice_reply_fn(audio: tuple[int, np.ndarray]):
"""
Called when the user pauses (VAD). Returns streamed audio frames (TTS).
"""
# audio is (sample_rate, int16 mono ndarray)
# FastRTC STT expects "audio" in the same tuple form per docs examples.
user_text = stt.stt(audio).strip()
if not user_text:
return
# For voice sessions we create a synthetic session_id (not Scratch ws session)
# because FastRTC’s ReplyOnPause fn signature doesn’t expose the RTC session id.
# This keeps a stable conversation state per-process, but not per-user.
#
# If you need per-user memory for voice, we can switch to a stateful StreamHandler later.
voice_session_id = "voice-global"
async def run():
# No tool bounce for voice by default (still supported via same session registry if you want)
async def noop_emit(_evt: dict):
return
text = await gemini_chat_turn(
session_id=voice_session_id,
user_text=user_text,
emit_event=noop_emit,
model=os.getenv("GEMINI_TEXT_MODEL", "gemini-2.0-flash"),
)
return text
text = asyncio.get_event_loop().run_until_complete(run())
# Stream TTS back
for chunk in tts.stream_tts_sync(text):
# chunk is already an audio frame compatible with FastRTC
yield chunk
voice_stream = Stream(
modality="audio",
mode="send-receive",
handler=ReplyOnPause(_voice_reply_fn),
)
# Mount FastRTC endpoints (WebRTC + WebSocket) under /rtc
voice_stream.mount(app, path="/rtc")
# ----------------------------
# Scratch-friendly WebSocket API (text + function calling)
# ----------------------------
@app.get("/")
async def root():
return JSONResponse(
{
"ok": True,
"service": "salexai-api",
"ws": "/ws",
"fastrtc": "/rtc",
"notes": [
"Use /ws for Scratch JSON chat + function calling.",
"Use /rtc for FastRTC voice chat endpoints (VAD/STT/TTS handled by FastRTC).",
],
}
)
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket):
await ws.accept()
session_id: Optional[str] = None
async def emit(evt: dict):
await ws.send_text(json.dumps(evt))
try:
while True:
raw = await ws.receive_text()
msg = json.loads(raw) if raw else {}
mtype = msg.get("type")
if mtype == "connect":
session_id = msg.get("session_id") or str(uuid.uuid4())
get_session(session_id) # ensure exists
await emit({"type": "ready", "session_id": session_id})
continue
if not session_id:
await emit({"type": "error", "message": "Not connected. Send {type:'connect'} first."})
continue
# -------- function registry --------
if mtype == "add_function":
name = str(msg.get("name") or "").strip()
schema = msg.get("schema") or {}
if not name:
await emit({"type": "error", "message": "add_function missing name"})
continue
s = get_session(session_id)
s.functions[name] = schema
await emit({"type": "function_added", "name": name})
continue
if mtype == "remove_function":
name = str(msg.get("name") or "").strip()
s = get_session(session_id)
if name in s.functions:
s.functions.pop(name, None)
await emit({"type": "function_removed", "name": name})
else:
await emit({"type": "warning", "message": f"Function not found: {name}"})
continue
if mtype == "list_functions":
s = get_session(session_id)
await emit({"type": "functions", "items": list(s.functions.keys())})
continue
# Client returns tool results
if mtype == "function_result":
call_id = msg.get("call_id")
result = msg.get("result")
if not call_id:
await emit({"type": "error", "message": "function_result missing call_id"})
continue
ok = deliver_function_result(session_id, call_id, result)
if not ok:
await emit({"type": "warning", "message": f"No pending call_id: {call_id}"})
else:
await emit({"type": "function_result_ack", "call_id": call_id})
continue
# -------- chat --------
if mtype == "send":
text = str(msg.get("text") or "")
if not text.strip():
await emit({"type": "error", "message": "Empty text"})
continue
try:
assistant_text = await gemini_chat_turn(
session_id=session_id,
user_text=text,
emit_event=emit, # this is where tool calls get emitted
model=os.getenv("GEMINI_TEXT_MODEL", "gemini-2.0-flash"),
)
await emit({"type": "assistant", "text": assistant_text})
except Exception as e:
await emit({"type": "error", "message": f"Gemini error: {e}"})
continue
await emit({"type": "error", "message": f"Unknown type: {mtype}"})
except WebSocketDisconnect:
return
except Exception as e:
try:
await emit({"type": "error", "message": f"WS crashed: {e}"})
except Exception:
pass