Spaces:
Running
Running
| """Smart model router for intelligent model selection and fallback.""" | |
| import asyncio | |
| import logging | |
| from dataclasses import dataclass, field | |
| from datetime import datetime, timezone | |
| from enum import Enum | |
| from typing import Any | |
| from pydantic import SecretStr | |
| from app.models.providers.base import ( | |
| BaseProvider, | |
| CompletionResponse, | |
| ModelInfo, | |
| ProviderError, | |
| RateLimitError, | |
| TaskType, | |
| TokenUsage, | |
| ) | |
| from app.models.providers.openai import OpenAIProvider | |
| from app.models.providers.anthropic import AnthropicProvider | |
| from app.models.providers.google import GoogleProvider | |
| from app.models.providers.groq import GroqProvider | |
| from app.models.providers.nvidia import NVIDIAProvider | |
| logger = logging.getLogger(__name__) | |
| class RoutingStrategy(str, Enum): | |
| """Model routing strategies.""" | |
| BEST_QUALITY = "best_quality" # Use highest quality model | |
| BEST_SPEED = "best_speed" # Use fastest model | |
| BEST_VALUE = "best_value" # Balance quality/cost | |
| LOWEST_COST = "lowest_cost" # Use cheapest model | |
| ROUND_ROBIN = "round_robin" # Rotate between models | |
| class ModelScore: | |
| """Scoring for model routing decisions.""" | |
| model_id: str | |
| provider: str | |
| quality_score: float = 0.0 # 0-1, higher is better | |
| speed_score: float = 0.0 # 0-1, higher is faster | |
| cost_score: float = 0.0 # 0-1, higher is cheaper | |
| overall_score: float = 0.0 | |
| class RoutingConfig: | |
| """Configuration for model routing.""" | |
| default_strategy: RoutingStrategy = RoutingStrategy.BEST_VALUE | |
| max_fallback_attempts: int = 3 | |
| fallback_delay_seconds: float = 1.0 | |
| enable_caching: bool = True | |
| cache_ttl_seconds: int = 300 | |
| # Task-specific model preferences | |
| task_preferences: dict[TaskType, list[str]] = field(default_factory=lambda: { | |
| TaskType.GENERAL: ["gpt-4o", "claude-3-5-sonnet-20241022", "gemini-2.5-pro", "deepseek-r1"], | |
| TaskType.CODE: ["claude-3-5-sonnet-20241022", "gpt-4o", "devstral-2-123b", "gemini-2.5-pro"], | |
| TaskType.REASONING: ["claude-3-opus-20240229", "deepseek-r1", "gpt-4o", "step-3.5-flash"], | |
| TaskType.EXTRACTION: ["gpt-4o-mini", "claude-3-haiku-20240307", "gemini-2.5-flash"], | |
| TaskType.SUMMARIZATION: ["gpt-4o-mini", "claude-3-5-haiku-20241022", "gemini-2.5-flash"], | |
| TaskType.CLASSIFICATION: ["gpt-4o-mini", "claude-3-haiku-20240307", "llama-3.1-8b-instant"], | |
| TaskType.CREATIVE: ["claude-3-5-sonnet-20241022", "gpt-4o", "gemini-2.5-pro"], | |
| TaskType.FAST: ["llama-3.1-8b-instant", "gemini-2.5-flash", "gpt-4o-mini"], | |
| }) | |
| class CostTracker: | |
| """Track costs across providers and models.""" | |
| total_cost: float = 0.0 | |
| cost_by_provider: dict[str, float] = field(default_factory=dict) | |
| cost_by_model: dict[str, float] = field(default_factory=dict) | |
| request_count: int = 0 | |
| total_tokens: TokenUsage = field(default_factory=TokenUsage) | |
| start_time: datetime = field(default_factory=datetime.utcnow) | |
| def track(self, response: CompletionResponse) -> None: | |
| """Track a completion response.""" | |
| self.total_cost += response.cost | |
| self.request_count += 1 | |
| self.total_tokens = self.total_tokens + response.usage | |
| # By provider | |
| self.cost_by_provider[response.provider] = ( | |
| self.cost_by_provider.get(response.provider, 0.0) + response.cost | |
| ) | |
| # By model | |
| self.cost_by_model[response.model] = ( | |
| self.cost_by_model.get(response.model, 0.0) + response.cost | |
| ) | |
| def get_summary(self) -> dict[str, Any]: | |
| """Get cost summary.""" | |
| return { | |
| "total_cost_usd": self.total_cost, | |
| "request_count": self.request_count, | |
| "total_tokens": { | |
| "prompt": self.total_tokens.prompt_tokens, | |
| "completion": self.total_tokens.completion_tokens, | |
| "total": self.total_tokens.total_tokens, | |
| }, | |
| "cost_by_provider": self.cost_by_provider, | |
| "cost_by_model": self.cost_by_model, | |
| "avg_cost_per_request": ( | |
| self.total_cost / self.request_count if self.request_count > 0 else 0 | |
| ), | |
| "tracking_since": self.start_time.isoformat(), | |
| } | |
| def reset(self) -> None: | |
| """Reset cost tracking.""" | |
| self.total_cost = 0.0 | |
| self.cost_by_provider = {} | |
| self.cost_by_model = {} | |
| self.request_count = 0 | |
| self.total_tokens = TokenUsage() | |
| self.start_time = datetime.now(timezone.utc) | |
| class SmartModelRouter: | |
| """Intelligent model router with fallback and cost tracking.""" | |
| # Model quality rankings (subjective, based on benchmarks) | |
| MODEL_QUALITY_SCORES: dict[str, float] = { | |
| # OpenAI | |
| "gpt-4o": 0.95, | |
| "gpt-4-turbo": 0.92, | |
| "gpt-4": 0.90, | |
| "gpt-4o-mini": 0.80, | |
| "gpt-3.5-turbo": 0.70, | |
| # Anthropic | |
| "claude-3-opus-20240229": 0.97, | |
| "claude-3-5-sonnet-20241022": 0.94, | |
| "claude-3-sonnet-20240229": 0.88, | |
| "claude-3-5-haiku-20241022": 0.82, | |
| "claude-3-haiku-20240307": 0.75, | |
| # Google Gemini 2.5 & 3.0 | |
| "gemini-2.5-pro": 0.93, | |
| "gemini-2.5-flash": 0.85, | |
| "gemini-3-flash-preview": 0.87, | |
| "gemini-3.1-flash-lite-preview": 0.82, | |
| # Google Gemini 2.0 | |
| "gemini-2.0-flash": 0.88, | |
| "gemini-2.0-flash-lite": 0.80, | |
| # Google Gemini 1.5 | |
| "gemini-1.5-pro": 0.91, | |
| "gemini-1.5-flash": 0.78, | |
| "gemini-pro": 0.75, | |
| # Groq | |
| "llama-3.3-70b-versatile": 0.85, | |
| "llama-3.2-90b-vision-preview": 0.84, | |
| "llama-3.1-70b-versatile": 0.84, | |
| "llama3-70b-8192": 0.82, | |
| "mixtral-8x7b-32768": 0.78, | |
| "llama-3.1-8b-instant": 0.65, | |
| "llama3-8b-8192": 0.60, | |
| "gemma2-9b-it": 0.62, | |
| # NVIDIA | |
| "deepseek-r1": 0.92, | |
| "deepseek-v3.2": 0.90, | |
| "step-3.5-flash": 0.88, | |
| "glm4.7": 0.87, | |
| "devstral-2-123b": 0.86, | |
| "llama-3.3-70b": 0.85, | |
| "nemotron-70b": 0.83, | |
| } | |
| # Model speed rankings (relative, based on typical latency) | |
| MODEL_SPEED_SCORES: dict[str, float] = { | |
| # Groq is fastest | |
| "llama-3.1-8b-instant": 0.98, | |
| "llama3-8b-8192": 0.97, | |
| "gemma2-9b-it": 0.96, | |
| "mixtral-8x7b-32768": 0.94, | |
| "llama3-70b-8192": 0.92, | |
| "llama-3.1-70b-versatile": 0.91, | |
| "llama-3.3-70b-versatile": 0.90, | |
| "llama-3.2-90b-vision-preview": 0.89, | |
| # Google Flash models | |
| "gemini-2.5-flash": 0.90, | |
| "gemini-3-flash-preview": 0.89, | |
| "gemini-2.0-flash": 0.88, | |
| "gemini-1.5-flash": 0.88, | |
| "gemini-2.0-flash-lite": 0.87, | |
| "gemini-3.1-flash-lite-preview": 0.86, | |
| # NVIDIA models | |
| "step-3.5-flash": 0.85, | |
| "devstral-2-123b": 0.84, | |
| "llama-3.3-70b": 0.83, | |
| "nemotron-70b": 0.82, | |
| "glm4.7": 0.81, | |
| "deepseek-v3.2": 0.80, | |
| "deepseek-r1": 0.79, | |
| # Mini models | |
| "gpt-4o-mini": 0.85, | |
| "claude-3-haiku-20240307": 0.84, | |
| "claude-3-5-haiku-20241022": 0.83, | |
| "gpt-3.5-turbo": 0.82, | |
| # Pro models | |
| "gemini-pro": 0.75, | |
| "gemini-2.5-pro": 0.72, | |
| "gemini-1.5-pro": 0.70, | |
| "gpt-4o": 0.68, | |
| "claude-3-5-sonnet-20241022": 0.65, | |
| "claude-3-sonnet-20240229": 0.62, | |
| "gpt-4-turbo": 0.55, | |
| "gpt-4": 0.50, | |
| "claude-3-opus-20240229": 0.40, | |
| } | |
| def __init__( | |
| self, | |
| openai_api_key: str | SecretStr | None = None, | |
| anthropic_api_key: str | SecretStr | None = None, | |
| google_api_key: str | SecretStr | None = None, | |
| groq_api_key: str | SecretStr | None = None, | |
| nvidia_api_key: str | SecretStr | None = None, | |
| config: RoutingConfig | None = None, | |
| ): | |
| self.config = config or RoutingConfig() | |
| self.providers: dict[str, BaseProvider] = {} | |
| self.cost_tracker = CostTracker() | |
| self._initialized = False | |
| self._round_robin_index = 0 | |
| # Store API keys (handle SecretStr) | |
| self._api_keys = { | |
| "openai": self._get_key_value(openai_api_key), | |
| "anthropic": self._get_key_value(anthropic_api_key), | |
| "google": self._get_key_value(google_api_key), | |
| "groq": self._get_key_value(groq_api_key), | |
| "nvidia": self._get_key_value(nvidia_api_key), | |
| } | |
| def _get_key_value(key: str | SecretStr | None) -> str | None: | |
| """Extract string value from SecretStr if needed.""" | |
| if key is None: | |
| return None | |
| if isinstance(key, SecretStr): | |
| return key.get_secret_value() | |
| return key | |
| async def initialize(self) -> None: | |
| """Initialize all configured providers.""" | |
| if self._initialized: | |
| return | |
| # Initialize providers based on available API keys | |
| if self._api_keys["openai"]: | |
| provider = OpenAIProvider(api_key=self._api_keys["openai"]) | |
| await provider.initialize() | |
| self.providers["openai"] = provider | |
| logger.info("Initialized OpenAI provider") | |
| if self._api_keys["anthropic"]: | |
| provider = AnthropicProvider(api_key=self._api_keys["anthropic"]) | |
| await provider.initialize() | |
| self.providers["anthropic"] = provider | |
| logger.info("Initialized Anthropic provider") | |
| if self._api_keys["google"]: | |
| provider = GoogleProvider(api_key=self._api_keys["google"]) | |
| await provider.initialize() | |
| self.providers["google"] = provider | |
| logger.info("Initialized Google provider") | |
| if self._api_keys["groq"]: | |
| provider = GroqProvider(api_key=self._api_keys["groq"]) | |
| await provider.initialize() | |
| self.providers["groq"] = provider | |
| logger.info("Initialized Groq provider") | |
| if self._api_keys["nvidia"]: | |
| provider = NVIDIAProvider(api_key=self._api_keys["nvidia"]) | |
| await provider.initialize() | |
| self.providers["nvidia"] = provider | |
| logger.info("Initialized NVIDIA provider") | |
| if not self.providers: | |
| logger.warning("No LLM providers configured") | |
| self._initialized = True | |
| async def shutdown(self) -> None: | |
| """Shutdown all providers.""" | |
| for provider in self.providers.values(): | |
| await provider.shutdown() | |
| self.providers.clear() | |
| self._initialized = False | |
| def list_providers(self) -> list[str]: | |
| """Get list of initialized provider names.""" | |
| return list(self.providers.keys()) | |
| def get_available_models(self) -> list[ModelInfo]: | |
| """Get all available models across providers.""" | |
| models = [] | |
| for provider in self.providers.values(): | |
| models.extend(provider.get_models()) | |
| return models | |
| def get_provider_for_model(self, model: str) -> BaseProvider | None: | |
| """Get the provider for a specific model. | |
| Supports both formats: | |
| - "gemini-1.5-flash" (bare model name) | |
| - "google/gemini-1.5-flash" (provider/model format) | |
| """ | |
| # Strip provider prefix if present (e.g., "google/gemini-1.5-flash" -> "gemini-1.5-flash") | |
| model_name = model | |
| if "/" in model: | |
| provider_prefix, model_name = model.split("/", 1) | |
| # Try to match provider directly first | |
| if provider_prefix in self.providers: | |
| provider = self.providers[provider_prefix] | |
| try: | |
| if provider.get_model_info(model_name): | |
| return provider | |
| except Exception: | |
| pass | |
| # Check aliases | |
| if hasattr(provider, "MODEL_ALIASES"): | |
| if model_name in provider.MODEL_ALIASES: # type: ignore | |
| return provider | |
| # Fallback: try all providers with both original and stripped names | |
| for provider in self.providers.values(): | |
| for name in [model, model_name]: | |
| try: | |
| if provider.get_model_info(name): | |
| return provider | |
| except Exception: | |
| pass | |
| # Check aliases | |
| if hasattr(provider, "MODEL_ALIASES"): | |
| if name in provider.MODEL_ALIASES: # type: ignore | |
| return provider | |
| return None | |
| def _score_model( | |
| self, | |
| model_info: ModelInfo, | |
| strategy: RoutingStrategy, | |
| ) -> ModelScore: | |
| """Score a model based on routing strategy.""" | |
| model_id = model_info.id | |
| quality = self.MODEL_QUALITY_SCORES.get(model_id, 0.5) | |
| speed = self.MODEL_SPEED_SCORES.get(model_id, 0.5) | |
| # Calculate cost score (inverse of cost, normalized) | |
| max_cost = 0.1 # $0.10 per 1K tokens as reference | |
| avg_cost = (model_info.cost_per_1k_input + model_info.cost_per_1k_output) / 2 | |
| cost_score = 1.0 - min(avg_cost / max_cost, 1.0) | |
| # Calculate overall score based on strategy | |
| if strategy == RoutingStrategy.BEST_QUALITY: | |
| overall = quality * 0.8 + speed * 0.1 + cost_score * 0.1 | |
| elif strategy == RoutingStrategy.BEST_SPEED: | |
| overall = quality * 0.1 + speed * 0.8 + cost_score * 0.1 | |
| elif strategy == RoutingStrategy.LOWEST_COST: | |
| overall = quality * 0.1 + speed * 0.1 + cost_score * 0.8 | |
| else: # BEST_VALUE | |
| overall = quality * 0.4 + speed * 0.3 + cost_score * 0.3 | |
| return ModelScore( | |
| model_id=model_id, | |
| provider=model_info.provider, | |
| quality_score=quality, | |
| speed_score=speed, | |
| cost_score=cost_score, | |
| overall_score=overall, | |
| ) | |
| def route( | |
| self, | |
| task_type: TaskType = TaskType.GENERAL, | |
| strategy: RoutingStrategy | None = None, | |
| required_features: list[str] | None = None, | |
| ) -> tuple[str, BaseProvider] | None: | |
| """Route to the best model for the task. | |
| Args: | |
| task_type: Type of task to perform | |
| strategy: Routing strategy (uses default if not specified) | |
| required_features: Required model features (e.g., 'functions', 'vision') | |
| Returns: | |
| Tuple of (model_id, provider) or None if no suitable model found | |
| """ | |
| if not self.providers: | |
| return None | |
| strategy = strategy or self.config.default_strategy | |
| # Handle round robin specially | |
| if strategy == RoutingStrategy.ROUND_ROBIN: | |
| models = self.get_available_models() | |
| if not models: | |
| return None | |
| # Filter by features if needed | |
| if required_features: | |
| models = self._filter_by_features(models, required_features) | |
| if not models: | |
| return None | |
| model = models[self._round_robin_index % len(models)] | |
| self._round_robin_index += 1 | |
| provider = self.get_provider_for_model(model.id) | |
| return (model.id, provider) if provider else None | |
| # Get task preferences | |
| preferred_models = self.config.task_preferences.get(task_type, []) | |
| # Check preferred models first | |
| for model_id in preferred_models: | |
| provider = self.get_provider_for_model(model_id) | |
| if provider: | |
| model_info = provider.get_model_info(model_id) | |
| if model_info and self._meets_requirements(model_info, required_features): | |
| return (model_id, provider) | |
| # Score all available models | |
| scored_models: list[tuple[ModelScore, BaseProvider]] = [] | |
| for provider in self.providers.values(): | |
| for model_info in provider.get_models(): | |
| if self._meets_requirements(model_info, required_features): | |
| score = self._score_model(model_info, strategy) | |
| scored_models.append((score, provider)) | |
| if not scored_models: | |
| return None | |
| # Sort by overall score | |
| scored_models.sort(key=lambda x: x[0].overall_score, reverse=True) | |
| best_score, best_provider = scored_models[0] | |
| return (best_score.model_id, best_provider) | |
| def _meets_requirements( | |
| self, | |
| model_info: ModelInfo, | |
| required_features: list[str] | None, | |
| ) -> bool: | |
| """Check if model meets required features.""" | |
| if not required_features: | |
| return True | |
| for feature in required_features: | |
| if feature == "functions" and not model_info.supports_functions: | |
| return False | |
| if feature == "vision" and not model_info.supports_vision: | |
| return False | |
| if feature == "streaming" and not model_info.supports_streaming: | |
| return False | |
| return True | |
| def _filter_by_features( | |
| self, | |
| models: list[ModelInfo], | |
| required_features: list[str], | |
| ) -> list[ModelInfo]: | |
| """Filter models by required features.""" | |
| return [m for m in models if self._meets_requirements(m, required_features)] | |
| async def complete( | |
| self, | |
| messages: list[dict[str, Any]], | |
| model: str | None = None, | |
| task_type: TaskType = TaskType.GENERAL, | |
| strategy: RoutingStrategy | None = None, | |
| required_features: list[str] | None = None, | |
| fallback: bool = True, | |
| **kwargs: Any, | |
| ) -> CompletionResponse: | |
| """Generate a completion with automatic routing and fallback. | |
| Args: | |
| messages: List of message dicts | |
| model: Specific model to use (overrides routing) | |
| task_type: Type of task for routing | |
| strategy: Routing strategy | |
| required_features: Required model features | |
| fallback: Enable fallback on failure | |
| **kwargs: Additional completion parameters | |
| Returns: | |
| CompletionResponse from the model | |
| Raises: | |
| ProviderError: If all models fail | |
| """ | |
| if not self._initialized: | |
| await self.initialize() | |
| # Determine model(s) to try | |
| models_to_try: list[tuple[str, BaseProvider]] = [] | |
| if model: | |
| # Specific model requested | |
| provider = self.get_provider_for_model(model) | |
| if provider: | |
| models_to_try.append((model, provider)) | |
| else: | |
| raise ProviderError(f"Model {model} not found", "router") | |
| else: | |
| # Use routing | |
| route_result = self.route(task_type, strategy, required_features) | |
| if route_result: | |
| models_to_try.append(route_result) | |
| # Add fallback models | |
| if fallback and len(models_to_try) < self.config.max_fallback_attempts: | |
| # Get additional models for fallback | |
| preferred = self.config.task_preferences.get(task_type, []) | |
| for fallback_model in preferred: | |
| if len(models_to_try) >= self.config.max_fallback_attempts: | |
| break | |
| provider = self.get_provider_for_model(fallback_model) | |
| if provider and (fallback_model, provider) not in models_to_try: | |
| models_to_try.append((fallback_model, provider)) | |
| if not models_to_try: | |
| raise ProviderError("No suitable models available", "router") | |
| # Try models in order | |
| last_error: Exception | None = None | |
| for i, (model_id, provider) in enumerate(models_to_try): | |
| try: | |
| # Strip provider prefix if present (e.g., "google/gemini-1.5-flash" -> "gemini-1.5-flash") | |
| model_name = model_id.split("/", 1)[1] if "/" in model_id else model_id | |
| logger.info(f"Attempting completion with {provider.PROVIDER_NAME}/{model_name}") | |
| logger.debug(f"Router: model_id={model_id}, model_name={model_name}, provider={provider.PROVIDER_NAME}") | |
| response = await provider.complete(messages, model_name, **kwargs) | |
| # Track cost | |
| self.cost_tracker.track(response) | |
| return response | |
| except RateLimitError as e: | |
| logger.warning(f"Rate limited by {provider.PROVIDER_NAME}: {e}") | |
| last_error = e | |
| if i < len(models_to_try) - 1: | |
| await asyncio.sleep(self.config.fallback_delay_seconds) | |
| except ProviderError as e: | |
| logger.warning(f"Provider error from {provider.PROVIDER_NAME}: {e}") | |
| last_error = e | |
| if i < len(models_to_try) - 1: | |
| await asyncio.sleep(self.config.fallback_delay_seconds) | |
| except Exception as e: | |
| logger.error(f"Unexpected error from {provider.PROVIDER_NAME}: {e}") | |
| last_error = e | |
| # All models failed | |
| raise ProviderError( | |
| f"All models failed. Last error: {last_error}", | |
| "router", | |
| ) | |
| def get_cost_summary(self) -> dict[str, Any]: | |
| """Get cost tracking summary.""" | |
| return self.cost_tracker.get_summary() | |
| def reset_cost_tracking(self) -> None: | |
| """Reset cost tracking.""" | |
| self.cost_tracker.reset() | |
| def available_providers(self) -> list[str]: | |
| """List of initialized provider names.""" | |
| return list(self.providers.keys()) | |
| def __repr__(self) -> str: | |
| return ( | |
| f"SmartModelRouter(providers={list(self.providers.keys())}, " | |
| f"requests={self.cost_tracker.request_count}, " | |
| f"cost=${self.cost_tracker.total_cost:.4f})" | |
| ) | |