SalexAI commited on
Commit
dccec3c
·
verified ·
1 Parent(s): 0a9dfed

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +42 -79
app/main.py CHANGED
@@ -1,26 +1,21 @@
1
  import asyncio
2
  import base64
3
  import json
4
- import os
5
  import uuid
6
- from typing import AsyncGenerator, Literal, Optional
7
 
8
  import numpy as np
9
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
10
- from fastapi.responses import JSONResponse, StreamingResponse
11
- from dotenv import load_dotenv
12
 
13
- from fastrtc import AdditionalOutputs, AsyncStreamHandler, Stream, wait_for_item
14
-
15
- # ---- Gemini (optional for later; right now we keep your echo handler working) ----
16
- # You can plug Gemini back in once bridge works.
17
-
18
- load_dotenv()
19
 
20
  app = FastAPI()
21
 
 
22
  # ---------------------------
23
- # Minimal VAD echo handler (server is already booting with this)
 
24
  # ---------------------------
25
  class EchoHandler(AsyncStreamHandler):
26
  def __init__(self, expected_layout: Literal["mono"] = "mono", output_sample_rate: int = 24000):
@@ -38,22 +33,24 @@ class EchoHandler(AsyncStreamHandler):
38
  if audio.dtype != np.int16:
39
  audio = audio.astype(np.int16)
40
 
41
- # Echo back immediately as "audio"
42
  self.out_q.put_nowait((sr, audio.reshape(1, -1)))
43
 
44
  async def emit(self):
45
  return await wait_for_item(self.out_q)
46
 
47
 
 
48
  stream = Stream(
49
  handler=EchoHandler(),
50
  modality="audio",
51
  mode="send-receive",
52
- additional_inputs=["voice_name"], # placeholder for later
53
  )
54
 
 
55
  stream.mount(app)
56
 
 
57
  # ---------------------------
58
  # Helpers
59
  # ---------------------------
@@ -67,78 +64,56 @@ def int16_to_b64(audio: np.ndarray) -> str:
67
  return base64.b64encode(audio.tobytes()).decode("utf-8")
68
 
69
 
70
- # ---------------------------
71
- # Basic endpoints
72
- # ---------------------------
73
  @app.get("/")
74
  async def root():
75
- return {"ok": True, "message": "FastRTC mounted. Use the mounted endpoints for WebRTC/WebSocket."}
76
 
77
  @app.get("/health")
78
  async def health():
79
  return {"ok": True}
80
 
81
- @app.get("/webrtc/new")
82
- async def webrtc_new():
83
- """
84
- Mint a webrtc_id to use with /outputs or /ws bridge.
85
- """
86
- webrtc_id = str(uuid.uuid4())
87
- # Initialize internal connection state so output_stream has something to bind to later
88
- # (FastRTC will create it lazily when first used, but we create a stable id for the client.)
89
- return {"webrtc_id": webrtc_id}
90
-
91
- @app.get("/outputs")
92
- async def outputs(webrtc_id: str):
93
- async def event_stream():
94
- async for out in stream.output_stream(webrtc_id):
95
- payload = json.dumps(out.args[0] if out.args else None)
96
- yield f"event: output\ndata: {payload}\n\n"
97
- return StreamingResponse(event_stream(), media_type="text/event-stream")
98
-
99
 
100
  # ---------------------------
101
- # Scratch-friendly WebSocket bridge
102
  # ---------------------------
103
  @app.websocket("/ws")
104
  async def ws_bridge(ws: WebSocket):
105
  await ws.accept()
106
 
107
- webrtc_id: Optional[str] = None
108
- out_task: Optional[asyncio.Task] = None
109
 
110
- async def send_outputs_loop():
111
- # Stream AdditionalOutputs + audio coming out of FastRTC
 
 
 
112
  try:
113
- async for item in stream.output_stream(webrtc_id):
114
- # item is AdditionalOutputs; forward as JSON
115
- msg = item.args[0] if item.args else None
116
- await ws.send_text(json.dumps({"type": "output", "data": msg}))
117
- except Exception:
118
- pass
119
 
120
- async def send_audio_loop():
121
- # Also poll the "audio" output if your handler emits raw audio tuples.
122
- # FastRTC output_stream yields AdditionalOutputs only.
123
- # So for audio we use stream.fetch_output(...) style by calling internal generator:
124
- try:
125
- async for out in stream.stream_output(webrtc_id):
126
- # out can be (sr, np.ndarray) or AdditionalOutputs
127
  if isinstance(out, AdditionalOutputs):
 
 
128
  continue
 
129
  sr, audio = out
130
  audio = np.asarray(audio)
131
  if audio.ndim == 2:
132
  audio = audio.squeeze()
133
  if audio.dtype != np.int16:
134
  audio = audio.astype(np.int16)
 
135
  await ws.send_text(json.dumps({
136
  "type": "audio_delta",
137
  "rate": int(sr),
138
- "data": int16_to_b64(audio)
139
  }))
140
  except Exception:
141
- pass
142
 
143
  try:
144
  while True:
@@ -147,39 +122,27 @@ async def ws_bridge(ws: WebSocket):
147
  t = msg.get("type")
148
 
149
  if t == "connect":
150
- # create or use provided webrtc_id
151
- webrtc_id = msg.get("webrtc_id") or str(uuid.uuid4())
152
-
153
- # optionally set voice / other inputs (stored for handler)
154
- voice = msg.get("voice") or "Puck"
155
- try:
156
- await stream.set_input(webrtc_id, voice)
157
- except Exception:
158
- # if set_input isn't supported in your exact FastRTC build, ignore
159
- pass
160
-
161
- # start output loops once
162
- if out_task is None:
163
- out_task = asyncio.gather(send_audio_loop(), send_outputs_loop())
164
-
165
- await ws.send_text(json.dumps({"type": "ready", "webrtc_id": webrtc_id}))
166
  continue
