SalexAI commited on
Commit
9c5a6e7
·
verified ·
1 Parent(s): 277bec4

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +205 -313
app/main.py CHANGED
@@ -1,335 +1,227 @@
1
- # app/main.py
2
- import os
3
- import json
4
  import asyncio
5
- from typing import Any, Dict, Optional, List
 
 
 
6
 
7
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
8
- from fastapi.responses import JSONResponse
9
  from dotenv import load_dotenv
10
- import websockets
11
-
12
- load_dotenv()
13
-
14
- app = FastAPI(title="Gemini Live Native-Audio WS Proxy", version="2.1.0")
15
-
16
- GEMINI_LIVE_WS_URL = (
17
- "wss://generativelanguage.googleapis.com/ws/"
18
- "google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent"
19
  )
20
 
21
- API_KEY = os.getenv("GEMINI_API_KEY", "").strip()
22
-
23
- # IMPORTANT: pick a REAL default model here (must support Live + native audio)
24
- # Put your known-working native audio model id below:
25
- FALLBACK_NATIVE_AUDIO_MODEL = "models/gemini-2.5-flash-native-audio-preview-12-2025"
26
-
27
- DEFAULT_MODEL = os.getenv("GEMINI_MODEL", FALLBACK_NATIVE_AUDIO_MODEL)
28
- DEFAULT_SYSTEM = os.getenv(
29
- "GEMINI_SYSTEM_INSTRUCTION",
30
- "You are an AI voice assitent named arrow, Please attempt to understand all commands as best you can, Be helpful and use inbulit functions instead of guessing."
31
  )
32
- DEFAULT_TEMPERATURE = float(os.getenv("GEMINI_TEMPERATURE", "0.7"))
33
- DEFAULT_MAX_TOKENS = int(os.getenv("GEMINI_MAX_OUTPUT_TOKENS", "1024"))
34
 
35
- DEFAULT_VOICE = os.getenv("GEMINI_VOICE_NAME", "Kore")
36
- DEFAULT_INPUT_RATE = int(os.getenv("GEMINI_INPUT_AUDIO_RATE", "16000"))
37
- DEFAULT_OUTPUT_RATE = int(os.getenv("GEMINI_OUTPUT_AUDIO_RATE", "24000"))
38
 
39
- DEBUG_GEMINI_RAW = os.getenv("DEBUG_GEMINI_RAW", "0").strip() == "1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- def _clean_str(x: Any) -> str:
43
- if not isinstance(x, str):
44
- return ""
45
- return x.strip()
46
 
 
 
47
 
48
- def _is_bad_model(s: str) -> bool:
49
- s2 = (s or "").strip().lower()
50
- return (not s2) or (s2 in {"undefined", "null", "none"})
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- def _safe_model(model: Any) -> str:
54
- m = _clean_str(model)
55
- if _is_bad_model(m):
56
- m = _clean_str(DEFAULT_MODEL)
57
- if _is_bad_model(m):
58
- m = FALLBACK_NATIVE_AUDIO_MODEL
59
- return m
60
 
61
 
62
  @app.get("/health")
63
  async def health():
