Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced Base Plugin | |
| Location-agnostic backend plugin that supports: | |
| - Local backends (running in project) | |
| - Network backends (running on LAN) | |
| - Cloud backends (commercial APIs) | |
| Uses prompt transformation layer for backend format abstraction. | |
| """ | |
| from abc import ABC, abstractmethod | |
| from typing import List, Optional | |
| from pathlib import Path | |
| import requests | |
| from PIL import Image | |
| from .backend_config import BackendConnectionConfig, BackendLocation, BackendProtocol | |
| from .prompt_transformer import ( | |
| StandardGenerationRequest, | |
| PromptTransformer, | |
| get_transformer | |
| ) | |
| class EnhancedBackendPlugin(ABC): | |
| """ | |
| Enhanced base class for all backend plugins. | |
| Supports three deployment scenarios: | |
| 1. Local: Backend runs in project structure | |
| 2. Network: Backend runs on LAN (IP:PORT) | |
| 3. Cloud: Commercial API over internet | |
| The application NEVER directly imports backends. | |
| Everything goes through this abstraction layer. | |
| """ | |
| def __init__(self, config: BackendConnectionConfig): | |
| """ | |
| Initialize plugin with connection configuration. | |
| Args: | |
| config: Backend connection configuration | |
| """ | |
| self.config = config | |
| self.name = config.name | |
| self.backend_type = config.backend_type | |
| self.location = config.location | |
| self.protocol = config.protocol | |
| # Get prompt transformer for this backend type | |
| self.transformer = get_transformer(config.backend_type) | |
| # Backend-specific client (set by subclass) | |
| self._client = None | |
| def _initialize_local(self) -> None: | |
| """ | |
| Initialize local backend. | |
| Subclasses implement this to set up local backend. | |
| Example: Import and instantiate local Python module. | |
| """ | |
| pass | |
| def _initialize_network(self) -> None: | |
| """ | |
| Initialize network backend. | |
| Subclasses implement this to set up network connection. | |
| Example: Create HTTP client with endpoint. | |
| """ | |
| pass | |
| def _initialize_cloud(self) -> None: | |
| """ | |
| Initialize cloud backend. | |
| Subclasses implement this to set up cloud API client. | |
| Example: Configure API client with credentials. | |
| """ | |
| pass | |
| def initialize(self) -> None: | |
| """ | |
| Initialize backend based on location. | |
| Automatically calls the appropriate initialization method. | |
| """ | |
| if self.location == BackendLocation.LOCAL: | |
| self._initialize_local() | |
| elif self.location == BackendLocation.NETWORK: | |
| self._initialize_network() | |
| elif self.location == BackendLocation.CLOUD: | |
| self._initialize_cloud() | |
| def health_check(self) -> bool: | |
| """ | |
| Check if backend is available and healthy. | |
| Works for local, network, and cloud backends. | |
| """ | |
| if self.location == BackendLocation.LOCAL: | |
| # Local: Check if client is initialized | |
| return self._client is not None | |
| elif self.location in [BackendLocation.NETWORK, BackendLocation.CLOUD]: | |
| # Network/Cloud: Send health check request | |
| try: | |
| health_url = self.config.get_full_endpoint( | |
| self.config.health_check_endpoint or '/health' | |
| ) | |
| response = requests.get( | |
| health_url, | |
| timeout=5, | |
| headers=self._get_auth_headers() | |
| ) | |
| return response.status_code == 200 | |
| except Exception as e: | |
| print(f"Health check failed for {self.name}: {e}") | |
| return False | |
| return False | |
| def generate_image( | |
| self, | |
| request: StandardGenerationRequest | |
| ) -> List[Image.Image]: | |
| """ | |
| Generate image using this backend. | |
| This is the ONLY method the application calls. | |
| It handles: | |
| 1. Transform standard request → backend format | |
| 2. Send to backend (local/network/cloud) | |
| 3. Transform backend response → standard format | |
| Args: | |
| request: Standard generation request | |
| Returns: | |
| List of generated images | |
| """ | |
| # Step 1: Transform request to backend-specific format | |
| backend_request = self.transformer.transform_request(request) | |
| # Step 2: Send to backend based on location | |
| if self.location == BackendLocation.LOCAL: | |
| backend_response = self._generate_local(backend_request) | |
| elif self.location == BackendLocation.NETWORK: | |
| backend_response = self._generate_network(backend_request) | |
| elif self.location == BackendLocation.CLOUD: | |
| backend_response = self._generate_cloud(backend_request) | |
| else: | |
| raise ValueError(f"Unknown backend location: {self.location}") | |
| # Step 3: Transform response to standard format | |
| images = self.transformer.transform_response(backend_response) | |
| return images | |
| def _generate_local(self, backend_request: dict) -> any: | |
| """ | |
| Generate using local backend. | |
| Args: | |
| backend_request: Backend-specific request format | |
| Returns: | |
| Backend-specific response | |
| """ | |
| pass | |
| def _generate_network(self, backend_request: dict) -> any: | |
| """ | |
| Generate using network backend. | |
| Args: | |
| backend_request: Backend-specific request format | |
| Returns: | |
| Backend-specific response | |
| """ | |
| pass | |
| def _generate_cloud(self, backend_request: dict) -> any: | |
| """ | |
| Generate using cloud backend. | |
| Args: | |
| backend_request: Backend-specific request format | |
| Returns: | |
| Backend-specific response | |
| """ | |
| pass | |
| def _get_auth_headers(self) -> dict: | |
| """Get authentication headers for API requests.""" | |
| headers = {} | |
| if self.config.api_key: | |
| # Common auth header patterns | |
| if self.backend_type == 'gemini': | |
| headers['x-goog-api-key'] = self.config.api_key | |
| else: | |
| headers['Authorization'] = f'Bearer {self.config.api_key}' | |
| return headers | |
| def _send_http_request( | |
| self, | |
| endpoint: str, | |
| data: dict, | |
| method: str = 'POST' | |
| ) -> any: | |
| """ | |
| Send HTTP request to backend. | |
| Helper method for network/cloud backends. | |
| """ | |
| url = self.config.get_full_endpoint(endpoint) | |
| headers = self._get_auth_headers() | |
| headers['Content-Type'] = 'application/json' | |
| try: | |
| if method == 'POST': | |
| response = requests.post( | |
| url, | |
| json=data, | |
| headers=headers, | |
| timeout=self.config.timeout | |
| ) | |
| elif method == 'GET': | |
| response = requests.get( | |
| url, | |
| params=data, | |
| headers=headers, | |
| timeout=self.config.timeout | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| raise RuntimeError(f"Backend request failed: {e}") | |
| def get_capabilities(self) -> dict: | |
| """ | |
| Report backend capabilities. | |
| Returns capabilities from configuration. | |
| """ | |
| return { | |
| 'name': self.name, | |
| 'backend_type': self.backend_type, | |
| 'location': self.location.value, | |
| 'protocol': self.protocol.value, | |
| 'endpoint': self.config.endpoint, | |
| **self.config.capabilities | |
| } | |
| def __repr__(self): | |
| return ( | |
| f"{self.__class__.__name__}(" | |
| f"name={self.name}, " | |
| f"location={self.location.value}, " | |
| f"endpoint={self.config.endpoint})" | |
| ) | |