Spaces:
Sleeping
Sleeping
| import argparse | |
| import markdown2 | |
| import os | |
| import sys | |
| import uvicorn | |
| from pathlib import Path | |
| from fastapi import FastAPI, Depends, HTTPException | |
| from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel, Field | |
| from typing import Union | |
| from sse_starlette.sse import EventSourceResponse, ServerSentEvent | |
| from passlib.context import CryptContext | |
| from utils.logger import logger | |
| from networks.message_streamer import MessageStreamer | |
| from messagers.message_composer import MessageComposer | |
| from mocks.stream_chat_mocker import stream_chat_mock | |
| # Create FastAPI app | |
| app = FastAPI(docs_url=None, redoc_url=None) | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
| class Auth(BaseModel): | |
| api_key: str | |
| password: str | |
| # Password hashing context | |
| pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
| def get_password_hash(password): | |
| return pwd_context.hash(password) | |
| def verify_password(plain_password, hashed_password): | |
| return pwd_context.verify(plain_password, hashed_password) | |
| # Fetch API key and password from environment variables | |
| api_key = os.getenv("XCHE_API_KEY") | |
| password = os.getenv("XCHE_PASSWORD") | |
| # Pre-hash the password and store it in a fake database | |
| fake_data_db = { | |
| api_key: { | |
| "api_key": api_key, | |
| "password": get_password_hash(password) # Pre-hashed password | |
| } | |
| } | |
| def get_api_key(db, api_key: str): | |
| if api_key in db: | |
| api_dict = db[api_key] | |
| return Auth(**api_dict) | |
| def authenticate(fake_db, api_key: str, password: str): | |
| api_data = get_api_key(fake_db, api_key) | |
| if not api_data: | |
| return False | |
| if not verify_password(password, api_data.password): | |
| return False | |
| return api_data | |
| async def login(form_data: OAuth2PasswordRequestForm = Depends()): | |
| api_data = authenticate(fake_data_db, form_data.username, form_data.password) | |
| if not api_data: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Incorrect API KEY or Password", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| return {"access_token": api_data.api_key, "token_type": "bearer"} | |
| def check_api_token(token: str = Depends(oauth2_scheme)): | |
| api_data = get_api_key(fake_data_db, token) | |
| if not api_data: | |
| raise HTTPException(status_code=403, detail="Invalid or missing API Key") | |
| return api_data | |
| class ChatAPIApp: | |
| def __init__(self): | |
| self.app = app | |
| self.setup_routes() | |
| # "mixtral-8x7b" | |
| # "mistral - 7b" | |
| # "nous-mixtral-8x7b" | |
| # "zephyr-7b-beta" | |
| # "starchat2-15b-v0.1" | |
| class ChatCompletionsPostItem(BaseModel): | |
| model: str = Field( | |
| default="mixtral-8x7b", | |
| description="(str) `mixtral-8x7b`", | |
| ) | |
| messages: list = Field( | |
| default=[{"role": "user", "content": "Hello, who are you?"}], | |
| description="(list) Messages", | |
| ) | |
| temperature: Union[float, None] = Field( | |
| default=0.5, | |
| description="(float) Temperature", | |
| ) | |
| top_p: Union[float, None] = Field( | |
| default=0.95, | |
| description="(float) top p", | |
| ) | |
| max_tokens: Union[int, None] = Field( | |
| default=-1, | |
| description="(int) Max tokens", | |
| ) | |
| use_cache: bool = Field( | |
| default=False, | |
| description="(bool) Use cache", | |
| ) | |
| stream: bool = Field( | |
| default=True, | |
| description="(bool) Stream", | |
| ) | |
| def chat_completions(self, item: ChatCompletionsPostItem): | |
| streamer = MessageStreamer(model=item.model) | |
| composer = MessageComposer(model=item.model) | |
| composer.merge(messages=item.messages) | |
| # streamer.chat = stream_chat_mock | |
| stream_response = streamer.chat_response( | |
| prompt=composer.merged_str, | |
| temperature=item.temperature, | |
| top_p=item.top_p, | |
| max_new_tokens=item.max_tokens, | |
| use_cache=item.use_cache, | |
| ) | |
| if item.stream: | |
| event_source_response = EventSourceResponse( | |
| streamer.chat_return_generator(stream_response), | |
| media_type="text/event-stream", | |
| ping=2000, | |
| ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}), | |
| ) | |
| return event_source_response | |
| else: | |
| data_response = streamer.chat_return_dict(stream_response) | |
| return data_response | |
| def setup_routes(self): | |
| for prefix in ["", "/v1", "/api", "/api/v1"]: | |
| if prefix == "/api/v1": | |
| include_in_schema = True | |
| else: | |
| include_in_schema = False | |
| self.app.post( | |
| prefix + "/chat/completions", | |
| summary="Chat completions in conversation session", | |
| include_in_schema=include_in_schema, | |
| )(self.chat_completions) | |
| class ArgParser(argparse.ArgumentParser): | |
| def __init__(self, *args, **kwargs): | |
| super(ArgParser, self).__init__(*args, **kwargs) | |
| self.add_argument( | |
| "-s", | |
| "--server", | |
| type=str, | |
| default="0.0.0.0", | |
| help="Server IP for HF LLM Chat API", | |
| ) | |
| self.add_argument( | |
| "-p", | |
| "--port", | |
| type=int, | |
| default=7860, | |
| help="Server Port for HF LLM Chat API", | |
| ) | |
| self.add_argument( | |
| "-d", | |
| "--dev", | |
| default=False, | |
| action="store_true", | |
| help="Run in dev mode", | |
| ) | |
| self.args = self.parse_args(sys.argv[1:]) | |
| app = ChatAPIApp().app | |
| if __name__ == "__main__": | |
| args = ArgParser().args | |
| if args.dev: | |
| uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) | |
| else: | |
| uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) |