from typing import Optional, List, Dict, Any, Tuple from .registry import registry from .models import Model, ModelCapability from utils.logger import logger from .registry import DEFAULT_PREMIUM_MODEL, DEFAULT_FREE_MODEL class ModelManager: def __init__(self): self.registry = registry def get_model(self, model_id: str) -> Optional[Model]: return self.registry.get(model_id) def resolve_model_id(self, model_id: str) -> str: logger.debug(f"resolve_model_id called with: '{model_id}' (type: {type(model_id)})") resolved = self.registry.resolve_model_id(model_id) if resolved: logger.debug(f"Resolved model '{model_id}' to '{resolved}'") return resolved all_aliases = list(self.registry._aliases.keys()) logger.warning(f"Could not resolve model ID: '{model_id}'. Available aliases: {all_aliases[:10]}...") return model_id def validate_model(self, model_id: str) -> Tuple[bool, str]: model = self.get_model(model_id) if not model: return False, f"Model '{model_id}' not found" if not model.enabled: return False, f"Model '{model.name}' is currently disabled" return True, "" def calculate_cost( self, model_id: str, input_tokens: int, output_tokens: int ) -> Optional[float]: model = self.get_model(model_id) if not model or not model.pricing: logger.warning(f"No pricing available for model: {model_id}") return None input_cost = input_tokens * model.pricing.input_cost_per_token output_cost = output_tokens * model.pricing.output_cost_per_token total_cost = input_cost + output_cost logger.debug( f"Cost calculation for {model.name}: " f"{input_tokens} input tokens (${input_cost:.6f}) + " f"{output_tokens} output tokens (${output_cost:.6f}) = " f"${total_cost:.6f}" ) return total_cost def get_models_for_tier(self, tier: str) -> List[Model]: return self.registry.get_by_tier(tier, enabled_only=True) def get_models_with_capability(self, capability: ModelCapability) -> List[Model]: return self.registry.get_by_capability(capability, enabled_only=True) def select_best_model( self, tier: str, required_capabilities: Optional[List[ModelCapability]] = None, min_context_window: Optional[int] = None, prefer_cheaper: bool = False ) -> Optional[Model]: models = self.get_models_for_tier(tier) if required_capabilities: models = [ m for m in models if all(cap in m.capabilities for cap in required_capabilities) ] if min_context_window: models = [m for m in models if m.context_window >= min_context_window] if not models: return None if prefer_cheaper and any(m.pricing for m in models): models_with_pricing = [m for m in models if m.pricing] if models_with_pricing: models = sorted( models_with_pricing, key=lambda m: m.pricing.input_cost_per_million_tokens ) else: models = sorted( models, key=lambda m: (-m.priority, not m.recommended) ) return models[0] if models else None def get_default_model(self, tier: str = "free") -> Optional[Model]: models = self.get_models_for_tier(tier) recommended = [m for m in models if m.recommended] if recommended: recommended = sorted(recommended, key=lambda m: -m.priority) return recommended[0] if models: models = sorted(models, key=lambda m: -m.priority) return models[0] return None def get_context_window(self, model_id: str, default: int = 31_000) -> int: return self.registry.get_context_window(model_id, default) def check_token_limit( self, model_id: str, token_count: int, is_input: bool = True ) -> Tuple[bool, int]: model = self.get_model(model_id) if not model: return False, 0 if is_input: max_allowed = model.context_window else: max_allowed = model.max_output_tokens or model.context_window return token_count <= max_allowed, max_allowed def format_model_info(self, model_id: str) -> Dict[str, Any]: model = self.get_model(model_id) if not model: return {"error": f"Model '{model_id}' not found"} return { "id": model.id, "name": model.name, "provider": model.provider.value, "context_window": model.context_window, "max_output_tokens": model.max_output_tokens, "capabilities": [cap.value for cap in model.capabilities], "pricing": { "input_per_million": model.pricing.input_cost_per_million_tokens, "output_per_million": model.pricing.output_cost_per_million_tokens, } if model.pricing else None, "enabled": model.enabled, "beta": model.beta, "tier_availability": model.tier_availability, "priority": model.priority, "recommended": model.recommended, } def list_available_models( self, tier: Optional[str] = None, include_disabled: bool = False ) -> List[Dict[str, Any]]: logger.debug(f"list_available_models called with tier='{tier}', include_disabled={include_disabled}") if tier: models = self.registry.get_by_tier(tier, enabled_only=not include_disabled) logger.debug(f"Found {len(models)} models for tier '{tier}'") else: models = self.registry.get_all(enabled_only=not include_disabled) logger.debug(f"Found {len(models)} total models") if models: model_names = [m.name for m in models] logger.debug(f"Models: {model_names}") else: logger.warning(f"No models found for tier '{tier}' - this might indicate a configuration issue") models = sorted( models, key=lambda m: (not m.is_free_tier, -m.priority, m.name) ) return [self.format_model_info(m.id) for m in models] def get_legacy_constants(self) -> Dict: return self.registry.to_legacy_format() async def get_default_model_for_user(self, client, user_id: str) -> str: try: from utils.config import config, EnvMode if config.ENV_MODE == EnvMode.LOCAL: return DEFAULT_PREMIUM_MODEL from services.billing import get_user_subscription, SUBSCRIPTION_TIERS subscription = await get_user_subscription(user_id) is_paid_tier = False if subscription: price_id = None if subscription.get('items') and subscription['items'].get('data') and len(subscription['items']['data']) > 0: price_id = subscription['items']['data'][0]['price']['id'] else: price_id = subscription.get('price_id') tier_info = SUBSCRIPTION_TIERS.get(price_id) if tier_info and tier_info['name'] != 'free': is_paid_tier = True if is_paid_tier: logger.debug(f"Setting Claude Sonnet 4 as default for paid user {user_id}") return DEFAULT_PREMIUM_MODEL else: logger.debug(f"Setting Kimi K2 as default for free user {user_id}") return DEFAULT_FREE_MODEL except Exception as e: logger.warning(f"Failed to determine user tier for {user_id}: {e}") return DEFAULT_FREE_MODEL model_manager = ModelManager()