Spaces:
Running
Running
| """ | |
| LLM Provider Interface for Flare | |
| """ | |
| import os | |
| from abc import ABC, abstractmethod | |
| from typing import Dict, List, Optional, Any | |
| import httpx | |
| from openai import AsyncOpenAI | |
| from utils import log | |
| class LLMInterface(ABC): | |
| """Abstract base class for LLM providers""" | |
| def __init__(self, settings: Dict[str, Any] = None): | |
| """Initialize with provider settings""" | |
| self.settings = settings or {} | |
| self.internal_prompt = self.settings.get("internal_prompt", "") | |
| self.parameter_collection_config = self.settings.get("parameter_collection_config", {}) | |
| async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: | |
| """Generate response from LLM""" | |
| pass | |
| async def startup(self, project_config: Dict) -> bool: | |
| """Initialize LLM with project config""" | |
| pass | |
| class SparkLLM(LLMInterface): | |
| """Spark LLM integration""" | |
| def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None): | |
| super().__init__(settings) | |
| self.spark_endpoint = spark_endpoint.rstrip("/") | |
| self.spark_token = spark_token | |
| self.provider_variant = provider_variant | |
| log(f"π SparkLLM initialized with endpoint: {self.spark_endpoint}") | |
| async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: | |
| """Generate response from Spark LLM""" | |
| headers = { | |
| "Authorization": f"Bearer {self.spark_token}", | |
| "Content-Type": "application/json" | |
| } | |
| # Build payload | |
| payload = { | |
| "system_prompt": system_prompt, | |
| "user_input": user_input, | |
| "context": context | |
| } | |
| try: | |
| async with httpx.AsyncClient(timeout=60) as client: | |
| response = await client.post( | |
| f"{self.spark_endpoint}/generate", | |
| json=payload, | |
| headers=headers | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| # Try different response fields | |
| raw = data.get("model_answer", "").strip() | |
| if not raw: | |
| raw = (data.get("assistant") or data.get("text", "")).strip() | |
| return raw | |
| except Exception as e: | |
| log(f"β Spark error: {e}") | |
| raise | |
| async def startup(self, project_config: Dict) -> bool: | |
| """Send startup request to Spark""" | |
| headers = { | |
| "Authorization": f"Bearer {self.spark_token}", | |
| "Content-Type": "application/json" | |
| } | |
| # Extract required fields from project config | |
| body = { | |
| "work_mode": self.provider_variant, | |
| "cloud_token": self.spark_token, | |
| "project_name": project_config.get("name"), | |
| "project_version": project_config.get("version_id"), | |
| "repo_id": project_config.get("repo_id"), | |
| "generation_config": project_config.get("generation_config", {}), | |
| "use_fine_tune": project_config.get("use_fine_tune", False), | |
| "fine_tune_zip": project_config.get("fine_tune_zip", "") | |
| } | |
| try: | |
| async with httpx.AsyncClient(timeout=10) as client: | |
| response = await client.post( | |
| f"{self.spark_endpoint}/startup", | |
| json=body, | |
| headers=headers | |
| ) | |
| if response.status_code >= 400: | |
| log(f"β Spark startup failed: {response.status_code} - {response.text}") | |
| return False | |
| log(f"β Spark acknowledged startup ({response.status_code})") | |
| return True | |
| except Exception as e: | |
| log(f"β οΈ Spark startup error: {e}") | |
| return False | |
| class GPT4oLLM(LLMInterface): | |
| """OpenAI GPT integration""" | |
| def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict[str, Any] = None): | |
| super().__init__(settings) | |
| self.api_key = api_key | |
| self.model = self._map_model_name(model) | |
| self.client = AsyncOpenAI(api_key=api_key) | |
| # Extract model-specific settings | |
| self.temperature = settings.get("temperature", 0.7) if settings else 0.7 | |
| self.max_tokens = settings.get("max_tokens", 4096) if settings else 4096 | |
| log(f"β Initialized GPT LLM with model: {self.model}") | |
| def _map_model_name(self, model: str) -> str: | |
| """Map provider name to actual model name""" | |
| mappings = { | |
| "gpt4o": "gpt-4", | |
| "gpt4o-mini": "gpt-4o-mini" | |
| } | |
| return mappings.get(model, model) | |
| async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str: | |
| """Generate response from OpenAI""" | |
| try: | |
| # Build messages | |
| messages = [{"role": "system", "content": system_prompt}] | |
| # Add context | |
| for msg in context: | |
| messages.append({ | |
| "role": msg.get("role", "user"), | |
| "content": msg.get("content", "") | |
| }) | |
| # Add current user input | |
| messages.append({"role": "user", "content": user_input}) | |
| # Call OpenAI | |
| response = await self.client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| log(f"β OpenAI error: {e}") | |
| raise | |
| async def startup(self, project_config: Dict) -> bool: | |
| """GPT doesn't need startup, always return True""" | |
| log("β GPT provider ready (no startup needed)") | |
| return True |