gordonchan's picture
Upload 41 files
ca56e6a verified
import time
import traceback
import uuid
from functools import partial
from typing import (
Dict,
Any,
AsyncIterator,
)
import anyio
from fastapi import APIRouter, Depends
from fastapi import HTTPException, Request
from loguru import logger
from openai.types.chat import (
ChatCompletionMessage,
ChatCompletion,
ChatCompletionChunk,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
from openai.types.completion_usage import CompletionUsage
from sse_starlette import EventSourceResponse
from vllm.outputs import RequestOutput
from api.core.vllm_engine import VllmEngine
from api.models import GENERATE_ENGINE
from api.utils.compat import model_dump, model_parse
from api.utils.protocol import Role, ChatCompletionCreateParams
from api.utils.request import (
check_api_key,
handle_request,
get_event_publisher,
)
chat_router = APIRouter(prefix="/chat")
def get_engine():
yield GENERATE_ENGINE
@chat_router.post("/completions", dependencies=[Depends(check_api_key)])
async def create_chat_completion(
request: ChatCompletionCreateParams,
raw_request: Request,
engine: VllmEngine = Depends(get_engine),
):
if (not request.messages) or request.messages[-1]["role"] == Role.ASSISTANT:
raise HTTPException(status_code=400, detail="Invalid request")
request = await handle_request(request, engine.prompt_adapter.stop)
request.max_tokens = request.max_tokens or 512
params = model_dump(request, exclude={"messages"})
params.update(dict(prompt_or_messages=request.messages, echo=False))
logger.debug(f"==== request ====\n{params}")
request_id: str = f"chatcmpl-{str(uuid.uuid4())}"
generator = engine.generate(params, request_id)
if request.stream:
iterator = create_chat_completion_stream(generator, params, request_id)
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial(
get_event_publisher,
request=raw_request,
inner_send_chan=send_chan,
iterator=iterator,
),
)
else:
# Non-streaming response
final_res: RequestOutput = None
async for res in generator:
if raw_request is not None and await raw_request.is_disconnected():
await engine.model.abort(request_id)
return
final_res = res
assert final_res is not None
choices = []
functions = params.get("functions", None)
tools = params.get("tools", None)
for output in final_res.outputs:
output.text = output.text.replace("�", "")
finish_reason = output.finish_reason
function_call = None
if functions or tools:
try:
res, function_call = engine.prompt_adapter.parse_assistant_response(
output.text, functions, tools,
)
output.text = res
except Exception as e:
traceback.print_exc()
logger.warning("Failed to parse tool call")
if isinstance(function_call, dict) and "arguments" in function_call:
function_call = FunctionCall(**function_call)
message = ChatCompletionMessage(
role="assistant",
content=output.text,
function_call=function_call
)
finish_reason = "function_call"
elif isinstance(function_call, dict) and "function" in function_call:
finish_reason = "tool_calls"
tool_calls = [model_parse(ChatCompletionMessageToolCall, function_call)]
message = ChatCompletionMessage(
role="assistant",
content=output.text,
tool_calls=tool_calls,
)
else:
message = ChatCompletionMessage(role="assistant", content=output.text)
choices.append(
Choice(
index=output.index,
message=message,
finish_reason=finish_reason,
)
)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs)
usage = CompletionUsage(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return ChatCompletion(
id=request_id,
choices=choices,
created=int(time.time()),
model=request.model,
object="chat.completion",
usage=usage,
)
async def create_chat_completion_stream(generator: AsyncIterator, params: Dict[str, Any], request_id: str) -> AsyncIterator:
n = params.get("n", 1)
for i in range(n):
# First chunk with role
choice = ChunkChoice(
index=i,
delta=ChoiceDelta(role="assistant", content=""),
finish_reason=None,
logprobs=None,
)
yield ChatCompletionChunk(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="chat.completion.chunk",
)
previous_texts = [""] * n
previous_num_tokens = [0] * n
async for res in generator:
res: RequestOutput
for output in res.outputs:
i = output.index
output.text = output.text.replace("�", "")
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
choice = ChunkChoice(
index=i,
delta=ChoiceDelta(content=delta_text),
finish_reason=output.finish_reason,
logprobs=None,
)
yield ChatCompletionChunk(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="chat.completion.chunk",
)
if output.finish_reason is not None:
choice = ChunkChoice(
index=i,
delta=ChoiceDelta(),
finish_reason="stop",
logprobs=None,
)
yield ChatCompletionChunk(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="chat.completion.chunk",
)