Spaces:
Runtime error
Runtime error
| import httpx | |
| from typing import Optional, Iterator, List, Dict, Union | |
| import logging | |
| class InferenceApi: | |
| def __init__(self, config: dict): | |
| """Initialize the Inference API with configuration.""" | |
| self.logger = logging.getLogger(__name__) | |
| self.logger.info("Initializing Inference API") | |
| # Get base URL from config | |
| self.base_url = config["llm_server"]["base_url"] | |
| self.timeout = config["llm_server"].get("timeout", 60) | |
| # Initialize HTTP client | |
| self.client = httpx.AsyncClient( | |
| base_url=self.base_url, | |
| timeout=self.timeout | |
| ) | |
| self.logger.info("Inference API initialized successfully") | |
| async def generate_response( | |
| self, | |
| prompt: str, | |
| system_message: Optional[str] = None, | |
| max_new_tokens: Optional[int] = None | |
| ) -> str: | |
| """ | |
| Generate a complete response by forwarding the request to the LLM Server. | |
| """ | |
| self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...") | |
| try: | |
| response = await self.client.post( | |
| "/api/v1/generate", | |
| json={ | |
| "prompt": prompt, | |
| "system_message": system_message, | |
| "max_new_tokens": max_new_tokens | |
| } | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| return data["generated_text"] | |
| except Exception as e: | |
| self.logger.error(f"Error in generate_response: {str(e)}") | |
| raise | |
| async def generate_stream( | |
| self, | |
| prompt: str, | |
| system_message: Optional[str] = None, | |
| max_new_tokens: Optional[int] = None | |
| ) -> Iterator[str]: | |
| """ | |
| Generate a streaming response by forwarding the request to the LLM Server. | |
| """ | |
| self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...") | |
| try: | |
| async with self.client.stream( | |
| "POST", | |
| "/api/v1/generate/stream", | |
| json={ | |
| "prompt": prompt, | |
| "system_message": system_message, | |
| "max_new_tokens": max_new_tokens | |
| } | |
| ) as response: | |
| response.raise_for_status() | |
| async for chunk in response.aiter_text(): | |
| yield chunk | |
| except Exception as e: | |
| self.logger.error(f"Error in generate_stream: {str(e)}") | |
| raise | |
| async def generate_embedding(self, text: str) -> List[float]: | |
| """ | |
| Generate embedding by forwarding the request to the LLM Server. | |
| """ | |
| self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...") | |
| try: | |
| response = await self.client.post( | |
| "/api/v1/embedding", | |
| json={"text": text} | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| return data["embedding"] | |
| except Exception as e: | |
| self.logger.error(f"Error in generate_embedding: {str(e)}") | |
| raise | |
| async def check_system_status(self) -> Dict[str, Union[Dict, str]]: | |
| """ | |
| Get system status from the LLM Server. | |
| """ | |
| try: | |
| response = await self.client.get("/api/v1/system/status") | |
| response.raise_for_status() | |
| return response.json() | |
| except Exception as e: | |
| self.logger.error(f"Error getting system status: {str(e)}") | |
| raise | |
| async def validate_system(self) -> Dict[str, Union[Dict, str, List[str]]]: | |
| """ | |
| Get system validation status from the LLM Server. | |
| """ | |
| try: | |
| response = await self.client.get("/api/v1/system/validate") | |
| response.raise_for_status() | |
| return response.json() | |
| except Exception as e: | |
| self.logger.error(f"Error validating system: {str(e)}") | |
| raise | |
| async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, str]: | |
| """ | |
| Initialize a model on the LLM Server. | |
| """ | |
| try: | |
| response = await self.client.post( | |
| "/api/v1/model/initialize", | |
| params={"model_name": model_name} if model_name else None | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| except Exception as e: | |
| self.logger.error(f"Error initializing model: {str(e)}") | |
| raise | |
| async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, str]: | |
| """ | |
| Initialize an embedding model on the LLM Server. | |
| """ | |
| try: | |
| response = await self.client.post( | |
| "/api/v1/model/initialize/embedding", | |
| params={"model_name": model_name} if model_name else None | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| except Exception as e: | |
| self.logger.error(f"Error initializing embedding model: {str(e)}") | |
| raise | |
| async def close(self): | |
| """Close the HTTP client session.""" | |
| await self.client.aclose() |