64
- model = _safe_model(DEFAULT_MODEL)
65
- ok = bool(API_KEY)
66
- return JSONResponse(
67
- {
68
- "ok": ok,
69
- "has_api_key": ok,
70
- "model": model,
71
- "voice": DEFAULT_VOICE,
72
- "input_rate": DEFAULT_INPUT_RATE,
73
- "output_rate": DEFAULT_OUTPUT_RATE,
74
- "debug_raw": DEBUG_GEMINI_RAW,
75
- }
76
- )
77
-
78
-
79
- def _extract_text_parts(content: Dict[str, Any]) -> str:
80
- parts = content.get("parts") or []
81
- out: List[str] = []
82
- for p in parts:
83
- if isinstance(p, dict) and isinstance(p.get("text"), str):
84
- out.append(p["text"])
85
- return "".join(out)
86
-
87
-
88
- def _extract_inline_audio_parts(content: Dict[str, Any]) -> List[Dict[str, str]]:
89
- parts = content.get("parts") or []
90
- out: List[Dict[str, str]] = []
91
- for p in parts:
92
- if not isinstance(p, dict):
93
- continue
94
- inline = p.get("inlineData")
95
- if isinstance(inline, dict):
96
- data = inline.get("data")
97
- mime = inline.get("mimeType")
98
- if isinstance(data, str) and isinstance(mime, str):
99
- out.append({"mime": mime, "data": data})
100
- return out
101
-
102
-
103
- async def _gemini_ws_connect(setup_payload: Dict[str, Any]):
104
- headers = {"x-goog-api-key": API_KEY}
105
- ws = await websockets.connect(
106
- GEMINI_LIVE_WS_URL,
107
- extra_headers=headers,
108
- max_size=32 * 1024 * 1024,
109
- ping_interval=20,
110
- ping_timeout=20,
111
- )
112
-
113
- await ws.send(json.dumps(setup_payload))
114
-
115
- while True:
116
- raw = await ws.recv()
117
- msg = json.loads(raw)
118
- if "setupComplete" in msg:
119
- return ws
120
- if "error" in msg:
121
- raise RuntimeError(f"Gemini setup error: {msg['error']}")
122
-
123
-
124
- @app.websocket("/ws")
125
- async def ws_proxy(client_ws: WebSocket):
126
- await client_ws.accept()
127
-
128
- if not API_KEY:
129
- await client_ws.send_text(json.dumps({"type": "error", "message": "Missing GEMINI_API_KEY on server."}))
130
- await client_ws.close(code=1011)
131
- return
132
-
133
- # Defaults per connection
134
- cfg = {
135
- "model": _safe_model(DEFAULT_MODEL),
136
- "system_instruction": _clean_str(DEFAULT_SYSTEM) or "You are helpful.",
137
- "temperature": DEFAULT_TEMPERATURE,
138
- "max_output_tokens": DEFAULT_MAX_TOKENS,
139
- "voice": _clean_str(DEFAULT_VOICE) or "Kore",
140
- "input_rate": DEFAULT_INPUT_RATE,
141
  }
142
 
143
- # Wait briefly for optional configure (FIRST message)
144
- pending_first: Optional[Dict[str, Any]] = None
145
- try:
146
- raw = await asyncio.wait_for(client_ws.receive_text(), timeout=1.2)
147
- first = json.loads(raw)
148
- if isinstance(first, dict) and first.get("type") == "configure":
149
- cfg["model"] = _safe_model(first.get("model"))
150
- si = _clean_str(first.get("system_instruction"))
151
- if si:
152
- cfg["system_instruction"] = si
153
- try:
154
- if first.get("temperature") is not None:
155
- cfg["temperature"] = float(first["temperature"])
156
- except Exception:
157
- pass
158
- try:
159
- if first.get("max_output_tokens") is not None:
160
- cfg["max_output_tokens"] = int(first["max_output_tokens"])
161
- except Exception:
162
- pass
163
- v = _clean_str(first.get("voice"))
164
- if v:
165
- cfg["voice"] = v
166
- try:
167
- if first.get("input_rate") is not None:
168
- cfg["input_rate"] = int(first["input_rate"])
169
- except Exception:
170
- pass
171
-
172
- await client_ws.send_text(json.dumps({"type": "configured"}))
173
- else:
174
- pending_first = first if isinstance(first, dict) else None
175
- except asyncio.TimeoutError:
176
- pass
177
- except Exception:
178
- pass
179
-
180
- # FINAL guard (this prevents “undefined” ever reaching Gemini)
181
- cfg["model"] = _safe_model(cfg["model"])
182
-
183
- # Build native-audio session setup
184
- setup_payload = {
185
- "setup": {
186
- "model": cfg["model"],
187
- "generationConfig": {
188
- "temperature": cfg["temperature"],
189
- "maxOutputTokens": cfg["max_output_tokens"],
190
- "responseModalities": ["AUDIO"],
191
- "speechConfig": {
192
- "voiceConfig": {
193
- "prebuiltVoiceConfig": {
194
- "voiceName": cfg["voice"],
195
- }
196
- }
197
- },
198
- },
199
- "inputAudioTranscription": {},
200
- "outputAudioTranscription": {},
201
- "systemInstruction": {
202
- "role": "system",
203
- "parts": [{"text": cfg["system_instruction"]}],
204
- },
205
- }
206
- }
207
-
208
- stop_event = asyncio.Event()
209
- gemini_ws = None
210
-
211
- try:
212
- gemini_ws = await _gemini_ws_connect(setup_payload)
213
- await client_ws.send_text(json.dumps({"type": "ready", "model": cfg["model"]}))
214
- except Exception as e:
215
- await client_ws.send_text(json.dumps({"type": "error", "message": f"Gemini setup failed: {e}"}))
216
- await client_ws.close(code=1011)
217
- return
218
-
219
- async def forward_client_to_gemini():
220
- nonlocal pending_first
221
- try:
222
- while not stop_event.is_set():
223
- if pending_first is not None:
224
- data = pending_first
225
- pending_first = None
226
- else:
227
- raw = await client_ws.receive_text()
228
- data = json.loads(raw)
229
-
230
- t = data.get("type")
231
-
232
- if t == "close":
233
- stop_event.set()
234
- return
235
-
236
- if t == "audio":
237
- b64 = data.get("data")
238
- rate = data.get("rate", cfg["input_rate"])
239
- if not isinstance(b64, str) or not b64:
240
- continue
241
- try:
242
- rate_i = int(rate)
243
- except Exception:
244
- rate_i = cfg["input_rate"]
245
-
246
- payload = {
247
- "realtimeInput": {
248
- "audio": {
249
- "data": b64,
250
- "mimeType": f"audio/pcm;rate={rate_i}",
251
- }
252
- }
253
- }
254
- await gemini_ws.send(json.dumps(payload))
255
- continue
256
-
257
- if t == "audio_end":
258
- await gemini_ws.send(json.dumps({"realtimeInput": {"audioStreamEnd": True}}))
259
- continue
260
-
261
- if t == "text":
262
- text = data.get("text", "")
263
- if isinstance(text, str) and text.strip():
264
- payload = {
265
- "clientContent": {
266
- "turns": [{"role": "user", "parts": [{"text": text.strip()}]}],
267
- "turnComplete": True,
268
- }
269
- }
270
- await gemini_ws.send(json.dumps(payload))
271
- continue
272
-
273
- await client_ws.send_text(json.dumps({"type": "error", "message": f"Unknown message type: {t}"}))
274
-
275
- except WebSocketDisconnect:
276
- stop_event.set()
277
- except Exception as e:
278
- stop_event.set()
279
- try:
280
- await client_ws.send_text(json.dumps({"type": "error", "message": str(e)}))
281
- except Exception:
282
- pass
283
 
