| | from __future__ import annotations |
| |
|
| | import asyncio |
| | import json |
| | import os |
| | import uuid |
| | from typing import Any, Dict, Optional |
| |
|
| | import numpy as np |
| | from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
| | from fastapi.responses import JSONResponse |
| |
|
| | from fastrtc import Stream, ReplyOnPause, get_stt_model, get_tts_model |
| |
|
| | from .gemini_text import ( |
| | gemini_chat_turn, |
| | get_session, |
| | deliver_function_result, |
| | ) |
| |
|
| | app = FastAPI() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | STT_MODEL_NAME = os.getenv("FASTRTC_STT_MODEL", "moonshine/tiny") |
| | TTS_MODEL_NAME = os.getenv("FASTRTC_TTS_MODEL", "kokoro") |
| |
|
| | stt = get_stt_model(model=STT_MODEL_NAME) |
| | tts = get_tts_model(model=TTS_MODEL_NAME) |
| |
|
| |
|
| | def _voice_reply_fn(audio: tuple[int, np.ndarray]): |
| | """ |
| | Called when the user pauses (VAD). Returns streamed audio frames (TTS). |
| | """ |
| | |
| | |
| | user_text = stt.stt(audio).strip() |
| | if not user_text: |
| | return |
| |
|
| | |
| | |
| | |
| | |
| | |
| | voice_session_id = "voice-global" |
| |
|
| | async def run(): |
| | |
| | async def noop_emit(_evt: dict): |
| | return |
| |
|
| | text = await gemini_chat_turn( |
| | session_id=voice_session_id, |
| | user_text=user_text, |
| | emit_event=noop_emit, |
| | model=os.getenv("GEMINI_TEXT_MODEL", "gemini-2.0-flash"), |
| | ) |
| | return text |
| |
|
| | text = asyncio.get_event_loop().run_until_complete(run()) |
| |
|
| | |
| | for chunk in tts.stream_tts_sync(text): |
| | |
| | yield chunk |
| |
|
| |
|
| | voice_stream = Stream( |
| | modality="audio", |
| | mode="send-receive", |
| | handler=ReplyOnPause(_voice_reply_fn), |
| | ) |
| |
|
| | |
| | voice_stream.mount(app, path="/rtc") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @app.get("/") |
| | async def root(): |
| | return JSONResponse( |
| | { |
| | "ok": True, |
| | "service": "salexai-api", |
| | "ws": "/ws", |
| | "fastrtc": "/rtc", |
| | "notes": [ |
| | "Use /ws for Scratch JSON chat + function calling.", |
| | "Use /rtc for FastRTC voice chat endpoints (VAD/STT/TTS handled by FastRTC).", |
| | ], |
| | } |
| | ) |
| |
|
| |
|
| | @app.websocket("/ws") |
| | async def ws_endpoint(ws: WebSocket): |
| | await ws.accept() |
| |
|
| | session_id: Optional[str] = None |
| |
|
| | async def emit(evt: dict): |
| | await ws.send_text(json.dumps(evt)) |
| |
|
| | try: |
| | while True: |
| | raw = await ws.receive_text() |
| | msg = json.loads(raw) if raw else {} |
| |
|
| | mtype = msg.get("type") |
| |
|
| | if mtype == "connect": |
| | session_id = msg.get("session_id") or str(uuid.uuid4()) |
| | get_session(session_id) |
| | await emit({"type": "ready", "session_id": session_id}) |
| | continue |
| |
|
| | if not session_id: |
| | await emit({"type": "error", "message": "Not connected. Send {type:'connect'} first."}) |
| | continue |
| |
|
| | |
| |
|
| | if mtype == "add_function": |
| | name = str(msg.get("name") or "").strip() |
| | schema = msg.get("schema") or {} |
| | if not name: |
| | await emit({"type": "error", "message": "add_function missing name"}) |
| | continue |
| | s = get_session(session_id) |
| | s.functions[name] = schema |
| | await emit({"type": "function_added", "name": name}) |
| | continue |
| |
|
| | if mtype == "remove_function": |
| | name = str(msg.get("name") or "").strip() |
| | s = get_session(session_id) |
| | if name in s.functions: |
| | s.functions.pop(name, None) |
| | await emit({"type": "function_removed", "name": name}) |
| | else: |
| | await emit({"type": "warning", "message": f"Function not found: {name}"}) |
| | continue |
| |
|
| | if mtype == "list_functions": |
| | s = get_session(session_id) |
| | await emit({"type": "functions", "items": list(s.functions.keys())}) |
| | continue |
| |
|
| | |
| | if mtype == "function_result": |
| | call_id = msg.get("call_id") |
| | result = msg.get("result") |
| | if not call_id: |
| | await emit({"type": "error", "message": "function_result missing call_id"}) |
| | continue |
| | ok = deliver_function_result(session_id, call_id, result) |
| | if not ok: |
| | await emit({"type": "warning", "message": f"No pending call_id: {call_id}"}) |
| | else: |
| | await emit({"type": "function_result_ack", "call_id": call_id}) |
| | continue |
| |
|
| | |
| |
|
| | if mtype == "send": |
| | text = str(msg.get("text") or "") |
| | if not text.strip(): |
| | await emit({"type": "error", "message": "Empty text"}) |
| | continue |
| |
|
| | try: |
| | assistant_text = await gemini_chat_turn( |
| | session_id=session_id, |
| | user_text=text, |
| | emit_event=emit, |
| | model=os.getenv("GEMINI_TEXT_MODEL", "gemini-2.0-flash"), |
| | ) |
| | await emit({"type": "assistant", "text": assistant_text}) |
| | except Exception as e: |
| | await emit({"type": "error", "message": f"Gemini error: {e}"}) |
| | continue |
| |
|
| | await emit({"type": "error", "message": f"Unknown type: {mtype}"}) |
| |
|
| | except WebSocketDisconnect: |
| | return |
| | except Exception as e: |
| | try: |
| | await emit({"type": "error", "message": f"WS crashed: {e}"}) |
| | except Exception: |
| | pass |
| |
|