Spaces:
Sleeping
Sleeping
| """ | |
| API Client for Smart Auto-Complete | |
| Handles communication with OpenAI and Anthropic APIs | |
| """ | |
| import logging | |
| import time | |
| from typing import Dict, List, Optional, Union | |
| import anthropic | |
| import openai | |
| from .utils import validate_api_key | |
| logger = logging.getLogger(__name__) | |
| class APIClient: | |
| """ | |
| Unified API client for multiple AI providers | |
| Supports OpenAI GPT and Anthropic Claude models | |
| """ | |
| def __init__(self, settings=None): | |
| """ | |
| Initialize the API client with settings | |
| Args: | |
| settings: Application settings object | |
| """ | |
| self.settings = settings | |
| self.openai_client = None | |
| self.anthropic_client = None | |
| self.current_provider = None | |
| self.request_count = 0 | |
| self.last_request_time = 0 | |
| self._initialize_clients() | |
| def _initialize_clients(self): | |
| """Initialize API clients based on available keys""" | |
| try: | |
| # Initialize OpenAI client | |
| if ( | |
| self.settings | |
| and hasattr(self.settings, "OPENAI_API_KEY") | |
| and self.settings.OPENAI_API_KEY | |
| and validate_api_key(self.settings.OPENAI_API_KEY, "openai") | |
| ): | |
| self.openai_client = openai.OpenAI(api_key=self.settings.OPENAI_API_KEY) | |
| logger.info("OpenAI client initialized successfully") | |
| # Initialize Anthropic client | |
| if ( | |
| self.settings | |
| and hasattr(self.settings, "ANTHROPIC_API_KEY") | |
| and self.settings.ANTHROPIC_API_KEY | |
| and validate_api_key(self.settings.ANTHROPIC_API_KEY, "anthropic") | |
| ): | |
| self.anthropic_client = anthropic.Anthropic( | |
| api_key=self.settings.ANTHROPIC_API_KEY | |
| ) | |
| logger.info("Anthropic client initialized successfully") | |
| # Set default provider | |
| if hasattr(self.settings, "DEFAULT_PROVIDER"): | |
| self.current_provider = self.settings.DEFAULT_PROVIDER | |
| elif self.openai_client: | |
| self.current_provider = "openai" | |
| elif self.anthropic_client: | |
| self.current_provider = "anthropic" | |
| else: | |
| logger.warning("No valid API clients initialized") | |
| except Exception as e: | |
| logger.error(f"Error initializing API clients: {str(e)}") | |
| def get_completion( | |
| self, | |
| messages: List[Dict[str, str]], | |
| temperature: float = 0.7, | |
| max_tokens: int = 150, | |
| provider: Optional[str] = None, | |
| ) -> Optional[str]: | |
| """ | |
| Get a completion from the specified provider | |
| Args: | |
| messages: List of message dictionaries with 'role' and 'content' | |
| temperature: Sampling temperature (0.0 to 1.0) | |
| max_tokens: Maximum tokens in response | |
| provider: Specific provider to use ('openai' or 'anthropic') | |
| Returns: | |
| Generated completion text or None if failed | |
| """ | |
| try: | |
| # Rate limiting check | |
| if not self._check_rate_limit(): | |
| logger.warning("Rate limit exceeded, skipping request") | |
| return None | |
| # Determine which provider to use | |
| use_provider = provider or self.current_provider | |
| if use_provider == "openai" and self.openai_client: | |
| return self._get_openai_completion(messages, temperature, max_tokens) | |
| elif use_provider == "anthropic" and self.anthropic_client: | |
| return self._get_anthropic_completion(messages, temperature, max_tokens) | |
| else: | |
| # Fallback to any available provider | |
| if self.openai_client: | |
| return self._get_openai_completion( | |
| messages, temperature, max_tokens | |
| ) | |
| elif self.anthropic_client: | |
| return self._get_anthropic_completion( | |
| messages, temperature, max_tokens | |
| ) | |
| else: | |
| logger.error("No API clients available") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error getting completion: {str(e)}") | |
| return None | |
| def _get_openai_completion( | |
| self, messages: List[Dict[str, str]], temperature: float, max_tokens: int | |
| ) -> Optional[str]: | |
| """Get completion from OpenAI API""" | |
| try: | |
| response = self.openai_client.chat.completions.create( | |
| model="gpt-3.5-turbo", # Can be made configurable | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| n=1, | |
| stop=None, | |
| presence_penalty=0.1, | |
| frequency_penalty=0.1, | |
| ) | |
| self._update_request_stats() | |
| if response.choices and len(response.choices) > 0: | |
| return response.choices[0].message.content.strip() | |
| else: | |
| logger.warning("No choices returned from OpenAI API") | |
| return None | |
| except openai.RateLimitError: | |
| logger.warning("OpenAI rate limit exceeded") | |
| return None | |
| except openai.APIError as e: | |
| logger.error(f"OpenAI API error: {str(e)}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Unexpected error with OpenAI: {str(e)}") | |
| return None | |
| def _get_anthropic_completion( | |
| self, messages: List[Dict[str, str]], temperature: float, max_tokens: int | |
| ) -> Optional[str]: | |
| """Get completion from Anthropic API""" | |
| try: | |
| # Convert messages format for Anthropic | |
| system_message = "" | |
| user_messages = [] | |
| for msg in messages: | |
| if msg["role"] == "system": | |
| system_message = msg["content"] | |
| else: | |
| user_messages.append(msg) | |
| # Create the completion request | |
| response = self.anthropic_client.messages.create( | |
| model="claude-3-haiku-20240307", # Can be made configurable | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| system=system_message, | |
| messages=user_messages, | |
| ) | |
| self._update_request_stats() | |
| if response.content and len(response.content) > 0: | |
| return response.content[0].text.strip() | |
| else: | |
| logger.warning("No content returned from Anthropic API") | |
| return None | |
| except anthropic.RateLimitError: | |
| logger.warning("Anthropic rate limit exceeded") | |
| return None | |
| except anthropic.APIError as e: | |
| logger.error(f"Anthropic API error: {str(e)}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Unexpected error with Anthropic: {str(e)}") | |
| return None | |
| def _check_rate_limit(self) -> bool: | |
| """ | |
| Check if we're within rate limits | |
| Simple implementation - can be enhanced with more sophisticated logic | |
| """ | |
| current_time = time.time() | |
| # Allow max 60 requests per minute (1 per second) | |
| if current_time - self.last_request_time < 1.0: | |
| return False | |
| return True | |
| def _update_request_stats(self): | |
| """Update request statistics""" | |
| self.request_count += 1 | |
| self.last_request_time = time.time() | |
| def get_available_providers(self) -> List[str]: | |
| """Get list of available providers""" | |
| providers = [] | |
| if self.openai_client: | |
| providers.append("openai") | |
| if self.anthropic_client: | |
| providers.append("anthropic") | |
| return providers | |
| def switch_provider(self, provider: str) -> bool: | |
| """ | |
| Switch to a different provider | |
| Args: | |
| provider: Provider name ('openai' or 'anthropic') | |
| Returns: | |
| True if switch was successful, False otherwise | |
| """ | |
| if provider == "openai" and self.openai_client: | |
| self.current_provider = "openai" | |
| logger.info("Switched to OpenAI provider") | |
| return True | |
| elif provider == "anthropic" and self.anthropic_client: | |
| self.current_provider = "anthropic" | |
| logger.info("Switched to Anthropic provider") | |
| return True | |
| else: | |
| logger.warning(f"Cannot switch to provider: {provider}") | |
| return False | |
| def get_stats(self) -> Dict[str, Union[int, float, str]]: | |
| """Get API usage statistics""" | |
| return { | |
| "request_count": self.request_count, | |
| "current_provider": self.current_provider, | |
| "available_providers": self.get_available_providers(), | |
| "last_request_time": self.last_request_time, | |
| } | |
| def test_connection(self, provider: Optional[str] = None) -> bool: | |
| """ | |
| Test connection to the API provider | |
| Args: | |
| provider: Specific provider to test, or None for current provider | |
| Returns: | |
| True if connection is successful, False otherwise | |
| """ | |
| try: | |
| test_messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": "Say 'Hello' in one word."}, | |
| ] | |
| result = self.get_completion( | |
| messages=test_messages, | |
| temperature=0.1, | |
| max_tokens=10, | |
| provider=provider, | |
| ) | |
| return result is not None and len(result.strip()) > 0 | |
| except Exception as e: | |
| logger.error(f"Connection test failed: {str(e)}") | |
| return False | |