284
- async def forward_gemini_to_client():
285
- try:
286
- while not stop_event.is_set():
287
- raw = await gemini_ws.recv()
288
- msg = json.loads(raw)
289
-
290
- if DEBUG_GEMINI_RAW:
291
- await client_ws.send_text(json.dumps({"type": "gemini_raw", "message": msg}))
292
-
293
- server_content = msg.get("serverContent")
294
- if isinstance(server_content, dict):
295
- model_turn = server_content.get("modelTurn")
296
- if isinstance(model_turn, dict):
297
- txt = _extract_text_parts(model_turn)
298
- if txt:
299
- await client_ws.send_text(json.dumps({"type": "text_delta", "text": txt}))
300
-
301
- audios = _extract_inline_audio_parts(model_turn)
302
- for a in audios:
303
- await client_ws.send_text(
304
- json.dumps({"type": "audio_delta", "mime": a["mime"], "data": a["data"]})
305
- )
306
-
307
- out_tx = server_content.get("outputTranscription")
308
- if isinstance(out_tx, dict) and isinstance(out_tx.get("text"), str):
309
- await client_ws.send_text(
310
- json.dumps({"type": "output_transcript_delta", "text": out_tx["text"]})
311
- )
312
-
313
- if server_content.get("generationComplete") is True:
314
- await client_ws.send_text(json.dumps({"type": "turn_complete"}))
315
-
316
- except Exception as e:
317
- stop_event.set()
318
- try:
319
- await client_ws.send_text(json.dumps({"type": "error", "message": f"Gemini link error: {e}"}))
320
- except Exception:
321
- pass
322
 
323
- try:
324
- await asyncio.gather(forward_client_to_gemini(), forward_gemini_to_client())
325
- finally:
326
- stop_event.set()
327
- try:
328
- if gemini_ws is not None:
329
- await gemini_ws.close()
330
- except Exception:
331
- pass
332
- try:
333
- await client_ws.close()
334
- except Exception:
335
- pass
 
 
 
 
1
  import asyncio
2
+ import base64
3
+ import json
4
+ import os
5
+ from typing import AsyncGenerator, Literal
6
 
7
+ import numpy as np
 
8
  from dotenv import load_dotenv
9
+ from fastapi import FastAPI
10
+ from fastapi.responses import StreamingResponse
11
+
12
+ from fastrtc import (
13
+ AdditionalOutputs,
14
+ AsyncStreamHandler,
15
+ Stream,
16
+ wait_for_item,
 
17
  )
18
 
19
+ from google import genai
20
+ from google.genai.types import (
21
+ LiveConnectConfig,
22
+ PrebuiltVoiceConfig,
23
+ SpeechConfig,
24
+ VoiceConfig,
 
 
 
 
25
  )
 
 
