Spaces:
Build error
Build error
import uvicorn | |
from fastapi import FastAPI | |
from pydantic import BaseModel, Field | |
from sse_starlette.sse import EventSourceResponse | |
from utils.logger import logger | |
from networks.message_streamer import MessageStreamer | |
from messagers.message_composer import MessageComposer | |
class ChatAPIApp: | |
def __init__(self): | |
self.app = FastAPI( | |
docs_url="/", | |
title="HuggingFace LLM API", | |
swagger_ui_parameters={"defaultModelsExpandDepth": -1}, | |
version="1.0", | |
) | |
self.setup_routes() | |
def get_available_models(self): | |
self.available_models = [ | |
{ | |
"id": "mixtral-8x7b", | |
"description": "Mixtral-8x7B: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1", | |
}, | |
] | |
return self.available_models | |
class ChatCompletionsPostItem(BaseModel): | |
model: str = Field( | |
default="mixtral-8x7b", | |
description="(str) `mixtral-8x7b`", | |
) | |
messages: list = Field( | |
default=[{"role": "user", "content": "Hello, who are you?"}], | |
description="(list) Messages", | |
) | |
temperature: float = Field( | |
default=0.01, | |
description="(float) Temperature", | |
) | |
max_tokens: int = Field( | |
default=32000, | |
description="(int) Max tokens", | |
) | |
stream: bool = Field( | |
default=True, | |
description="(bool) Stream", | |
) | |
def chat_completions(self, item: ChatCompletionsPostItem): | |
streamer = MessageStreamer(model=item.model) | |
composer = MessageComposer(model=item.model) | |
composer.merge(messages=item.messages) | |
return EventSourceResponse( | |
streamer.chat( | |
prompt=composer.merged_str, | |
temperature=item.temperature, | |
max_new_tokens=item.max_tokens, | |
stream=item.stream, | |
yield_output=True, | |
), | |
media_type="text/event-stream", | |
) | |
def setup_routes(self): | |
for prefix in ["", "/v1"]: | |
self.app.get( | |
prefix + "/models", | |
summary="Get available models", | |
)(self.get_available_models) | |
self.app.post( | |
prefix + "/chat/completions", | |
summary="Chat completions in conversation session", | |
)(self.chat_completions) | |
app = ChatAPIApp().app | |
if __name__ == "__main__": | |
uvicorn.run("__main__:app", host="0.0.0.0", port=23333, reload=True) | |