api / src /genai /utils.py
Eli Safra
Deploy SolarWine API (FastAPI + Docker, port 7860)
938949f
"""
Shared utilities for working with the Google Gemini (genai) client.
Centralises:
- GOOGLE_API_KEY resolution (Streamlit secrets → environment variable).
- genai.Client construction.
- Robust JSON object extraction from model responses.
"""
from __future__ import annotations
import json
import os
import re
from typing import Optional
def get_google_api_key(explicit: Optional[str] = None) -> str:
"""
Resolve the Google API key used for Gemini.
Resolution order:
1. Explicit argument (if non-empty).
2. Streamlit secrets["GOOGLE_API_KEY"] (if available and non-empty).
3. Environment variable GOOGLE_API_KEY.
Raises
------
ValueError
If no key can be found.
"""
if explicit:
return explicit
# Try Streamlit secrets if available
try:
import streamlit as st # type: ignore
key = st.secrets.get("GOOGLE_API_KEY", "")
if key:
return str(key)
except Exception:
pass
key = os.environ.get("GOOGLE_API_KEY", "").strip()
if not key:
raise ValueError(
"GOOGLE_API_KEY not found. Set it as an environment variable or in "
"Streamlit secrets."
)
return key
def get_genai_client(api_key: Optional[str] = None):
"""
Construct and return a google.genai.Client using the resolved API key.
Parameters
----------
api_key : str, optional
Explicit key to use; falls back to get_google_api_key() when None or empty.
"""
try:
from google import genai # type: ignore
except ImportError as e:
raise ImportError(
"Could not import 'google.genai'. Install the Gemini SDK with:\n"
" pip install google-genai\n"
"Then run Streamlit using the same Python environment (e.g. activate "
"your venv or conda env before 'streamlit run app.py')."
) from e
key = get_google_API_key_safe(api_key)
return genai.Client(api_key=key)
def get_google_API_key_safe(explicit: Optional[str] = None) -> str:
"""
Wrapper for get_google_api_key used internally to avoid circular imports.
Kept separate so that callers can patch or override in tests if needed.
"""
return get_google_api_key(explicit)
def extract_json_object(text: str) -> dict:
"""
Extract a JSON object from raw model text.
Strips optional markdown ``` fences and returns the first {...} block.
Raises
------
ValueError
If no JSON object can be found or parsed.
"""
# Strip markdown fences like ```json ... ```
cleaned = re.sub(r"```(?:json)?\s*", "", text).strip().rstrip("`").strip()
start = cleaned.find("{")
end = cleaned.rfind("}") + 1
if start == -1 or end <= start:
raise ValueError(f"No JSON object found in LLM response:\n{cleaned[:300]}")
return json.loads(cleaned[start:end])