gordonchan's picture
Upload 41 files
ca56e6a verified
raw
history blame
No virus
5.02 kB
import time
import uuid
from functools import partial
from typing import (
Dict,
Any,
AsyncIterator,
)
import anyio
from fastapi import APIRouter, Depends
from fastapi import HTTPException, Request
from loguru import logger
from openai.types.chat import (
ChatCompletionMessage,
ChatCompletion,
ChatCompletionChunk,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from sse_starlette import EventSourceResponse
from text_generation.types import StreamResponse, Response
from api.core.tgi import TGIEngine
from api.models import GENERATE_ENGINE
from api.utils.compat import model_dump
from api.utils.protocol import Role, ChatCompletionCreateParams
from api.utils.request import (
check_api_key,
handle_request,
get_event_publisher,
)
chat_router = APIRouter(prefix="/chat")
def get_engine():
yield GENERATE_ENGINE
@chat_router.post("/completions", dependencies=[Depends(check_api_key)])
async def create_chat_completion(
request: ChatCompletionCreateParams,
raw_request: Request,
engine: TGIEngine = Depends(get_engine),
):
if (not request.messages) or request.messages[-1]["role"] == Role.ASSISTANT:
raise HTTPException(status_code=400, detail="Invalid request")
request = await handle_request(request, engine.prompt_adapter.stop)
request.max_tokens = request.max_tokens or 512
prompt = engine.apply_chat_template(request.messages)
include = {
"temperature",
"best_of",
"repetition_penalty",
"typical_p",
"watermark",
}
params = model_dump(request, include=include)
params.update(
dict(
prompt=prompt,
do_sample=request.temperature > 1e-5,
max_new_tokens=request.max_tokens,
stop_sequences=request.stop,
top_p=request.top_p if request.top_p < 1.0 else 0.99,
)
)
logger.debug(f"==== request ====\n{params}")
request_id: str = f"chatcmpl-{str(uuid.uuid4())}"
if request.stream:
generator = engine.generate_stream(**params)
iterator = create_chat_completion_stream(generator, params, request_id)
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial(
get_event_publisher,
request=raw_request,
inner_send_chan=send_chan,
iterator=iterator,
),
)
response: Response = await engine.generate(**params)
finish_reason = response.details.finish_reason.value
finish_reason = "length" if finish_reason == "length" else "stop"
message = ChatCompletionMessage(role="assistant", content=response.generated_text)
choice = Choice(
index=0,
message=message,
finish_reason=finish_reason,
logprobs=None,
)
num_prompt_tokens = len(response.details.prefill)
num_generated_tokens = response.details.generated_tokens
usage = CompletionUsage(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return ChatCompletion(
id=request_id,
choices=[choice],
created=int(time.time()),
model=request.model,
object="chat.completion",
usage=usage,
)
async def create_chat_completion_stream(
generator: AsyncIterator[StreamResponse], params: Dict[str, Any], request_id: str
) -> AsyncIterator[ChatCompletionChunk]:
# First chunk with role
choice = ChunkChoice(
index=0,
delta=ChoiceDelta(role="assistant", content=""),
finish_reason=None,
logprobs=None,
)
yield ChatCompletionChunk(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="chat.completion.chunk",
)
async for output in generator:
output: StreamResponse
if output.token.special:
continue
choice = ChunkChoice(
index=0,
delta=ChoiceDelta(content=output.token.text),
finish_reason=None,
logprobs=None,
)
yield ChatCompletionChunk(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="chat.completion.chunk",
)
choice = ChunkChoice(
index=0,
delta=ChoiceDelta(),
finish_reason="stop",
logprobs=None,
)
yield ChatCompletionChunk(
id=request_id,
choices=[choice],
created=int(time.time()),
model=params.get("model", "llm"),
object="chat.completion.chunk",
)