Spaces:
Runtime error
Runtime error
import httpx | |
from typing import Optional, Iterator, List, Dict, Union | |
import logging | |
class InferenceApi: | |
def __init__(self, config: dict): | |
"""Initialize the Inference API with configuration.""" | |
self.logger = logging.getLogger(__name__) | |
self.logger.info("Initializing Inference API") | |
# Get base URL from config | |
self.base_url = config["llm_server"]["base_url"] | |
self.timeout = config["llm_server"].get("timeout", 60) | |
# Initialize HTTP client | |
self.client = httpx.AsyncClient( | |
base_url=self.base_url, | |
timeout=self.timeout | |
) | |
self.logger.info("Inference API initialized successfully") | |
async def generate_response( | |
self, | |
prompt: str, | |
system_message: Optional[str] = None, | |
max_new_tokens: Optional[int] = None | |
) -> str: | |
""" | |
Generate a complete response by forwarding the request to the LLM Server. | |
""" | |
self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...") | |
try: | |
response = await self.client.post( | |
"/api/v1/generate", | |
json={ | |
"prompt": prompt, | |
"system_message": system_message, | |
"max_new_tokens": max_new_tokens | |
} | |
) | |
response.raise_for_status() | |
data = response.json() | |
return data["generated_text"] | |
except Exception as e: | |
self.logger.error(f"Error in generate_response: {str(e)}") | |
raise | |
async def generate_stream( | |
self, | |
prompt: str, | |
system_message: Optional[str] = None, | |
max_new_tokens: Optional[int] = None | |
) -> Iterator[str]: | |
""" | |
Generate a streaming response by forwarding the request to the LLM Server. | |
""" | |
self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...") | |
try: | |
async with self.client.stream( | |
"POST", | |
"/api/v1/generate/stream", | |
json={ | |
"prompt": prompt, | |
"system_message": system_message, | |
"max_new_tokens": max_new_tokens | |
} | |
) as response: | |
response.raise_for_status() | |
async for chunk in response.aiter_text(): | |
yield chunk | |
except Exception as e: | |
self.logger.error(f"Error in generate_stream: {str(e)}") | |
raise | |
async def generate_embedding(self, text: str) -> List[float]: | |
""" | |
Generate embedding by forwarding the request to the LLM Server. | |
""" | |
self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...") | |
try: | |
response = await self.client.post( | |
"/api/v1/embedding", | |
json={"text": text} | |
) | |
response.raise_for_status() | |
data = response.json() | |
return data["embedding"] | |
except Exception as e: | |
self.logger.error(f"Error in generate_embedding: {str(e)}") | |
raise | |
async def check_system_status(self) -> Dict[str, Union[Dict, str]]: | |
""" | |
Get system status from the LLM Server. | |
""" | |
try: | |
response = await self.client.get("/api/v1/system/status") | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
self.logger.error(f"Error getting system status: {str(e)}") | |
raise | |
async def validate_system(self) -> Dict[str, Union[Dict, str, List[str]]]: | |
""" | |
Get system validation status from the LLM Server. | |
""" | |
try: | |
response = await self.client.get("/api/v1/system/validate") | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
self.logger.error(f"Error validating system: {str(e)}") | |
raise | |
async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, str]: | |
""" | |
Initialize a model on the LLM Server. | |
""" | |
try: | |
response = await self.client.post( | |
"/api/v1/model/initialize", | |
params={"model_name": model_name} if model_name else None | |
) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
self.logger.error(f"Error initializing model: {str(e)}") | |
raise | |
async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, str]: | |
""" | |
Initialize an embedding model on the LLM Server. | |
""" | |
try: | |
response = await self.client.post( | |
"/api/v1/model/initialize/embedding", | |
params={"model_name": model_name} if model_name else None | |
) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
self.logger.error(f"Error initializing embedding model: {str(e)}") | |
raise | |
async def close(self): | |
"""Close the HTTP client session.""" | |
await self.client.aclose() |