167
 
168
  if t == "audio":
169
- if not webrtc_id:
170
- await ws.send_text(json.dumps({"type": "error", "message": "Not connected. Send {type:'connect'} first."}))
171
  continue
172
 
173
  b64 = msg.get("data")
174
  rate = int(msg.get("rate") or 16000)
175
-
176
  if not isinstance(b64, str) or not b64:
177
  continue
178
 
179
  audio = b64_to_int16(b64)
180
-
181
- # FastRTC expects (sample_rate, np.ndarray)
182
- await stream.send_input(webrtc_id, (rate, audio.reshape(1, -1)))
183
  continue
184
 
185
  if t == "close":
@@ -192,7 +155,7 @@ async def ws_bridge(ws: WebSocket):
192
  pass
193
  finally:
194
  try:
195
- if out_task:
196
- out_task.cancel()
197
  except Exception:
198
  pass
 
1
  import asyncio
2
  import base64
3
  import json
 
4
  import uuid
5
+ from typing import Optional, Literal
6
 
7
  import numpy as np
8
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
9
+ from fastapi.responses import JSONResponse
 
10
 
11
+ from fastrtc import Stream, AsyncStreamHandler, wait_for_item, AdditionalOutputs
 
 
 
 
 
12
 
13
  app = FastAPI()
14
 
15
+
16
  # ---------------------------
17
+ # A tiny headless audio handler (echo) to validate the pipe.
18
+ # Swap this out later for Gemini / other realtime models.
19
  # ---------------------------
20
  class EchoHandler(AsyncStreamHandler):
21
  def __init__(self, expected_layout: Literal["mono"] = "mono", output_sample_rate: int = 24000):
 
33
  if audio.dtype != np.int16:
34
  audio = audio.astype(np.int16)
35
 
36
+ # Echo straight back
37
  self.out_q.put_nowait((sr, audio.reshape(1, -1)))
38
 
39
  async def emit(self):
40
  return await wait_for_item(self.out_q)
41
 
42
 
43
+ # IMPORTANT: no additional_inputs here (strings crash in 0.0.34)
44
  stream = Stream(
45
  handler=EchoHandler(),
46
  modality="audio",
47
  mode="send-receive",
 
48
  )
49
 
50
+ # This mounts FastRTC’s internal routes; we’re also adding /ws below.
51
  stream.mount(app)
52
 
53
+
54
  # ---------------------------
55
  # Helpers
56
  # ---------------------------
 
64
  return base64.b64encode(audio.tobytes()).decode("utf-8")
65
 
66
 
 
 
 
67
  @app.get("/")
68
  async def root():
69
+ return {"ok": True, "message": "FastRTC mounted. Headless mode. Use /ws for Scratch."}
70
 
71
  @app.get("/health")
72
  async def health():
73
  return {"ok": True}
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # ---------------------------
77
+ # Scratch-friendly WS bridge (no WebRTC needed client-side)
78
  # ---------------------------
79
  @app.websocket("/ws")
80
  async def ws_bridge(ws: WebSocket):
81
  await ws.accept()
82
 
83
+ session_id: Optional[str] = None
84
+ pump_task: Optional[asyncio.Task] = None
85
 
86
+ async def pump_outputs():
87
+ """
88
+ Pull audio outputs from FastRTC and forward to client.
89
+ NOTE: FastRTC 0.0.34 uses fetch_output() polling style.
90
+ """
91
  try:
92
+ while True:
93
+ out = await stream.fetch_output(session_id)
94
+ if out is None:
95
+ await asyncio.sleep(0.01)
96
+ continue
 
97
 
 
 
 
 
 
 
 
98
  if isinstance(out, AdditionalOutputs):
99
+ payload = out.args[0] if out.args else None
100
+ await ws.send_text(json.dumps({"type": "output", "data": payload}))
101
  continue
102
+
103
  sr, audio = out
104
  audio = np.asarray(audio)
105
  if audio.ndim == 2:
106
  audio = audio.squeeze()
107
  if audio.dtype != np.int16:
108
  audio = audio.astype(np.int16)
109
+
110
  await ws.send_text(json.dumps({
111
  "type": "audio_delta",
112
  "rate": int(sr),
113
+ "data": int16_to_b64(audio),
114
  }))
115
  except Exception:
116
+ return
117
 
118
  try:
119
  while True:
 
122
  t = msg.get("type")
123
 
124
  if t == "connect":
125
+ session_id = msg.get("session_id") or str(uuid.uuid4())
126
+
127
+ # start pump once
128
+ if pump_task is None:
129
+ pump_task = asyncio.create_task(pump_outputs())
130
+
131
+ await ws.send_text(json.dumps({"type": "ready", "session_id": session_id}))
 
 
 
 
 
 
 
 
 
132
  continue
133
 
134
  if t == "audio":
135
+ if not session_id:
136
+ await ws.send_text(json.dumps({"type": "error", "message": "Send {type:'connect'} first."}))
137
  continue
138
 
139
  b64 = msg.get("data")
140
  rate = int(msg.get("rate") or 16000)
 
141
  if not isinstance(b64, str) or not b64:
142
  continue
143
 
144
  audio = b64_to_int16(b64)
145
+ await stream.send_input(session_id, (rate, audio.reshape(1, -1)))
 
 
146
  continue
147
 
148
  if t == "close":
 
155
  pass
156
  finally:
157
  try:
158
+ if pump_task:
159
+ pump_task.cancel()
160
  except Exception:
161
  pass