Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
import sys | |
import time | |
import uvicorn | |
import requests | |
import asyncio | |
import logging | |
from pathlib import Path | |
from fastapi import FastAPI, Depends, HTTPException | |
from fastapi.responses import HTMLResponse | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from pydantic import BaseModel, Field | |
from typing import Union, List, Dict, Any | |
from sse_starlette.sse import EventSourceResponse, ServerSentEvent | |
from utils.logger import logger | |
from networks.message_streamer import MessageStreamer | |
from messagers.message_composer import MessageComposer | |
from mocks.stream_chat_mocker import stream_chat_mock | |
from fastapi.middleware.cors import CORSMiddleware | |
class EmbeddingResponseItem(BaseModel): | |
object: str = "embedding" | |
index: int | |
embedding: List[List[float]] | |
class EmbeddingResponse(BaseModel): | |
object: str = "list" | |
data: List[EmbeddingResponseItem] | |
model: str | |
usage: Dict[str, Any] | |
class ChatAPIApp: | |
def __init__(self): | |
self.app = FastAPI( | |
docs_url="/", | |
title="HuggingFace LLM API", | |
swagger_ui_parameters={"defaultModelsExpandDepth": -1}, | |
version="1.0", | |
) | |
self.setup_routes() | |
self.app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # You can specify specific origins here | |
allow_credentials=True, | |
allow_methods=["*"], # Or specify just the methods you need: ["GET", "POST"] | |
allow_headers=["*"], # Or specify headers you need | |
) | |
def get_available_models(self): | |
# https://platform.openai.com/docs/api-reference/models/list | |
# ANCHOR[id=available-models]: Available models | |
current_time = int(time.time()) | |
self.available_models = { | |
"object": "list", | |
"data": [ | |
{ | |
"id": "mixtral-8x7b", | |
"description": "[mistralai/Mixtral-8x7B-Instruct-v0.1]: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1", | |
"object": "model", | |
"created": current_time, | |
"owned_by": "mistralai", | |
}, | |
{ | |
"id": "mistral-7b", | |
"description": "[mistralai/Mistral-7B-Instruct-v0.2]: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2", | |
"object": "model", | |
"created": current_time, | |
"owned_by": "mistralai", | |
}, | |
{ | |
"id": "nous-mixtral-8x7b", | |
"description": "[NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO]: https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", | |
"object": "model", | |
"created": current_time, | |
"owned_by": "NousResearch", | |
}, | |
{ | |
"id": "gemma-7b", | |
"description": "[google/gemma-7b-it]: https://huggingface.co/google/gemma-7b-it", | |
"object": "model", | |
"created": current_time, | |
"owned_by": "Google", | |
}, | |
{ | |
"id": "codellama-7b", | |
"description": "[codellama/CodeLlama-7b-hf]: https://huggingface.co/codellama/CodeLlama-7b-hf", | |
"object": "model", | |
"created": current_time, | |
"owned_by": "codellama", | |
}, | |
{ | |
"id": "bert-base-uncased", | |
"description": "[google-bert/bert-base-uncased]: https://huggingface.co/google-bert/bert-base-uncased", | |
"object": "embedding", | |
"created": current_time, | |
"owned_by": "google", | |
}, | |
], | |
} | |
return self.available_models | |
def extract_api_key( | |
credentials: HTTPAuthorizationCredentials = Depends( | |
HTTPBearer(auto_error=False) | |
), | |
): | |
api_key = None | |
if credentials: | |
api_key = credentials.credentials | |
else: | |
api_key = os.getenv("HF_TOKEN") | |
if api_key: | |
if api_key.startswith("hf_"): | |
return api_key | |
else: | |
logger.warn(f"Invalid HF Token!") | |
else: | |
logger.warn("Not provide HF Token!") | |
return None | |
class QueryRequest(BaseModel): | |
input: str | |
model: str = Field(default="bert-base-uncased") | |
encoding_format: str | |
class ChatCompletionsPostItem(BaseModel): | |
model: str = Field( | |
default="mixtral-8x7b", | |
description="(str) `mixtral-8x7b`", | |
) | |
messages: list = Field( | |
default=[{"role": "user", "content": "Hello, who are you?"}], | |
description="(list) Messages", | |
) | |
temperature: Union[float, None] = Field( | |
default=0.5, | |
description="(float) Temperature", | |
) | |
top_p: Union[float, None] = Field( | |
default=0.95, | |
description="(float) top p", | |
) | |
max_tokens: Union[int, None] = Field( | |
default=-1, | |
description="(int) Max tokens", | |
) | |
use_cache: bool = Field( | |
default=False, | |
description="(bool) Use cache", | |
) | |
stream: bool = Field( | |
default=False, | |
description="(bool) Stream", | |
) | |
def chat_completions( | |
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key) | |
): | |
streamer = MessageStreamer(model=item.model) | |
composer = MessageComposer(model=item.model) | |
composer.merge(messages=item.messages) | |
# streamer.chat = stream_chat_mock | |
stream_response = streamer.chat_response( | |
prompt=composer.merged_str, | |
temperature=item.temperature, | |
top_p=item.top_p, | |
max_new_tokens=item.max_tokens, | |
api_key=api_key, | |
use_cache=item.use_cache, | |
) | |
if item.stream: | |
event_source_response = EventSourceResponse( | |
streamer.chat_return_generator(stream_response), | |
media_type="text/event-stream", | |
ping=2000, | |
ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}), | |
) | |
return event_source_response | |
else: | |
data_response = streamer.chat_return_dict(stream_response) | |
return data_response | |
async def embedding(self, request: QueryRequest, api_key: str = Depends(extract_api_key)): | |
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{request.model}" | |
headers = {"Authorization": f"Bearer {api_key}"} | |
response = await requests.post(api_url, headers=headers, json={"inputs": request.input}) | |
result = response.json() | |
if "error" in result: | |
logging.error(f"Error from Hugging Face API: {result['error']}") | |
error_detail = result.get('error', 'No detailed error message provided.') | |
raise HTTPException(status_code=503, detail=f"The model is currently loading, please re-run the query. Detail: {error_detail}") | |
if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list): | |
flattened_embeddings = [item for sublist in result for item in sublist] # Flatten list of lists | |
data = [{"object": "embedding", "index": i, "embedding": embedding} for i, embedding in enumerate(flattened_embeddings)] | |
return EmbeddingResponse( | |
object="list", | |
data=data, | |
model=request.model, | |
usage={"prompt_tokens": len(request.input), "total_tokens": len(request.input)} | |
) | |
else: | |
logging.error(f"Unexpected response format: {result}") | |
raise HTTPException(status_code=500, detail="Unexpected response format.") | |
def setup_routes(self): | |
for prefix in ["", "/v1", "/api", "/api/v1"]: | |
if prefix in ["/api/v1"]: | |
include_in_schema = True | |
else: | |
include_in_schema = False | |
self.app.get( | |
prefix + "/models", | |
summary="Get available models", | |
include_in_schema=include_in_schema, | |
)(self.get_available_models) | |
self.app.post( | |
prefix + "/chat/completions", | |
summary="Chat completions in conversation session", | |
include_in_schema=include_in_schema, | |
)(self.chat_completions) | |
self.app.post( | |
prefix + "/embeddings", # Use the specific prefix for this route | |
summary="Generate embeddings for the given texts", | |
include_in_schema=include_in_schema, | |
response_model=EmbeddingResponse # Adapt based on your actual response model | |
)(self.embedding) | |
class ArgParser(argparse.ArgumentParser): | |
def __init__(self, *args, **kwargs): | |
super(ArgParser, self).__init__(*args, **kwargs) | |
self.add_argument( | |
"-s", | |
"--server", | |
type=str, | |
default="0.0.0.0", | |
help="Server IP for HF LLM Chat API", | |
) | |
self.add_argument( | |
"-p", | |
"--port", | |
type=int, | |
default=23333, | |
help="Server Port for HF LLM Chat API", | |
) | |
self.add_argument( | |
"-d", | |
"--dev", | |
default=False, | |
action="store_true", | |
help="Run in dev mode", | |
) | |
self.args = self.parse_args(sys.argv[1:]) | |
app = ChatAPIApp().app | |
if __name__ == "__main__": | |
args = ArgParser().args | |
if args.dev: | |
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) | |
else: | |
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) | |
# python -m apis.chat_api # [Docker] on product mode | |
# python -m apis.chat_api -d # [Dev] on develop mode |