Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Implements API for Qwen-7B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) | |
# Usage: python openai_api.py | |
# Visit http://localhost:8000/docs for documents. | |
from argparse import ArgumentParser | |
import time | |
import torch | |
import uvicorn | |
from pydantic import BaseModel, Field | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from contextlib import asynccontextmanager | |
from typing import Any, Dict, List, Literal, Optional, Union | |
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM | |
from transformers.generation import GenerationConfig | |
from sse_starlette.sse import ServerSentEvent, EventSourceResponse | |
async def lifespan(app: FastAPI): # collects GPU memory | |
yield | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
app = FastAPI(lifespan=lifespan) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class ModelCard(BaseModel): | |
id: str | |
object: str = "model" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
owned_by: str = "owner" | |
root: Optional[str] = None | |
parent: Optional[str] = None | |
permission: Optional[list] = None | |
class ModelList(BaseModel): | |
object: str = "list" | |
data: List[ModelCard] = [] | |
class ChatMessage(BaseModel): | |
role: Literal["user", "assistant", "system"] | |
content: str | |
class DeltaMessage(BaseModel): | |
role: Optional[Literal["user", "assistant", "system"]] = None | |
content: Optional[str] = None | |
class ChatCompletionRequest(BaseModel): | |
model: str | |
messages: List[ChatMessage] | |
temperature: Optional[float] = None | |
top_p: Optional[float] = None | |
max_length: Optional[int] = None | |
stream: Optional[bool] = False | |
class ChatCompletionResponseChoice(BaseModel): | |
index: int | |
message: ChatMessage | |
finish_reason: Literal["stop", "length"] | |
class ChatCompletionResponseStreamChoice(BaseModel): | |
index: int | |
delta: DeltaMessage | |
finish_reason: Optional[Literal["stop", "length"]] | |
class ChatCompletionResponse(BaseModel): | |
model: str | |
object: Literal["chat.completion", "chat.completion.chunk"] | |
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] | |
created: Optional[int] = Field(default_factory=lambda: int(time.time())) | |
async def list_models(): | |
global model_args | |
model_card = ModelCard(id="gpt-3.5-turbo") | |
return ModelList(data=[model_card]) | |
async def create_chat_completion(request: ChatCompletionRequest): | |
global model, tokenizer | |
if request.messages[-1].role != "user": | |
raise HTTPException(status_code=400, detail="Invalid request") | |
query = request.messages[-1].content | |
prev_messages = request.messages[:-1] | |
# Temporarily, the system role does not work as expected. We advise that you write the setups for role-play in your query. | |
# if len(prev_messages) > 0 and prev_messages[0].role == "system": | |
# query = prev_messages.pop(0).content + query | |
history = [] | |
if len(prev_messages) % 2 == 0: | |
for i in range(0, len(prev_messages), 2): | |
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": | |
history.append([prev_messages[i].content, prev_messages[i+1].content]) | |
else: | |
raise HTTPException(status_code=400, detail="Invalid request.") | |
else: | |
raise HTTPException(status_code=400, detail="Invalid request.") | |
if request.stream: | |
generate = predict(query, history, request.model) | |
return EventSourceResponse(generate, media_type="text/event-stream") | |
response, _ = model.chat(tokenizer, query, history=history) | |
choice_data = ChatCompletionResponseChoice( | |
index=0, | |
message=ChatMessage(role="assistant", content=response), | |
finish_reason="stop" | |
) | |
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") | |
async def predict(query: str, history: List[List[str]], model_id: str): | |
global model, tokenizer | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=DeltaMessage(role="assistant"), | |
finish_reason=None | |
) | |
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | |
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
current_length = 0 | |
for new_response in model.chat_stream(tokenizer, query, history): | |
if len(new_response) == current_length: | |
continue | |
new_text = new_response[current_length:] | |
current_length = len(new_response) | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=DeltaMessage(content=new_text), | |
finish_reason=None | |
) | |
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | |
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=DeltaMessage(), | |
finish_reason="stop" | |
) | |
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | |
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
yield '[DONE]' | |
def _get_args(): | |
parser = ArgumentParser() | |
parser.add_argument("-c", "--checkpoint-path", type=str, default='QWen/QWen-7B-Chat', | |
help="Checkpoint name or path, default to %(default)r") | |
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only") | |
parser.add_argument("--server-port", type=int, default=8000, | |
help="Demo server port.") | |
parser.add_argument("--server-name", type=str, default="127.0.0.1", | |
help="Demo server name.") | |
args = parser.parse_args() | |
return args | |
if __name__ == "__main__": | |
args = _get_args() | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.checkpoint_path, trust_remote_code=True, resume_download=True, | |
) | |
if args.cpu_only: | |
device_map = "cpu" | |
else: | |
device_map = "auto" | |
model = AutoModelForCausalLM.from_pretrained( | |
args.checkpoint_path, | |
device_map=device_map, | |
trust_remote_code=True, | |
resume_download=True, | |
).eval() | |
model.generation_config = GenerationConfig.from_pretrained( | |
args.checkpoint_path, trust_remote_code=True, resume_download=True, | |
) | |
uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1) | |