AI-Package-Doctor / src /config.py
Yash030's picture
Fixed Openrouter issues and added settings
3a3d624
"""
Configuration module for model initialization and environment setup.
CRITICAL: Includes Ollama integration fix for Google ADK.
"""
import os
from dotenv import load_dotenv
from google.adk.models.lite_llm import LiteLlm
from google.adk.sessions import DatabaseSessionService
from google.genai import types
from .utils import logger
from contextvars import ContextVar
from typing import Optional, Any, Dict
# Global context for tracking the current user session
context_user_id: ContextVar[Optional[str]] = ContextVar("context_user_id", default=None)
# Load environment variables
load_dotenv()
# ===== SSL CONFIGURATION =====
# Fix for SSL certificate errors on Windows
import certifi
os.environ['SSL_CERT_FILE'] = certifi.where()
logger.info(f"SSL Cert File configured: {os.environ['SSL_CERT_FILE']}")
# ===== MODEL INITIALIZATION =====
# Using OpenRouter (Grok) via LiteLLM with Groq Fallback
from google.adk.models.lite_llm import LiteLlm
import asyncio
from typing import AsyncGenerator
import litellm
from google.adk.models.llm_request import LlmRequest
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import text
# Singleton engine for credential lookups
_cred_engine = None
async def get_user_credentials(user_id: str) -> Dict[str, str]:
"""Fetches custom API keys for the given user from the database."""
global _cred_engine
if not user_id:
return {}
if _cred_engine is None:
db_url = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///legacy_solver.db")
_cred_engine = create_async_engine(db_url)
try:
async with _cred_engine.connect() as conn:
result = await conn.execute(
text("SELECT gemini_api_key, openrouter_api_key FROM user_credentials WHERE user_id = :user_id"),
{"user_id": user_id}
)
row = result.fetchone()
if row:
return {
"gemini_api_key": row[0],
"openrouter_api_key": row[1]
}
except Exception as e:
logger.error(f"Error fetching credentials for user {user_id}: {e}")
return {}
class ResilientLiteLlm(LiteLlm):
"""
A wrapper around LiteLlm that falls back to a secondary model if the primary fails.
Specifically handles RateLimitErrors with retries and rotation through multiple models.
"""
def __init__(self, primary_model_names: list[str], fallback_model_name: str = "groq/llama3-70b-8192", **kwargs):
# Initialize with the first model as default
super().__init__(model=primary_model_names[0], **kwargs)
self._primary_models = primary_model_names
self._fallback_model_name = fallback_model_name
# Ensure Groq API Key is present for fallback
if not os.getenv("GROQ_API_KEY"):
logger.warning("GROQ_API_KEY not found. Fallback to Groq will not work.")
async def generate_content_async(self, contents, **kwargs) -> AsyncGenerator:
"""
Attempts to generate content with primary models in rotation.
If a model is rate limited, automatically tries the next one in the list.
"""
# Auto-wrap string inputs into the LlmRequest structure ADK expects
if isinstance(contents, str):
contents = LlmRequest(
contents=[types.Content(parts=[types.Part(text=contents)])]
)
max_retries = 2
retry_delay = 3
# Try rotating through primary models
user_id = context_user_id.get()
creds = await get_user_credentials(user_id) if user_id else {}
custom_or_key = creds.get("openrouter_api_key")
for model_name in self._primary_models:
self.model = model_name
# Inject custom API key if provided by user
if custom_or_key:
kwargs["api_key"] = custom_or_key
# For LiteLLM to use custom base with custom key correctly
kwargs["base_url"] = "https://openrouter.ai/api/v1"
for attempt in range(max_retries):
try:
async for chunk in super().generate_content_async(contents, **kwargs):
yield chunk
return # Success!
except Exception as e:
# Check if it is a rate limit error (429)
is_rate_limit = "429" in str(e) or "RateLimitError" in type(e).__name__
if is_rate_limit:
if attempt < max_retries - 1:
logger.warning(f"Model {model_name} rate limited. Retrying in {retry_delay}s... (Attempt {attempt + 1}/{max_retries})")
await asyncio.sleep(retry_delay)
continue
else:
logger.warning(f"Model {model_name} exhausted. Rotating to next model...")
break # Move to next model in the outer loop
logger.error(f"Primary model ({model_name}) failed with non-retryable error: {e}")
break # Try next model
# Fallback Logic if all primary models fail
logger.info(f"All primary models exhausted. Switching to fallback: {self._fallback_model_name}")
self.model = self._fallback_model_name
# Also inject for fallback if applicable (though Groq key is usually system-wide)
if custom_or_key and self.model.startswith("openrouter/"):
kwargs["api_key"] = custom_or_key
try:
async for chunk in super().generate_content_async(contents, **kwargs):
yield chunk
logger.info("Fallback successful")
except Exception as fallback_error:
logger.error(f"Fallback model ({self._fallback_model_name}) also failed: {fallback_error}")
raise fallback_error
finally:
# Revert model to first primary for next call
self.model = self._primary_models[0]
def get_model():
"""Returns a configured ResilientLiteLlm model instance with rotation support."""
# Configure OpenRouter endpoint
os.environ["OPENAI_API_BASE"] = "https://openrouter.ai/api/v1"
os.environ["OPENAI_API_KEY"] = os.getenv("OPENROUTER_API_KEY")
# List of high-performance free models for rotation
primary_models = [
#Working Model with tool calling support and no rate limiting
"openrouter/nvidia/nemotron-3-super-120b-a12b:free",
# Not Working
# "openrouter/google/gemma-4-26b-a4b-it:free",
# "openrouter/google/gemma-3-27b-it:free",
# "openrouter/meta-llama/llama-3.3-70b-instruct:free",
# "openrouter/qwen/qwen3-coder:free",
]
model = ResilientLiteLlm(
primary_model_names=primary_models,
fallback_model_name="groq/llama3-70b-8192"
)
logger.info(f"Model initialized: ResilientLiteLlm (Rotating through: {', '.join(primary_models)})")
return model
# ===== GEMINI MODEL INITIALIZATION =====
# Using Google Gemini for Search Agents
from google.adk.models.google_llm import Gemini
Model="gemini-2.5-flash"
class ContextAwareGemini(Gemini):
"""
A wrapper around Gemini that dynamically injects the user's API key
from the current session context.
"""
async def generate_content_async(self, contents, **kwargs) -> AsyncGenerator:
user_id = context_user_id.get()
if user_id:
creds = await get_user_credentials(user_id)
custom_key = creds.get("gemini_api_key")
if custom_key:
# Inject custom key into the generate_content call
# Note: google.genai requires re-initialization of client with the new key
# or passing it in config if supported.
# For google-adk's Gemini class, we need to ensure the underlying client uses it.
# A robust way is to override the client or re-init property.
old_key = os.environ.get("GOOGLE_API_KEY")
try:
os.environ["GOOGLE_API_KEY"] = custom_key
# Re-initialize the client property if it exists in the ADK Gemini class
if hasattr(self, "_client"):
self._client = None # Force re-init with new env var
async for chunk in super().generate_content_async(contents, **kwargs):
yield chunk
finally:
if old_key:
os.environ["GOOGLE_API_KEY"] = old_key
else:
os.environ.pop("GOOGLE_API_KEY", None)
return
# Default behavior
async for chunk in super().generate_content_async(contents, **kwargs):
yield chunk
def get_gemini_model():
"""Returns a configured ContextAwareGemini model instance."""
# Ensure Google API Key is available
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
logger.warning("GOOGLE_API_KEY not found in environment. Gemini may fail.")
model = ContextAwareGemini(
model=Model,
generate_content_config=types.GenerateContentConfig(
http_options=types.HttpOptions(
retry_options=types.HttpRetryOptions(initial_delay=10, attempts=10)
)
)
)
logger.info(f"Model initialized: {Model} (Context-Aware) with Retry Options")
return model
# ===== SESSION SERVICE INITIALIZATION =====
# Using LazyDatabaseSessionService to prevent empty sessions on load
from .lazy_session import LazyDatabaseSessionService
def get_session_service(db_url=None):
"""
Returns a configured DatabaseSessionService instance.
Args:
db_url: Database connection string.
Defaults to DATABASE_URL env var, or local SQLite if not set.
"""
# Prioritize argument, then env var, then local default
if not db_url:
# Use legacy_solver.db as it contains the existing sessions
db_url = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///legacy_solver.db")
session_service = LazyDatabaseSessionService(db_url=db_url)
logger.info(f"Session service initialized (Lazy): {db_url.split('://')[0]}://...") # Log safe URL
return session_service
# ===== MEMORY SERVICE INITIALIZATION =====
# Using InMemoryMemoryService for simplicity (DatabaseMemoryService not available in this ADK version)
from google.adk.memory import InMemoryMemoryService
# Global cache for memory service
_memory_service_instance = None
def get_memory_service():
"""
Returns a configured MemoryService instance.
Uses Pinecone if PINECONE_API_KEY is set, otherwise InMemory.
Implements Singleton pattern to avoid reloading embeddings.
"""
global _memory_service_instance
if _memory_service_instance:
return _memory_service_instance
pinecone_key = os.getenv("PINECONE_API_KEY")
logger.info(f"Checking PINECONE_API_KEY: {'Found' if pinecone_key else 'Missing'}")
if pinecone_key:
try:
from .memory import PineconeMemoryService
_memory_service_instance = PineconeMemoryService(api_key=pinecone_key)
logger.info("Memory service initialized: Pinecone (Long-Term Vector Store)")
return _memory_service_instance
except Exception as e:
logger.error(f"Failed to init Pinecone, falling back to InMemory: {e}")
_memory_service_instance = InMemoryMemoryService()
logger.info("Memory service initialized: InMemory (Ephemeral)")
return _memory_service_instance