Spaces:
Sleeping
Sleeping
""" | |
Multi-LLM Handler with failover support | |
Uses Groq, Gemini, and OpenAI with automatic failover for reliability | |
""" | |
import asyncio | |
import re | |
import time | |
from typing import Optional, Dict, Any, List | |
import os | |
import requests | |
import google.generativeai as genai | |
import openai | |
from dotenv import load_dotenv | |
from config.config import get_provider_configs | |
load_dotenv() | |
class MultiLLMHandler: | |
"""Multi-LLM handler with automatic failover across providers.""" | |
def __init__(self): | |
"""Initialize the multi-LLM handler with all available providers.""" | |
self.providers = get_provider_configs() | |
self.current_provider = None | |
self.current_config = None | |
# Initialize the first available provider (prefer Gemini/OpenAI for general RAG) | |
self._initialize_provider() | |
print(f"✅ Initialized Multi-LLM Handler with {self.provider.upper()}: {self.model_name}") | |
def _initialize_provider(self): | |
"""Initialize the first available provider.""" | |
# Prefer Gemini first for general text tasks | |
if self.providers["gemini"]: | |
self.current_provider = "gemini" | |
self.current_config = self.providers["gemini"][0] | |
genai.configure(api_key=self.current_config["api_key"]) | |
# Then OpenAI | |
elif self.providers["openai"]: | |
self.current_provider = "openai" | |
self.current_config = self.providers["openai"][0] | |
openai.api_key = self.current_config["api_key"] | |
# Finally Groq | |
elif self.providers["groq"]: | |
self.current_provider = "groq" | |
self.current_config = self.providers["groq"][0] | |
else: | |
raise ValueError("No LLM providers available with valid API keys") | |
def provider(self): | |
"""Get current provider name.""" | |
return self.current_provider | |
def model_name(self): | |
"""Get current model name.""" | |
return self.current_config["model"] if self.current_config else "unknown" | |
async def _call_groq(self, prompt: str, temperature: float, max_tokens: int) -> str: | |
"""Call Groq API.""" | |
headers = { | |
"Authorization": f"Bearer {self.current_config['api_key']}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"model": self.current_config["model"], | |
"messages": [{"role": "user", "content": prompt}], | |
"temperature": temperature, | |
"max_tokens": max_tokens | |
} | |
# Hide reasoning tokens (e.g., <think>) for Qwen reasoning models | |
try: | |
model_name = (self.current_config.get("model") or "").lower() | |
if "qwen" in model_name: | |
# Per request, use the chat completion parameter to hide reasoning content | |
data["reasoning_effort"] = "hidden" | |
except Exception: | |
# Be resilient if config shape changes | |
pass | |
response = requests.post( | |
"https://api.groq.com/openai/v1/chat/completions", | |
headers=headers, | |
json=data, | |
timeout=30 | |
) | |
response.raise_for_status() | |
result = response.json() | |
text = result["choices"][0]["message"]["content"].strip() | |
# Safety net: strip any <think>...</think> blocks if present | |
try: | |
model_name = (self.current_config.get("model") or "").lower() | |
if "qwen" in model_name and "<think>" in text.lower(): | |
text = re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE).strip() | |
except Exception: | |
pass | |
return text | |
async def _call_gemini(self, prompt: str, temperature: float, max_tokens: int) -> str: | |
"""Call Gemini API.""" | |
model = genai.GenerativeModel(self.current_config["model"]) | |
generation_config = genai.types.GenerationConfig( | |
temperature=temperature, | |
max_output_tokens=max_tokens | |
) | |
response = await asyncio.to_thread( | |
model.generate_content, | |
prompt, | |
generation_config=generation_config | |
) | |
return response.text.strip() | |
async def _call_openai(self, prompt: str, temperature: float, max_tokens: int) -> str: | |
"""Call OpenAI API.""" | |
response = await asyncio.to_thread( | |
openai.ChatCompletion.create, | |
model=self.current_config["model"], | |
messages=[{"role": "user", "content": prompt}], | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
return response.choices[0].message.content.strip() | |
async def _try_with_failover(self, prompt: str, temperature: float, max_tokens: int) -> str: | |
"""Try to generate text with automatic failover.""" | |
# Get all available providers in order | |
provider_order = [] | |
# Prefer Gemini -> OpenAI -> Groq for general text | |
if self.providers["gemini"]: | |
provider_order.extend([("gemini", config) for config in self.providers["gemini"]]) | |
if self.providers["openai"]: | |
provider_order.extend([("openai", config) for config in self.providers["openai"]]) | |
if self.providers["groq"]: | |
provider_order.extend([("groq", config) for config in self.providers["groq"]]) | |
last_error = None | |
for provider_name, config in provider_order: | |
try: | |
# Set current provider | |
old_provider = self.current_provider | |
old_config = self.current_config | |
self.current_provider = provider_name | |
self.current_config = config | |
# Configure API if needed | |
if provider_name == "gemini": | |
genai.configure(api_key=config["api_key"]) | |
elif provider_name == "openai": | |
openai.api_key = config["api_key"] | |
# Try the API call | |
if provider_name == "groq": | |
return await self._call_groq(prompt, temperature, max_tokens) | |
elif provider_name == "gemini": | |
return await self._call_gemini(prompt, temperature, max_tokens) | |
elif provider_name == "openai": | |
return await self._call_openai(prompt, temperature, max_tokens) | |
except Exception as e: | |
print(f"⚠️ {provider_name.upper()} ({config['name']}) failed: {str(e)}") | |
last_error = e | |
# Restore previous provider | |
self.current_provider = old_provider | |
self.current_config = old_config | |
continue | |
# If all providers failed | |
raise RuntimeError(f"All LLM providers failed. Last error: {last_error}") | |
async def generate_text(self, | |
prompt: Optional[str] = None, | |
system_prompt: Optional[str] = None, | |
user_prompt: Optional[str] = None, | |
temperature: Optional[float] = 0.4, | |
max_tokens: Optional[int] = 1200) -> str: | |
"""Generate text using multi-LLM with failover.""" | |
# Handle both single prompt and system/user prompt formats | |
if prompt: | |
final_prompt = prompt | |
elif system_prompt and user_prompt: | |
final_prompt = f"{system_prompt}\n\n{user_prompt}" | |
elif user_prompt: | |
final_prompt = user_prompt | |
else: | |
raise ValueError("Must provide either 'prompt' or 'user_prompt'") | |
return await self._try_with_failover( | |
final_prompt, | |
temperature or 0.4, | |
max_tokens or 1200 | |
) | |
async def generate_simple(self, | |
prompt: str, | |
temperature: Optional[float] = 0.4, | |
max_tokens: Optional[int] = 1200) -> str: | |
"""Simple text generation (alias for generate_text for compatibility).""" | |
return await self.generate_text(prompt=prompt, temperature=temperature, max_tokens=max_tokens) | |
def get_provider_info(self) -> Dict[str, Any]: | |
"""Get information about the current provider.""" | |
return { | |
"provider": self.current_provider, | |
"model": self.model_name, | |
"config_name": self.current_config["name"] if self.current_config else "none", | |
"available_providers": { | |
"groq": len(self.providers["groq"]), | |
"gemini": len(self.providers["gemini"]), | |
"openai": len(self.providers["openai"]) | |
} | |
} | |
async def test_connection(self) -> bool: | |
"""Test the connection to the current LLM provider.""" | |
try: | |
test_prompt = "Say 'Hello' if you can read this." | |
response = await self.generate_simple(test_prompt, temperature=0.1, max_tokens=10) | |
return "hello" in response.lower() | |
except Exception as e: | |
print(f"❌ Connection test failed: {str(e)}") | |
return False | |
# Create a global instance | |
llm_handler = MultiLLMHandler() | |