gordonchan's picture
Upload 41 files
ca56e6a verified
from functools import partial
from typing import Iterator
import anyio
from fastapi import APIRouter, Depends, Request
from loguru import logger
from sse_starlette import EventSourceResponse
from starlette.concurrency import run_in_threadpool
from api.core.llama_cpp_engine import LlamaCppEngine
from api.llama_cpp_routes.utils import get_llama_cpp_engine
from api.utils.compat import model_dump
from api.utils.protocol import CompletionCreateParams
from api.utils.request import (
handle_request,
check_api_key,
get_event_publisher,
)
completion_router = APIRouter()
@completion_router.post("/completions", dependencies=[Depends(check_api_key)])
async def create_completion(
request: CompletionCreateParams,
raw_request: Request,
engine: LlamaCppEngine = Depends(get_llama_cpp_engine),
):
if isinstance(request.prompt, list):
assert len(request.prompt) <= 1
request.prompt = request.prompt[0] if len(request.prompt) > 0 else ""
request.max_tokens = request.max_tokens or 256
request = await handle_request(request, engine.stop)
include = {
"temperature",
"top_p",
"stream",
"stop",
"model",
"max_tokens",
"presence_penalty",
"frequency_penalty",
}
kwargs = model_dump(request, include=include)
logger.debug(f"==== request ====\n{kwargs}")
iterator_or_completion = await run_in_threadpool(engine.create_completion, **kwargs)
if isinstance(iterator_or_completion, Iterator):
# It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)
# If no exception was raised from first_response, we can assume that
# the iterator is valid, and we can use it to stream the response.
def iterator() -> Iterator:
yield first_response
yield from iterator_or_completion
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial(
get_event_publisher,
request=raw_request,
inner_send_chan=send_chan,
iterator=iterator(),
),
)
else:
return iterator_or_completion