Spaces:
Paused
Paused
| """ | |
| LLM Provider Factory for Flare | |
| """ | |
| import os | |
| from typing import Optional, Dict, Any | |
| from dotenv import load_dotenv | |
| from llm_interface import LLMInterface, SparkLLM, GPT4oLLM | |
| from config_provider import ConfigProvider | |
| from utils import log | |
| class LLMFactory: | |
| """Factory class to create appropriate LLM provider based on llm_provider config""" | |
| def create_provider() -> LLMInterface: | |
| """Create and return appropriate LLM provider based on config""" | |
| cfg = ConfigProvider.get() | |
| llm_provider = cfg.global_config.llm_provider | |
| if not llm_provider or not llm_provider.name: | |
| raise ValueError("No LLM provider configured") | |
| provider_name = llm_provider.name | |
| log(f"π Creating LLM provider: {provider_name}") | |
| # Get provider config | |
| provider_config = cfg.global_config.get_provider_config("llm", provider_name) | |
| if not provider_config: | |
| raise ValueError(f"Unknown LLM provider: {provider_name}") | |
| # Get API key | |
| api_key = LLMFactory._get_api_key(provider_name) | |
| if not api_key and provider_config.requires_api_key: | |
| raise ValueError(f"API key required for {provider_name} but not configured") | |
| # Get settings | |
| settings = llm_provider.settings or {} | |
| # Create appropriate provider | |
| if provider_name == "spark": | |
| return LLMFactory._create_spark_provider(api_key, llm_provider.endpoint, settings) | |
| elif provider_name in ("gpt4o", "gpt4o-mini"): | |
| return LLMFactory._create_gpt_provider(provider_name, api_key, settings) | |
| else: | |
| raise ValueError(f"Unsupported LLM provider: {provider_name}") | |
| def _create_spark_provider(api_key: str, endpoint: Optional[str], settings: Dict[str, Any]) -> SparkLLM: | |
| """Create Spark LLM provider""" | |
| if not endpoint: | |
| raise ValueError("Spark requires endpoint to be configured") | |
| log(f"π Creating SparkLLM provider") | |
| log(f"π Endpoint: {endpoint}") | |
| # Determine provider variant for backward compatibility | |
| provider_variant = "spark-cloud" | |
| if not ConfigProvider.get().global_config.is_cloud_mode(): | |
| provider_variant = "spark-onpremise" | |
| return SparkLLM( | |
| spark_endpoint=str(endpoint), | |
| spark_token=api_key, | |
| provider_variant=provider_variant, | |
| settings=settings | |
| ) | |
| def _create_gpt_provider(model_type: str, api_key: str, settings: Dict[str, Any]) -> GPT4oLLM: | |
| """Create GPT-4o LLM provider""" | |
| # Determine model | |
| model = "gpt-4o-mini" if model_type == "gpt4o-mini" else "gpt-4o" | |
| log(f"π€ Creating GPT4oLLM provider with model: {model}") | |
| return GPT4oLLM( | |
| api_key=api_key, | |
| model=model, | |
| settings=settings | |
| ) | |
| def _get_api_key(provider_name: str) -> Optional[str]: | |
| """Get API key from config or environment""" | |
| cfg = ConfigProvider.get() | |
| # First check encrypted config | |
| api_key = cfg.global_config.get_plain_api_key("llm") | |
| if api_key: | |
| log("π Using decrypted API key from config") | |
| return api_key | |
| # Then check environment based on provider | |
| env_var_map = { | |
| "spark": "SPARK_TOKEN", | |
| "gpt4o": "OPENAI_API_KEY", | |
| "gpt4o-mini": "OPENAI_API_KEY", | |
| } | |
| env_var = env_var_map.get(provider_name) | |
| if env_var: | |
| # Check if running in HuggingFace Space | |
| if os.environ.get("SPACE_ID"): | |
| api_key = os.environ.get(env_var) | |
| if api_key: | |
| log(f"π Using {env_var} from HuggingFace secrets") | |
| return api_key | |
| else: | |
| # Local development | |
| load_dotenv() | |
| api_key = os.getenv(env_var) | |
| if api_key: | |
| log(f"π Using {env_var} from .env file") | |
| return api_key | |
| return None |