ianshank
feat: add personality output and bug fixes
40ee6b4
"""
LLM Client Factory and Provider Registry.
This module provides a factory function to instantiate the correct LLM client
based on provider settings, with lazy loading of adapters.
"""
import importlib
import logging
from typing import Any
from .base import BaseLLMClient, LLMClient, LLMResponse, LLMToolResponse, ToolCall
from .exceptions import (
CircuitBreakerOpenError,
LLMAuthenticationError,
LLMClientError,
LLMConnectionError,
LLMContentFilterError,
LLMContextLengthError,
LLMInvalidRequestError,
LLMModelNotFoundError,
LLMQuotaExceededError,
LLMRateLimitError,
LLMResponseParseError,
LLMServerError,
LLMStreamError,
LLMTimeoutError,
)
logger = logging.getLogger(__name__)
# Provider registry with lazy loading
# Maps provider name to (module_path, class_name)
_PROVIDER_REGISTRY: dict[str, tuple[str, str]] = {
"openai": ("src.adapters.llm.openai_client", "OpenAIClient"),
"anthropic": ("src.adapters.llm.anthropic_client", "AnthropicClient"),
"lmstudio": ("src.adapters.llm.lmstudio_client", "LMStudioClient"),
"local": ("src.adapters.llm.lmstudio_client", "LMStudioClient"), # Alias
}
# Cache for loaded client classes
_CLIENT_CACHE: dict[str, type[BaseLLMClient]] = {}
def register_provider(name: str, module_path: str, class_name: str, override: bool = False) -> None:
"""
Register a new LLM provider.
Args:
name: Provider identifier (e.g., "azure", "bedrock")
module_path: Full module path (e.g., "src.adapters.llm.azure_client")
class_name: Class name in the module (e.g., "AzureOpenAIClient")
override: If True, allow overriding existing provider
"""
if name in _PROVIDER_REGISTRY and not override:
raise ValueError(f"Provider '{name}' already registered. Use override=True to replace.")
_PROVIDER_REGISTRY[name] = (module_path, class_name)
# Clear cache if overriding
if name in _CLIENT_CACHE:
del _CLIENT_CACHE[name]
logger.info(f"Registered LLM provider: {name} -> {module_path}.{class_name}")
def list_providers() -> list[str]:
"""
List all registered provider names.
Returns:
List of provider identifiers
"""
return list(_PROVIDER_REGISTRY.keys())
def get_provider_class(provider: str) -> type[BaseLLMClient]:
"""
Get the client class for a provider (with lazy loading).
Args:
provider: Provider identifier
Returns:
Client class (not instantiated)
Raises:
ValueError: If provider not registered
ImportError: If module cannot be loaded
"""
if provider not in _PROVIDER_REGISTRY:
available = ", ".join(list_providers())
raise ValueError(f"Unknown provider '{provider}'. Available: {available}")
# Check cache first
if provider in _CLIENT_CACHE:
return _CLIENT_CACHE[provider]
# Lazy load the module
module_path, class_name = _PROVIDER_REGISTRY[provider]
try:
module = importlib.import_module(module_path)
client_class = getattr(module, class_name)
except ImportError as e:
raise ImportError(f"Failed to load provider '{provider}': {e}") from e
except AttributeError as e:
raise ImportError(f"Class '{class_name}' not found in module '{module_path}'") from e
# Cache for future use
_CLIENT_CACHE[provider] = client_class
return client_class
def create_client(
provider: str = "openai",
*,
api_key: str | None = None,
model: str | None = None,
base_url: str | None = None,
timeout: float | None = None,
max_retries: int | None = None,
**kwargs: Any,
) -> BaseLLMClient:
"""
Create an LLM client instance.
This is the main factory function for creating provider clients.
Args:
provider: Provider name ("openai", "anthropic", "lmstudio", etc.)
api_key: API key (may be optional for some providers)
model: Model identifier
base_url: Base URL for API
timeout: Request timeout in seconds
max_retries: Maximum retry attempts
**kwargs: Provider-specific parameters
Returns:
Configured LLMClient instance
Examples:
# OpenAI client
client = create_client("openai", model="gpt-4-turbo-preview")
# Anthropic client
client = create_client("anthropic", model="sonnet")
# Local LM Studio
client = create_client("lmstudio", base_url="http://localhost:1234/v1")
# With custom settings
client = create_client(
"openai",
api_key="sk-...",
timeout=120.0,
max_retries=5,
organization="org-..."
)
"""
client_class = get_provider_class(provider)
# Build kwargs for client initialization
init_kwargs = {**kwargs}
if api_key is not None:
init_kwargs["api_key"] = api_key
if model is not None:
init_kwargs["model"] = model
if base_url is not None:
init_kwargs["base_url"] = base_url
if timeout is not None:
init_kwargs["timeout"] = timeout
if max_retries is not None:
init_kwargs["max_retries"] = max_retries
logger.info(f"Creating {provider} client with model={model or 'default'}")
return client_class(**init_kwargs)
def create_client_from_config(config: dict) -> BaseLLMClient:
"""
Create an LLM client from a configuration dictionary.
Useful for loading settings from YAML/JSON config files.
Args:
config: Configuration dictionary with keys:
- provider: Required provider name
- Other keys passed to create_client
Returns:
Configured LLMClient instance
Example:
config = {
"provider": "openai",
"model": "gpt-4-turbo-preview",
"timeout": 60.0,
"max_retries": 3
}
client = create_client_from_config(config)
"""
config = config.copy()
provider = config.pop("provider", "openai")
return create_client(provider, **config)
# Convenience aliases for common use cases
def create_openai_client(**kwargs) -> BaseLLMClient:
"""Create an OpenAI client."""
return create_client("openai", **kwargs)
def create_anthropic_client(**kwargs) -> BaseLLMClient:
"""Create an Anthropic Claude client."""
return create_client("anthropic", **kwargs)
def create_local_client(**kwargs) -> BaseLLMClient:
"""Create a local LM Studio client."""
return create_client("lmstudio", **kwargs)
__all__ = [
# Base types
"LLMClient",
"LLMResponse",
"LLMToolResponse",
"ToolCall",
"BaseLLMClient",
# Exceptions
"LLMClientError",
"LLMAuthenticationError",
"LLMRateLimitError",
"LLMQuotaExceededError",
"LLMModelNotFoundError",
"LLMContextLengthError",
"LLMInvalidRequestError",
"LLMTimeoutError",
"LLMConnectionError",
"LLMServerError",
"LLMResponseParseError",
"LLMStreamError",
"LLMContentFilterError",
"CircuitBreakerOpenError",
# Factory functions
"create_client",
"create_client_from_config",
"create_openai_client",
"create_anthropic_client",
"create_local_client",
# Registry functions
"register_provider",
"list_providers",
"get_provider_class",
]