|
from threading import Lock |
|
import os |
|
from typing import List, Optional, Literal, Union, Dict |
|
from dotenv import load_dotenv |
|
import re |
|
from langchain_xai import ChatXAI |
|
from langchain_openai import ChatOpenAI |
|
from langchain_anthropic import ChatAnthropic |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from functools import wraps |
|
import time |
|
from openai import RateLimitError, OpenAIError |
|
from anthropic import RateLimitError as AnthropicRateLimitError, APIError as AnthropicAPIError |
|
from google.api_core.exceptions import ResourceExhausted, BadRequest, InvalidArgument |
|
from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_exception_type |
|
import asyncio |
|
|
|
ModelProvider = Literal["openai", "anthropic", "google", "xai"] |
|
|
|
class APIKeyManager: |
|
_instance = None |
|
_lock = Lock() |
|
|
|
|
|
SUPPORTED_MODELS = { |
|
"openai": [ |
|
"o1-mini", |
|
"o1", |
|
"o1-pro", |
|
"o3-mini", |
|
"o3", |
|
"o4-mini", |
|
"gpt-4o-mini-2024-07-18", |
|
"gpt-4o-mini", |
|
"chatgpt-4o-latest", |
|
"gpt-4o-2024-05-13", |
|
"gpt-4o-2024-08-06", |
|
"gpt-4o-2024-11-20", |
|
"gpt-4o", |
|
"gpt-4.1-nano", |
|
"gpt-4.1-mini", |
|
"gpt-4.1" |
|
], |
|
"google": [ |
|
"gemini-2.0-pro-exp-02-05", |
|
"gemini-2.0-flash-lite-preview-02-05", |
|
"gemini-2.0-flash-exp", |
|
"gemini-2.0-flash", |
|
"gemini-2.0-flash-thinking-exp-1219", |
|
"gemini-2.5-flash-lite-preview-06-17", |
|
"gemini-2.5-flash", |
|
"gemini-2.5-pro" |
|
], |
|
"xai": [ |
|
"grok-2", |
|
"grok-3-mini-latest", |
|
"grok-3-mini-fast-latest", |
|
"grok-3-latest", |
|
"grok-3-fast-latest" |
|
], |
|
"anthropic": [ |
|
"claude-opus-4-20250514", |
|
"claude-sonnet-4-20250514", |
|
"claude-3-7-sonnet-20250219", |
|
"claude-3-5-sonnet-20241022", |
|
"claude-3-5-sonnet-latest", |
|
"claude-3-5-haiku-20241022", |
|
"claude-3-5-haiku-latest", |
|
"claude-3-opus-20240229", |
|
"claude-3-opus-latest", |
|
"claude-3-sonnet-20240229", |
|
"claude-3-haiku-20240307" |
|
] |
|
} |
|
|
|
def __new__(cls): |
|
with cls._lock: |
|
if cls._instance is None: |
|
cls._instance = super(APIKeyManager, cls).__new__(cls) |
|
cls._instance._initialized = False |
|
return cls._instance |
|
|
|
def __init__(self): |
|
if not self._initialized: |
|
self._initialized = True |
|
|
|
|
|
load_dotenv(override=True) |
|
|
|
self._current_indices = { |
|
"openai": 0, |
|
"anthropic": 0, |
|
"google": 0, |
|
"xai": 0 |
|
} |
|
self._lock = Lock() |
|
|
|
|
|
self._api_keys = self._load_api_keys() |
|
self._llm = None |
|
self._current_provider = None |
|
|
|
|
|
provider_env = os.getenv("MODEL_PROVIDER", "openai").strip().lower() |
|
self.model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo").strip() |
|
temp_str = os.getenv("MODEL_TEMPERATURE", "0") |
|
topp_str = os.getenv("MODEL_TOP_P", "1") |
|
|
|
try: |
|
self.temperature = float(temp_str) |
|
except ValueError: |
|
self.temperature = 0.0 |
|
try: |
|
self.top_p = float(topp_str) |
|
except ValueError: |
|
self.top_p = 1.0 |
|
|
|
def _reinit(self): |
|
self._initialized = False |
|
self.__init__() |
|
|
|
def _load_api_keys(self) -> Dict[str, List[str]]: |
|
"""Load API keys from environment variables dynamically.""" |
|
api_keys = { |
|
"openai": [], |
|
"anthropic": [], |
|
"google": [], |
|
"xai": [] |
|
} |
|
|
|
|
|
env_vars = dict(os.environ) |
|
|
|
|
|
openai_pattern = re.compile(r'OPENAI_API_KEY_\d+$') |
|
openai_keys = {k: v for k, v in env_vars.items() if openai_pattern.match(k) and v.strip()} |
|
|
|
if not openai_keys: |
|
default_key = os.getenv('OPENAI_API_KEY') |
|
if default_key and default_key.strip(): |
|
api_keys["openai"].append(default_key) |
|
else: |
|
sorted_keys = sorted(openai_keys.keys(), key=lambda x: int(x.split('_')[-1])) |
|
for key_name in sorted_keys: |
|
api_key = openai_keys[key_name] |
|
if api_key and api_key.strip(): |
|
api_keys["openai"].append(api_key) |
|
|
|
|
|
google_pattern = re.compile(r'GOOGLE_API_KEY_\d+$') |
|
google_keys = {k: v for k, v in env_vars.items() if google_pattern.match(k) and v.strip()} |
|
|
|
if not google_keys: |
|
default_key = os.getenv('GOOGLE_API_KEY') |
|
if default_key and default_key.strip(): |
|
api_keys["google"].append(default_key) |
|
else: |
|
sorted_keys = sorted(google_keys.keys(), key=lambda x: int(x.split('_')[-1])) |
|
for key_name in sorted_keys: |
|
api_key = google_keys[key_name] |
|
if api_key and api_key.strip(): |
|
api_keys["google"].append(api_key) |
|
|
|
|
|
xai_pattern = re.compile(r'XAI_API_KEY_\d+$') |
|
xai_keys = {k: v for k, v in env_vars.items() if xai_pattern.match(k) and v.strip()} |
|
|
|
if not xai_keys: |
|
default_key = os.getenv('XAI_API_KEY') |
|
if default_key and default_key.strip(): |
|
api_keys["xai"].append(default_key) |
|
else: |
|
sorted_keys = sorted(xai_keys.keys(), key=lambda x: int(x.split('_')[-1])) |
|
for key_name in sorted_keys: |
|
api_key = xai_keys[key_name] |
|
if api_key and api_key.strip(): |
|
api_keys["xai"].append(api_key) |
|
|
|
|
|
anthropic_pattern = re.compile(r'ANTHROPIC_API_KEY_\d+$') |
|
anthropic_keys = {k: v for k, v in env_vars.items() if anthropic_pattern.match(k) and v.strip()} |
|
|
|
if not anthropic_keys: |
|
default_key = os.getenv('ANTHROPIC_API_KEY') |
|
if default_key and default_key.strip(): |
|
api_keys["anthropic"].append(default_key) |
|
else: |
|
sorted_keys = sorted(anthropic_keys.keys(), key=lambda x: int(x.split('_')[-1])) |
|
for key_name in sorted_keys: |
|
api_key = anthropic_keys[key_name] |
|
if api_key and api_key.strip(): |
|
api_keys["anthropic"].append(api_key) |
|
|
|
if not any(api_keys.values()): |
|
raise Exception("No valid API keys found in environment variables") |
|
|
|
for provider, keys in api_keys.items(): |
|
if keys: |
|
print(f"Loaded {len(keys)} {provider} API keys for rotation") |
|
|
|
return api_keys |
|
|
|
def get_next_api_key(self, provider: ModelProvider) -> str: |
|
"""Get the next API key in round-robin fashion for the specified provider.""" |
|
with self._lock: |
|
if not self._api_keys.get(provider) or len(self._api_keys[provider]) == 0: |
|
raise Exception(f"No API key found for {provider}") |
|
|
|
if provider not in self._current_indices: |
|
self._current_indices[provider] = 0 |
|
|
|
current_key = self._api_keys[provider][self._current_indices[provider]] |
|
self._current_indices[provider] = (self._current_indices[provider] + 1) % len(self._api_keys[provider]) |
|
return current_key |
|
|
|
def _get_provider_for_model(self) -> ModelProvider: |
|
"""Determine the provider based on the model name.""" |
|
load_dotenv(override=True) |
|
provider_env = os.getenv("MODEL_PROVIDER", "openai").lower().strip() |
|
|
|
if provider_env not in self.SUPPORTED_MODELS: |
|
raise Exception( |
|
f"Invalid or missing MODEL_PROVIDER in env: '{provider_env}'. " |
|
f"Must be one of: {list(self.SUPPORTED_MODELS.keys())}" |
|
) |
|
|
|
|
|
if self.model_name not in self.SUPPORTED_MODELS[provider_env]: |
|
available = self.SUPPORTED_MODELS[provider_env] |
|
raise Exception( |
|
f"Model '{self.model_name}' is not available under provider '{provider_env}'. " |
|
f"Available: {available}" |
|
) |
|
|
|
return provider_env |
|
|
|
|
|
def _initialize_llm( |
|
self, |
|
model_name: Optional[str] = None, |
|
temperature: Optional[float] = None, |
|
top_p: Optional[float] = None, |
|
max_tokens: Optional[int] = None, |
|
streaming: bool = False |
|
): |
|
"""Initialize LLM with the next API key in rotation.""" |
|
load_dotenv(override=True) |
|
provider = self._get_provider_for_model() |
|
model_name = model_name if model_name else self.model_name |
|
temperature = temperature if temperature else self.temperature |
|
top_p = top_p if top_p else self.top_p |
|
|
|
api_key = self.get_next_api_key(provider) |
|
print(f"Using provider={provider}, model_name={model_name}, " |
|
f"temperature={temperature}, top_p={top_p}, key={api_key}") |
|
|
|
kwargs = { |
|
"model": model_name, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"max_retries": 0, |
|
"streaming": streaming, |
|
"api_key": api_key |
|
} |
|
|
|
if max_tokens is not None: |
|
kwargs["max_tokens"] = max_tokens |
|
|
|
if provider == "openai": |
|
self._llm = ChatOpenAI(**kwargs) |
|
elif provider == "google": |
|
self._llm = ChatGoogleGenerativeAI(**kwargs) |
|
elif provider == "anthropic": |
|
self._llm = ChatAnthropic(**kwargs) |
|
else: |
|
self._llm = ChatXAI(**kwargs) |
|
|
|
self._current_provider = provider |
|
|
|
def get_llm( |
|
self, |
|
model_name: Optional[str] = None, |
|
temperature: Optional[float] = None, |
|
top_p: Optional[float] = None, |
|
max_tokens: Optional[int] = None, |
|
streaming: bool = False |
|
) -> Union[ChatOpenAI, ChatGoogleGenerativeAI, ChatAnthropic, ChatXAI]: |
|
"""Get LLM instance with the current API key.""" |
|
provider = self._get_provider_for_model() |
|
model_name = model_name if model_name else self.model_name |
|
temperature = temperature if temperature else self.temperature |
|
top_p = top_p if top_p else self.top_p |
|
|
|
if self._llm is None or provider != self._current_provider: |
|
self._initialize_llm(model_name, temperature, top_p, max_tokens, streaming) |
|
return self._llm |
|
|
|
def rotate_key(self, provider: Optional[ModelProvider] = None, streaming: bool = False) -> None: |
|
"""Manually rotate to the next API key.""" |
|
if provider is None: |
|
provider = self._current_provider |
|
self._initialize_llm(streaming=streaming) |
|
|
|
def get_all_api_keys(self, provider: Optional[ModelProvider] = None) -> Union[Dict[str, List[str]], List[str]]: |
|
"""Get all available API keys.""" |
|
if provider: |
|
return self._api_keys[provider].copy() |
|
return {k: v.copy() for k, v in self._api_keys.items()} |
|
|
|
def get_key_count(self, provider: Optional[ModelProvider] = None) -> Union[Dict[str, int], int]: |
|
"""Get the total number of available API keys.""" |
|
if provider: |
|
return len(self._api_keys[provider]) |
|
return {k: len(v) for k, v in self._api_keys.items()} |
|
|
|
def __len__(self) -> Dict[str, int]: |
|
"""Get the number of active API keys for each provider.""" |
|
return self.get_key_count() |
|
|
|
def __bool__(self) -> bool: |
|
"""Check if there are any API keys available.""" |
|
return any(bool(keys) for keys in self._api_keys.values()) |
|
|
|
def with_api_manager( |
|
model_name: Optional[str] = None, |
|
temperature: Optional[float] = None, |
|
top_p: Optional[float] = None, |
|
max_tokens: Optional[int] = None, |
|
streaming: bool = False, |
|
delay_on_timeout: int = 20, |
|
max_token_reduction_attempts: int = 0 |
|
): |
|
"""Decorator for automatic key rotation on error with delay on timeout.""" |
|
manager = APIKeyManager() |
|
provider = manager._get_provider_for_model() |
|
model_name = model_name if model_name else manager.model_name |
|
temperature = temperature if temperature else manager.temperature |
|
top_p = top_p if top_p else manager.top_p |
|
key_count = manager.get_key_count(provider) |
|
|
|
def decorator(func): |
|
if asyncio.iscoroutinefunction(func): |
|
@wraps(func) |
|
async def wrapper(*args, **kwargs): |
|
if key_count > 1: |
|
all_keys = manager.get_all_api_keys(provider) |
|
tried_keys = set() |
|
|
|
current_max_tokens = max_tokens |
|
token_reduction_attempts = 0 |
|
|
|
while len(tried_keys) < len(all_keys): |
|
try: |
|
llm = manager.get_llm( |
|
model_name=model_name, |
|
temperature=temperature, |
|
top_p=top_p, |
|
max_tokens=current_max_tokens, |
|
streaming=streaming |
|
) |
|
result = await func(*args, **kwargs, llm=llm) |
|
return result |
|
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e: |
|
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] |
|
print(f"Rate limit error with {provider} API key {current_key}: {str(e)}") |
|
tried_keys.add(current_key) |
|
if len(tried_keys) < len(all_keys): |
|
manager.rotate_key(provider=provider, streaming=streaming) |
|
print(f"Using next available {provider} API key") |
|
else: |
|
if delay_on_timeout > 0: |
|
print(f"Waiting for {delay_on_timeout} seconds before retrying with the first key...") |
|
time.sleep(delay_on_timeout) |
|
manager._current_indices[provider] = 0 |
|
else: |
|
print(f"All {provider} API keys failed due to rate limits: {str(e)}") |
|
raise |
|
except (OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: |
|
error_str = str(e) |
|
if "token" in error_str.lower() or "context length" in error_str.lower(): |
|
print(f"Token limit error encountered: {error_str}") |
|
if max_token_reduction_attempts > 0 and max_tokens is not None and token_reduction_attempts < max_token_reduction_attempts: |
|
current_max_tokens = int(current_max_tokens * 0.8) |
|
token_reduction_attempts += 1 |
|
print(f"Retrying with reduced max_tokens: {current_max_tokens}") |
|
continue |
|
else: |
|
print("Max token reduction attempts reached or token reduction disabled. Proceeding with key rotation.") |
|
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] |
|
tried_keys.add(current_key) |
|
if len(tried_keys) < len(all_keys): |
|
manager.rotate_key(provider=provider, streaming=streaming) |
|
print(f"Using next available {provider} API key after token limit error.") |
|
else: |
|
raise |
|
else: |
|
|
|
raise |
|
|
|
|
|
try: |
|
llm = manager.get_llm( |
|
model_name=model_name, |
|
temperature=temperature, |
|
top_p=top_p, |
|
max_tokens=current_max_tokens, |
|
streaming=streaming |
|
) |
|
result = await func(*args, **kwargs, llm=llm) |
|
return result |
|
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, |
|
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: |
|
print(f"Error after retrying all {provider} API keys: {str(e)}") |
|
raise |
|
|
|
elif key_count == 1: |
|
@retry( |
|
wait=wait_random_exponential(min=10, max=60), |
|
stop=stop_after_attempt(6), |
|
retry=retry_if_exception_type(( |
|
RateLimitError, ResourceExhausted, AnthropicRateLimitError, |
|
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument)) |
|
) |
|
async def attempt_function_call(): |
|
llm = manager.get_llm( |
|
model_name=model_name, |
|
temperature=temperature, |
|
top_p=top_p, |
|
max_tokens=max_tokens, |
|
streaming=streaming |
|
) |
|
return await func(*args, **kwargs, llm=llm) |
|
|
|
try: |
|
return await attempt_function_call() |
|
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, |
|
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: |
|
print(f"Error encountered for {provider} after multiple retries: {str(e)}") |
|
raise |
|
else: |
|
print(f"No API keys found for provider: {provider}") |
|
raise |
|
|
|
else: |
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
if key_count > 1: |
|
all_keys = manager.get_all_api_keys(provider) |
|
tried_keys = set() |
|
current_max_tokens = max_tokens |
|
token_reduction_attempts = 0 |
|
|
|
while len(tried_keys) < len(all_keys): |
|
try: |
|
llm = manager.get_llm( |
|
model_name=model_name, |
|
temperature=temperature, |
|
top_p=top_p, |
|
max_tokens=current_max_tokens, |
|
streaming=streaming |
|
) |
|
result = func(*args, **kwargs, llm=llm) |
|
return result |
|
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e: |
|
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] |
|
print(f"Rate limit error with {provider} API key {current_key}: {str(e)}") |
|
tried_keys.add(current_key) |
|
if len(tried_keys) < len(all_keys): |
|
manager.rotate_key(provider=provider, streaming=streaming) |
|
print(f"Using next available {provider} API key") |
|
else: |
|
if delay_on_timeout > 0: |
|
print(f"Waiting for {delay_on_timeout} seconds before retrying with the first key...") |
|
time.sleep(delay_on_timeout) |
|
manager._current_indices[provider] = 0 |
|
else: |
|
print(f"All {provider} API keys failed due to rate limits: {str(e)}") |
|
raise |
|
except (OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: |
|
error_str = str(e) |
|
if "token" in error_str.lower() or "context length" in error_str.lower(): |
|
print(f"Token limit error encountered: {error_str}") |
|
if max_token_reduction_attempts > 0 and max_tokens is not None and token_reduction_attempts < max_token_reduction_attempts: |
|
current_max_tokens = int(current_max_tokens * 0.8) |
|
token_reduction_attempts += 1 |
|
print(f"Retrying with reduced max_tokens: {current_max_tokens}") |
|
continue |
|
else: |
|
print("Max token reduction attempts reached or token reduction disabled. Proceeding with key rotation.") |
|
current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] |
|
tried_keys.add(current_key) |
|
if len(tried_keys) < len(all_keys): |
|
manager.rotate_key(provider=provider, streaming=streaming) |
|
print(f"Using next available {provider} API key after token limit error.") |
|
else: |
|
raise |
|
else: |
|
|
|
raise |
|
|
|
|
|
try: |
|
llm = manager.get_llm( |
|
model_name=model_name, |
|
temperature=temperature, |
|
top_p=top_p, |
|
max_tokens=current_max_tokens, |
|
streaming=streaming |
|
) |
|
result = func(*args, **kwargs, llm=llm) |
|
return result |
|
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, |
|
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: |
|
print(f"Error after retrying all {provider} API keys: {str(e)}") |
|
raise |
|
|
|
elif key_count == 1: |
|
@retry( |
|
wait=wait_random_exponential(min=10, max=60), |
|
stop=stop_after_attempt(6), |
|
retry=retry_if_exception_type(( |
|
RateLimitError, ResourceExhausted, AnthropicRateLimitError, |
|
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument)) |
|
) |
|
def attempt_function_call(): |
|
llm = manager.get_llm( |
|
model_name=model_name, |
|
temperature=temperature, |
|
top_p=top_p, |
|
max_tokens=max_tokens, |
|
streaming=streaming |
|
) |
|
return func(*args, **kwargs, llm=llm) |
|
|
|
try: |
|
return attempt_function_call() |
|
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, |
|
OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: |
|
print(f"Error encountered for {provider} after multiple retries: {str(e)}") |
|
raise |
|
else: |
|
print(f"No API keys found for provider: {provider}") |
|
raise |
|
|
|
return wrapper |
|
return decorator |
|
|
|
if __name__ == "__main__": |
|
import asyncio |
|
|
|
prompt = "What is the capital of France?" |
|
|
|
|
|
async def test_load_balancing(prompt: str, test_count: int = 10, stream: bool = False): |
|
@with_api_manager(streaming=stream) |
|
async def test(prompt: str, test_count: int = 10, *, llm): |
|
print("="*50) |
|
for i in range(test_count): |
|
try: |
|
print(f"\nTest {i+1} of {test_count}") |
|
if stream: |
|
async for chunk in llm.astream(prompt): |
|
print(chunk.content, end="", flush=True) |
|
print("\n" + "-"*50 if i != test_count - 1 else "\n" + "="*50) |
|
else: |
|
response = await llm.ainvoke(prompt) |
|
print(f"Response: {response.content.strip()}") |
|
print("-"*50) if i != test_count - 1 else print("="*50) |
|
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e: |
|
print(f"Error: {str(e)}") |
|
raise |
|
|
|
await test(prompt, test_count=test_count) |
|
|
|
|
|
def test_without_load_balancing(model_name: str, prompt: str, test_count: int = 10): |
|
manager = APIKeyManager() |
|
print(f"Using model: {model_name}") |
|
print("="*50) |
|
i = 0 |
|
while i < test_count: |
|
try: |
|
print(f"Test {i+1} of {test_count}") |
|
llm = manager.get_llm(model_name=model_name) |
|
response = llm.invoke(prompt) |
|
print(f"Response: {response.content.strip()}") |
|
print("-"*50) if i != test_count - 1 else print("="*50) |
|
i += 1 |
|
except Exception as e: |
|
raise Exception(f"Error with {model_name}: {str(e)}") |
|
|
|
test_without_load_balancing(model_name="gemini-2.5-flash-lite-preview-06-17", prompt=prompt, test_count=50) |
|
|