api / app /gemini_text.py
SalexAI's picture
Create gemini_text.py
ed9acde verified
from __future__ import annotations
import asyncio
import json
import os
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, List
from google import genai
# ----------------------------
# Session state
# ----------------------------
@dataclass
class ToolCallAwaiter:
fut: asyncio.Future
@dataclass
class SessionState:
history: List[dict] = field(default_factory=list)
# name -> schema dict (Scratch-provided, you control the format)
functions: Dict[str, dict] = field(default_factory=dict)
# call_id -> awaiter
pending_calls: Dict[str, ToolCallAwaiter] = field(default_factory=dict)
SESSIONS: Dict[str, SessionState] = {}
def get_session(session_id: str) -> SessionState:
if session_id not in SESSIONS:
SESSIONS[session_id] = SessionState()
return SESSIONS[session_id]
# ----------------------------
# Gemini client
# ----------------------------
def _get_genai_client() -> genai.Client:
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise RuntimeError("Missing GEMINI_API_KEY env var.")
return genai.Client(api_key=api_key)
def _scratch_schema_to_gemini_decl(name: str, schema: dict) -> dict:
"""
Convert a Scratch-side function schema into a Gemini-compatible function declaration.
Expected Scratch schema (example):
{
"description": "Open the settings page",
"parameters": {
"type": "object",
"properties": {
"tab": {"type":"string", "description":"Which tab to open"}
},
"required": ["tab"]
}
}
"""
desc = (schema or {}).get("description", "")
params = (schema or {}).get("parameters") or {"type": "object", "properties": {}}
return {
"name": name,
"description": desc,
"parameters": params,
}
async def gemini_chat_turn(
*,
session_id: str,
user_text: str,
emit_event, # async fn(dict) -> None (send to ws client)
model: str = "gemini-2.0-flash",
) -> str:
"""
Sends one user turn to Gemini Flash (text), supports tool calling by bouncing tool calls to the WS client.
"""
s = get_session(session_id)
client = _get_genai_client()
# Build tool declarations from session functions
tool_decls = []
for fname, fschema in s.functions.items():
tool_decls.append(_scratch_schema_to_gemini_decl(fname, fschema))
# Build content. Keep it simple + stable.
# Note: google-genai accepts "contents" as a list of role/content dicts.
s.history.append({"role": "user", "parts": [{"text": user_text}]})
# We run a loop because Gemini might call tools then continue.
while True:
resp = client.models.generate_content(
model=model,
contents=s.history,
config={
"tools": [{"function_declarations": tool_decls}] if tool_decls else None,
# Keep responses short-ish for Scratch club usage
"temperature": 0.6,
},
)
# google-genai response parsing varies across versions; handle robustly:
# We look for:
# - normal text in resp.candidates[].content.parts[].text
# - tool call in resp.candidates[].content.parts[].function_call
cand = (getattr(resp, "candidates", None) or [None])[0]
content = getattr(cand, "content", None) if cand else None
parts = getattr(content, "parts", None) if content else None
parts = parts or []
# Extract tool calls + text chunks
tool_calls = []
text_chunks = []
for p in parts:
fc = getattr(p, "function_call", None)
tx = getattr(p, "text", None)
if tx:
text_chunks.append(tx)
if fc:
# fc has name + args
name = getattr(fc, "name", None)
args = getattr(fc, "args", None)
if isinstance(args, str):
try:
args = json.loads(args)
except Exception:
args = {"_raw": args}
tool_calls.append({"name": name, "args": args or {}})
# If we got text and no tools, we’re done
if text_chunks and not tool_calls:
assistant_text = "".join(text_chunks).strip()
s.history.append({"role": "model", "parts": [{"text": assistant_text}]})
return assistant_text
# If tools were requested, execute via WS client
if tool_calls:
for tc in tool_calls:
fname = tc["name"] or "unknown_function"
fargs = tc["args"] or {}
call_id = str(uuid.uuid4())
fut = asyncio.get_event_loop().create_future()
s.pending_calls[call_id] = ToolCallAwaiter(fut=fut)
await emit_event(
{
"type": "function_called",
"call_id": call_id,
"name": fname,
"arguments": fargs,
}
)
# Wait for Scratch to respond with function_result
result = await fut
# Add the tool result back to Gemini’s history
# Tool response format: role "tool" with function_response part.
s.history.append(
{
"role": "tool",
"parts": [
{
"function_response": {
"name": fname,
"response": {"result": result},
}
}
],
}
)
# Loop continues to let Gemini produce final text after tools
# If no text and no tool calls, fallback
if not text_chunks and not tool_calls:
assistant_text = "(No response.)"
s.history.append({"role": "model", "parts": [{"text": assistant_text}]})
return assistant_text
def deliver_function_result(session_id: str, call_id: str, result: Any) -> bool:
s = get_session(session_id)
aw = s.pending_calls.get(call_id)
if not aw:
return False
if not aw.fut.done():
aw.fut.set_result(result)
s.pending_calls.pop(call_id, None)
return True