26
 
27
+ load_dotenv()
 
 
28
 
29
+ # ---------------------------
30
+ # Config (env vars)
31
+ # ---------------------------
32
+ # Put this in your HF Space "Secrets":
33
+ # GEMINI_API_KEY = "..."
34
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
35
+
36
+ # Gemini realtime model (this is the one FastRTC uses in their Gemini demo Space)
37
+ # You can change this later to another Live-capable model.
38
+ GEMINI_LIVE_MODEL = os.getenv("GEMINI_LIVE_MODEL", "gemini-2.0-flash-exp")
39
+
40
+ # Voice name (FastRTC Gemini demo uses "Puck" by default)
41
+ DEFAULT_VOICE = os.getenv("GEMINI_VOICE", "Puck")
42
+
43
+ # Sample rates
44
+ OUTPUT_SAMPLE_RATE = int(os.getenv("OUTPUT_SAMPLE_RATE", "24000"))
45
+ INPUT_SAMPLE_RATE = int(os.getenv("INPUT_SAMPLE_RATE", "16000")) # matches the demo Space
46
+
47
+
48
+ def _encode_pcm16_mono_to_b64(data: np.ndarray) -> str:
49
+ """
50
+ Encodes int16 mono PCM to base64 for any custom debug endpoints.
51
+ """
52
+ if data.dtype != np.int16:
53
+ data = data.astype(np.int16)
54
+ return base64.b64encode(data.tobytes()).decode("utf-8")
55
+
56
+
57
+ class GeminiLiveAudioHandler(AsyncStreamHandler):
58
+ """
59
+ FastRTC AsyncStreamHandler that connects to Gemini Live and streams AUDIO back.
60
+
61
+ This is adapted from the official FastRTC Gemini demo Space code. :contentReference[oaicite:5]{index=5}
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ expected_layout: Literal["mono"] = "mono",
67
+ output_sample_rate: int = OUTPUT_SAMPLE_RATE,
68
+ ) -> None:
69
+ super().__init__(
70
+ expected_layout=expected_layout,
71
+ output_sample_rate=output_sample_rate,
72
+ input_sample_rate=INPUT_SAMPLE_RATE,
73
+ )
74
+
75
+ self.input_queue: asyncio.Queue[bytes] = asyncio.Queue()
76
+ self.output_queue: asyncio.Queue[tuple[int, np.ndarray] | AdditionalOutputs] = asyncio.Queue()
77
+ self.quit = asyncio.Event()
78
+
79
+ def copy(self) -> "GeminiLiveAudioHandler":
80
+ # FastRTC uses .copy() to clone per-connection handlers
81
+ return GeminiLiveAudioHandler(
82
+ expected_layout="mono",
83
+ output_sample_rate=self.output_sample_rate,
84
+ )
85
+
86
+ async def start_up(self) -> None:
87
+ """
88
+ Connect to Gemini Live, then continuously:
89
+ - read user audio from self.stream()
90
+ - receive model audio chunks and push them to output_queue
91
+ """
92
+ # Optional: allow per-connection overrides via "additional_inputs"
93
+ # We wait for args to be set (FastRTC API docs show wait_for_args usage). :contentReference[oaicite:6]{index=6}
94
+ await self.wait_for_args()
95
+ # latest_args includes metadata at [0]; any custom inputs start at [1]
96
+ # We'll accept: voice_name (str) as the single custom arg, fallback to DEFAULT_VOICE.
97
+ voice_name = DEFAULT_VOICE
98
+ try:
99
+ if len(self.latest_args) >= 2 and isinstance(self.latest_args[1], str) and self.latest_args[1].strip():
100
+ voice_name = self.latest_args[1].strip()
101
+ except Exception:
102
+ pass
103
+
104
+ api_key = GEMINI_API_KEY
105
+ if not api_key:
106
+ # Fail early with a helpful message in the client.
107
+ await self.output_queue.put(
108
+ AdditionalOutputs({"type": "error", "message": "Missing GEMINI_API_KEY env var on the server."})
109
+ )
110
+ return
111
+
112
+ client = genai.Client(
113
+ api_key=api_key,
114
+ http_options={"api_version": "v1alpha"}, # matches FastRTC Gemini demo Space :contentReference[oaicite:7]{index=7}
115
+ )
116
+
117
+ config = LiveConnectConfig(
118
+ response_modalities=["AUDIO"], # AUDIO-only mode :contentReference[oaicite:8]{index=8}
119
+ speech_config=SpeechConfig(
120
+ voice_config=VoiceConfig(
121
+ prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name)
122
+ )
123
+ ),
124
+ )
125
+
126
+ async with client.aio.live.connect(model=GEMINI_LIVE_MODEL, config=config) as session:
127
+ # session.start_stream takes an async generator of bytes
128
+ async for audio in session.start_stream(stream=self._stream_pcm(), mime_type="audio/pcm"):
129
+ if audio.data:
130
+ # Gemini returns pcm16 bytes; convert to int16 array
131
+ arr = np.frombuffer(audio.data, dtype=np.int16)
132
+ # FastRTC expects (sample_rate, np.ndarray) shaped like (1, n) or (n,) depending on handler usage.
133
+ self.output_queue.put_nowait((self.output_sample_rate, arr.reshape(1, -1)))
134
+
135
+ async def _stream_pcm(self) -> AsyncGenerator[bytes, None]:
136
+ """
137
+ Provides PCM bytes to Gemini Live continuously.
138
+ """
139
+ while not self.quit.is_set():
140
+ try:
141
+ chunk = await asyncio.wait_for(self.input_queue.get(), timeout=0.1)
142
+ yield chunk
143
+ except (asyncio.TimeoutError, TimeoutError):
144
+ pass
145
 
146
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
147
+ """
148
+ Called by FastRTC as audio frames arrive from the client.
149
+ """
150
+ _, audio = frame
151
+ # Expect mono, int16-ish. Convert safely.
152
+ audio = np.asarray(audio)
153
+ if audio.ndim == 2:
154
+ audio = audio.squeeze()
155
+ if audio.dtype != np.int16:
156
+ audio = audio.astype(np.int16)
157
+
158
+ # Push raw PCM16 bytes to Gemini stream
159
+ self.input_queue.put_nowait(audio.tobytes())
160
+
161
+ async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
162
+ """
163
+ Called by FastRTC to get the next outbound chunk (audio or structured outputs).
164
+ """
165
+ return await wait_for_item(self.output_queue)
166
+
167
+ async def shutdown(self) -> None:
168
+ self.quit.set()
169
+
170
+
171
+ # ---------------------------
172
+ # FastRTC Stream + FastAPI
173
+ # ---------------------------
174
+
175
+ # We expose one additional input: voice name
176
+ # Clients can set it via Stream.set_input(...) patterns described in the FastRTC API docs. :contentReference[oaicite:9]{index=9}
177
+ stream = Stream(
178
+ handler=GeminiLiveAudioHandler(),
179
+ modality="audio",
180
+ mode="send-receive",
181
+ additional_inputs=[
182
+ # Keep it simple: one string
183
+ # (FastRTC examples often use Gradio components here; in API mode we’ll set via set_input)
184
+ # We still define it so handler.wait_for_args() has something to wait on.
185
+ "voice_name"
186
+ ],
187
+ )
188
 
189
+ app = FastAPI()
 
 
 
190
 
191
+ # Mount FastRTC endpoints onto FastAPI (this is the core feature). :contentReference[oaicite:10]{index=10}
192
+ stream.mount(app)
193
 
 
 
 
194
 
195
+ # ---------------------------
196
+ # Optional: server-side outputs stream (SSE)
197
+ # Works well for Scratch/JS clients that want text/meta without WebRTC.
198
+ # FastRTC docs show using stream.output_stream(webrtc_id). :contentReference[oaicite:11]{index=11}
199
+ # The talk-to-openai Space uses the same approach. :contentReference[oaicite:12]{index=12}
200
+ # ---------------------------
201
+ @app.get("/outputs")
202
+ async def outputs(webrtc_id: str):
203
+ async def event_stream():
204
+ async for out in stream.output_stream(webrtc_id):
205
+ # out is an AdditionalOutputs instance
206
+ # Serialize it as SSE "output" events
207
+ payload = json.dumps(out.args[0] if out.args else None)
208
+ yield f"event: output\ndata: {payload}\n\n"
209
 
210
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
 
 
 
 
 
 
211
 
212
 
213
  @app.get("/health")
214
  async def health():
215
+ return {
216
+ "ok": True,
217
+ "provider": "gemini_live_audio",
218
+ "model": GEMINI_LIVE_MODEL,
219
+ "output_sample_rate": OUTPUT_SAMPLE_RATE,
220
+ "input_sample_rate": INPUT_SAMPLE_RATE,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  }
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ if __name__ == "__main__":
225
+ import uvicorn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+ uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))