import time import uuid from functools import partial from typing import ( Dict, Any, AsyncIterator, ) import anyio from fastapi import APIRouter, Depends from fastapi import Request from loguru import logger from openai.types.completion import Completion from openai.types.completion_choice import CompletionChoice from openai.types.completion_usage import CompletionUsage from sse_starlette import EventSourceResponse from text_generation.types import Response, StreamResponse from api.core.tgi import TGIEngine from api.models import GENERATE_ENGINE from api.utils.compat import model_dump from api.utils.protocol import CompletionCreateParams from api.utils.request import ( handle_request, get_event_publisher, check_api_key ) completion_router = APIRouter() def get_engine(): yield GENERATE_ENGINE @completion_router.post("/completions", dependencies=[Depends(check_api_key)]) async def create_completion( request: CompletionCreateParams, raw_request: Request, engine: TGIEngine = Depends(get_engine), ): """ Completion API similar to OpenAI's API. """ request.max_tokens = request.max_tokens or 128 request = await handle_request(request, engine.prompt_adapter.stop, chat=False) if isinstance(request.prompt, list): request.prompt = request.prompt[0] request_id: str = f"cmpl-{str(uuid.uuid4())}" include = { "temperature", "best_of", "repetition_penalty", "typical_p", "watermark", } params = model_dump(request, include=include) params.update( dict( prompt=request.prompt, do_sample=request.temperature > 1e-5, max_new_tokens=request.max_tokens, stop_sequences=request.stop, top_p=request.top_p if request.top_p < 1.0 else 0.99, return_full_text=request.echo, ) ) logger.debug(f"==== request ====\n{params}") if request.stream: generator = engine.generate_stream(**params) iterator = create_completion_stream(generator, params, request_id) send_chan, recv_chan = anyio.create_memory_object_stream(10) return EventSourceResponse( recv_chan, data_sender_callable=partial( get_event_publisher, request=raw_request, inner_send_chan=send_chan, iterator=iterator, ), ) # Non-streaming response response: Response = await engine.generate(**params) finish_reason = response.details.finish_reason.value finish_reason = "length" if finish_reason == "length" else "stop" choice = CompletionChoice( index=0, text=response.generated_text, finish_reason=finish_reason, logprobs=None, ) num_prompt_tokens = len(response.details.prefill) num_generated_tokens = response.details.generated_tokens usage = CompletionUsage( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, ) return Completion( id=request_id, choices=[choice], created=int(time.time()), model=params.get("model", "llm"), object="text_completion", usage=usage, ) async def create_completion_stream( generator: AsyncIterator[StreamResponse], params: Dict[str, Any], request_id: str, ) -> AsyncIterator[Completion]: async for output in generator: output: StreamResponse if output.token.special: continue choice = CompletionChoice( index=0, text=output.token.text, finish_reason="stop", logprobs=None, ) yield Completion( id=request_id, choices=[choice], created=int(time.time()), model=params.get("model", "llm"), object="text_completion", )