Spaces:
Running
Running
from functools import partial | |
from typing import Iterator | |
import anyio | |
from fastapi import APIRouter, Depends, HTTPException, Request | |
from loguru import logger | |
from sse_starlette import EventSourceResponse | |
from starlette.concurrency import run_in_threadpool | |
from api.core.default import DefaultEngine | |
from api.models import GENERATE_ENGINE | |
from api.utils.compat import model_dump | |
from api.utils.protocol import CompletionCreateParams | |
from api.utils.request import ( | |
handle_request, | |
check_api_key, | |
get_event_publisher, | |
) | |
completion_router = APIRouter() | |
def get_engine(): | |
yield GENERATE_ENGINE | |
async def create_completion( | |
request: CompletionCreateParams, | |
raw_request: Request, | |
engine: DefaultEngine = Depends(get_engine), | |
): | |
if isinstance(request.prompt, str): | |
request.prompt = [request.prompt] | |
if len(request.prompt) < 1: | |
raise HTTPException(status_code=400, detail="Invalid request") | |
request = await handle_request(request, engine.stop, chat=False) | |
request.max_tokens = request.max_tokens or 128 | |
params = model_dump(request, exclude={"prompt"}) | |
params.update(dict(prompt_or_messages=request.prompt[0])) | |
logger.debug(f"==== request ====\n{params}") | |
iterator_or_completion = await run_in_threadpool(engine.create_completion, params) | |
if isinstance(iterator_or_completion, Iterator): | |
# It's easier to ask for forgiveness than permission | |
first_response = await run_in_threadpool(next, iterator_or_completion) | |
# If no exception was raised from first_response, we can assume that | |
# the iterator is valid, and we can use it to stream the response. | |
def iterator() -> Iterator: | |
yield first_response | |
yield from iterator_or_completion | |
send_chan, recv_chan = anyio.create_memory_object_stream(10) | |
return EventSourceResponse( | |
recv_chan, | |
data_sender_callable=partial( | |
get_event_publisher, | |
request=raw_request, | |
inner_send_chan=send_chan, | |
iterator=iterator(), | |
), | |
) | |
else: | |
return iterator_or_completion | |