Spaces:
Running
Running
import time | |
import traceback | |
import uuid | |
from functools import partial | |
from typing import ( | |
Dict, | |
Any, | |
AsyncIterator, | |
) | |
import anyio | |
from fastapi import APIRouter, Depends | |
from fastapi import HTTPException, Request | |
from loguru import logger | |
from openai.types.chat import ( | |
ChatCompletionMessage, | |
ChatCompletion, | |
ChatCompletionChunk, | |
) | |
from openai.types.chat.chat_completion import Choice | |
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice | |
from openai.types.chat.chat_completion_chunk import ChoiceDelta | |
from openai.types.chat.chat_completion_message import FunctionCall | |
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall | |
from openai.types.completion_usage import CompletionUsage | |
from sse_starlette import EventSourceResponse | |
from vllm.outputs import RequestOutput | |
from api.core.vllm_engine import VllmEngine | |
from api.models import GENERATE_ENGINE | |
from api.utils.compat import model_dump, model_parse | |
from api.utils.protocol import Role, ChatCompletionCreateParams | |
from api.utils.request import ( | |
check_api_key, | |
handle_request, | |
get_event_publisher, | |
) | |
chat_router = APIRouter(prefix="/chat") | |
def get_engine(): | |
yield GENERATE_ENGINE | |
async def create_chat_completion( | |
request: ChatCompletionCreateParams, | |
raw_request: Request, | |
engine: VllmEngine = Depends(get_engine), | |
): | |
if (not request.messages) or request.messages[-1]["role"] == Role.ASSISTANT: | |
raise HTTPException(status_code=400, detail="Invalid request") | |
request = await handle_request(request, engine.prompt_adapter.stop) | |
request.max_tokens = request.max_tokens or 512 | |
params = model_dump(request, exclude={"messages"}) | |
params.update(dict(prompt_or_messages=request.messages, echo=False)) | |
logger.debug(f"==== request ====\n{params}") | |
request_id: str = f"chatcmpl-{str(uuid.uuid4())}" | |
generator = engine.generate(params, request_id) | |
if request.stream: | |
iterator = create_chat_completion_stream(generator, params, request_id) | |
send_chan, recv_chan = anyio.create_memory_object_stream(10) | |
return EventSourceResponse( | |
recv_chan, | |
data_sender_callable=partial( | |
get_event_publisher, | |
request=raw_request, | |
inner_send_chan=send_chan, | |
iterator=iterator, | |
), | |
) | |
else: | |
# Non-streaming response | |
final_res: RequestOutput = None | |
async for res in generator: | |
if raw_request is not None and await raw_request.is_disconnected(): | |
await engine.model.abort(request_id) | |
return | |
final_res = res | |
assert final_res is not None | |
choices = [] | |
functions = params.get("functions", None) | |
tools = params.get("tools", None) | |
for output in final_res.outputs: | |
output.text = output.text.replace("�", "") | |
finish_reason = output.finish_reason | |
function_call = None | |
if functions or tools: | |
try: | |
res, function_call = engine.prompt_adapter.parse_assistant_response( | |
output.text, functions, tools, | |
) | |
output.text = res | |
except Exception as e: | |
traceback.print_exc() | |
logger.warning("Failed to parse tool call") | |
if isinstance(function_call, dict) and "arguments" in function_call: | |
function_call = FunctionCall(**function_call) | |
message = ChatCompletionMessage( | |
role="assistant", | |
content=output.text, | |
function_call=function_call | |
) | |
finish_reason = "function_call" | |
elif isinstance(function_call, dict) and "function" in function_call: | |
finish_reason = "tool_calls" | |
tool_calls = [model_parse(ChatCompletionMessageToolCall, function_call)] | |
message = ChatCompletionMessage( | |
role="assistant", | |
content=output.text, | |
tool_calls=tool_calls, | |
) | |
else: | |
message = ChatCompletionMessage(role="assistant", content=output.text) | |
choices.append( | |
Choice( | |
index=output.index, | |
message=message, | |
finish_reason=finish_reason, | |
) | |
) | |
num_prompt_tokens = len(final_res.prompt_token_ids) | |
num_generated_tokens = sum(len(output.token_ids) for output in final_res.outputs) | |
usage = CompletionUsage( | |
prompt_tokens=num_prompt_tokens, | |
completion_tokens=num_generated_tokens, | |
total_tokens=num_prompt_tokens + num_generated_tokens, | |
) | |
return ChatCompletion( | |
id=request_id, | |
choices=choices, | |
created=int(time.time()), | |
model=request.model, | |
object="chat.completion", | |
usage=usage, | |
) | |
async def create_chat_completion_stream(generator: AsyncIterator, params: Dict[str, Any], request_id: str) -> AsyncIterator: | |
n = params.get("n", 1) | |
for i in range(n): | |
# First chunk with role | |
choice = ChunkChoice( | |
index=i, | |
delta=ChoiceDelta(role="assistant", content=""), | |
finish_reason=None, | |
logprobs=None, | |
) | |
yield ChatCompletionChunk( | |
id=request_id, | |
choices=[choice], | |
created=int(time.time()), | |
model=params.get("model", "llm"), | |
object="chat.completion.chunk", | |
) | |
previous_texts = [""] * n | |
previous_num_tokens = [0] * n | |
async for res in generator: | |
res: RequestOutput | |
for output in res.outputs: | |
i = output.index | |
output.text = output.text.replace("�", "") | |
delta_text = output.text[len(previous_texts[i]):] | |
previous_texts[i] = output.text | |
previous_num_tokens[i] = len(output.token_ids) | |
choice = ChunkChoice( | |
index=i, | |
delta=ChoiceDelta(content=delta_text), | |
finish_reason=output.finish_reason, | |
logprobs=None, | |
) | |
yield ChatCompletionChunk( | |
id=request_id, | |
choices=[choice], | |
created=int(time.time()), | |
model=params.get("model", "llm"), | |
object="chat.completion.chunk", | |
) | |
if output.finish_reason is not None: | |
choice = ChunkChoice( | |
index=i, | |
delta=ChoiceDelta(), | |
finish_reason="stop", | |
logprobs=None, | |
) | |
yield ChatCompletionChunk( | |
id=request_id, | |
choices=[choice], | |
created=int(time.time()), | |
model=params.get("model", "llm"), | |
object="chat.completion.chunk", | |
) | |