import argparse import os import sys import time import uvicorn import requests import asyncio import logging from pathlib import Path from fastapi import FastAPI, Depends, HTTPException from fastapi.responses import HTMLResponse from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel, Field from typing import Union, List, Dict, Any from sse_starlette.sse import EventSourceResponse, ServerSentEvent from utils.logger import logger from networks.message_streamer import MessageStreamer from messagers.message_composer import MessageComposer from mocks.stream_chat_mocker import stream_chat_mock from fastapi.middleware.cors import CORSMiddleware class EmbeddingResponseItem(BaseModel): object: str = "embedding" index: int embedding: List[List[float]] class EmbeddingResponse(BaseModel): object: str = "list" data: List[EmbeddingResponseItem] model: str usage: Dict[str, Any] class ChatAPIApp: def __init__(self): self.app = FastAPI( docs_url="/", title="HuggingFace LLM API", swagger_ui_parameters={"defaultModelsExpandDepth": -1}, version="1.0", ) self.setup_routes() self.app.add_middleware( CORSMiddleware, allow_origins=["*"], # You can specify specific origins here allow_credentials=True, allow_methods=["*"], # Or specify just the methods you need: ["GET", "POST"] allow_headers=["*"], # Or specify headers you need ) def get_available_models(self): # https://platform.openai.com/docs/api-reference/models/list # ANCHOR[id=available-models]: Available models current_time = int(time.time()) self.available_models = { "object": "list", "data": [ { "id": "mixtral-8x7b", "description": "[mistralai/Mixtral-8x7B-Instruct-v0.1]: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1", "object": "model", "created": current_time, "owned_by": "mistralai", }, { "id": "mistral-7b", "description": "[mistralai/Mistral-7B-Instruct-v0.2]: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2", "object": "model", "created": current_time, "owned_by": "mistralai", }, { "id": "nous-mixtral-8x7b", "description": "[NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO]: https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", "object": "model", "created": current_time, "owned_by": "NousResearch", }, { "id": "gemma-7b", "description": "[google/gemma-7b-it]: https://huggingface.co/google/gemma-7b-it", "object": "model", "created": current_time, "owned_by": "Google", }, { "id": "codellama-7b", "description": "[codellama/CodeLlama-7b-hf]: https://huggingface.co/codellama/CodeLlama-7b-hf", "object": "model", "created": current_time, "owned_by": "codellama", }, { "id": "bert-base-uncased", "description": "[google-bert/bert-base-uncased]: https://huggingface.co/google-bert/bert-base-uncased", "object": "embedding", "created": current_time, "owned_by": "google", }, ], } return self.available_models def extract_api_key( credentials: HTTPAuthorizationCredentials = Depends( HTTPBearer(auto_error=False) ), ): api_key = None if credentials: api_key = credentials.credentials else: api_key = os.getenv("HF_TOKEN") if api_key: if api_key.startswith("hf_"): return api_key else: logger.warn(f"Invalid HF Token!") else: logger.warn("Not provide HF Token!") return None class QueryRequest(BaseModel): input: str model: str = Field(default="bert-base-uncased") encoding_format: str class ChatCompletionsPostItem(BaseModel): model: str = Field( default="mixtral-8x7b", description="(str) `mixtral-8x7b`", ) messages: list = Field( default=[{"role": "user", "content": "Hello, who are you?"}], description="(list) Messages", ) temperature: Union[float, None] = Field( default=0.5, description="(float) Temperature", ) top_p: Union[float, None] = Field( default=0.95, description="(float) top p", ) max_tokens: Union[int, None] = Field( default=-1, description="(int) Max tokens", ) use_cache: bool = Field( default=False, description="(bool) Use cache", ) stream: bool = Field( default=False, description="(bool) Stream", ) def chat_completions( self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key) ): streamer = MessageStreamer(model=item.model) composer = MessageComposer(model=item.model) composer.merge(messages=item.messages) # streamer.chat = stream_chat_mock stream_response = streamer.chat_response( prompt=composer.merged_str, temperature=item.temperature, top_p=item.top_p, max_new_tokens=item.max_tokens, api_key=api_key, use_cache=item.use_cache, ) if item.stream: event_source_response = EventSourceResponse( streamer.chat_return_generator(stream_response), media_type="text/event-stream", ping=2000, ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}), ) return event_source_response else: data_response = streamer.chat_return_dict(stream_response) return data_response async def embedding(self, request: QueryRequest, api_key: str = Depends(extract_api_key)): api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{request.model}" headers = {"Authorization": f"Bearer {api_key}"} response = await requests.post(api_url, headers=headers, json={"inputs": request.input}) result = response.json() if "error" in result: logging.error(f"Error from Hugging Face API: {result['error']}") error_detail = result.get('error', 'No detailed error message provided.') raise HTTPException(status_code=503, detail=f"The model is currently loading, please re-run the query. Detail: {error_detail}") if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list): flattened_embeddings = [item for sublist in result for item in sublist] # Flatten list of lists data = [{"object": "embedding", "index": i, "embedding": embedding} for i, embedding in enumerate(flattened_embeddings)] return EmbeddingResponse( object="list", data=data, model=request.model, usage={"prompt_tokens": len(request.input), "total_tokens": len(request.input)} ) else: logging.error(f"Unexpected response format: {result}") raise HTTPException(status_code=500, detail="Unexpected response format.") def setup_routes(self): for prefix in ["", "/v1", "/api", "/api/v1"]: if prefix in ["/api/v1"]: include_in_schema = True else: include_in_schema = False self.app.get( prefix + "/models", summary="Get available models", include_in_schema=include_in_schema, )(self.get_available_models) self.app.post( prefix + "/chat/completions", summary="Chat completions in conversation session", include_in_schema=include_in_schema, )(self.chat_completions) self.app.post( prefix + "/embeddings", # Use the specific prefix for this route summary="Generate embeddings for the given texts", include_in_schema=include_in_schema, response_model=EmbeddingResponse # Adapt based on your actual response model )(self.embedding) class ArgParser(argparse.ArgumentParser): def __init__(self, *args, **kwargs): super(ArgParser, self).__init__(*args, **kwargs) self.add_argument( "-s", "--server", type=str, default="0.0.0.0", help="Server IP for HF LLM Chat API", ) self.add_argument( "-p", "--port", type=int, default=23333, help="Server Port for HF LLM Chat API", ) self.add_argument( "-d", "--dev", default=False, action="store_true", help="Run in dev mode", ) self.args = self.parse_args(sys.argv[1:]) app = ChatAPIApp().app if __name__ == "__main__": args = ArgParser().args if args.dev: uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) else: uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) # python -m apis.chat_api # [Docker] on product mode # python -m apis.chat_api -d # [Dev] on develop mode