Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import sys | |
import traceback | |
import time | |
from re import compile, Match, Pattern | |
from typing import Callable, Coroutine, Optional, Tuple, Union, Dict | |
from typing_extensions import TypedDict | |
from fastapi import ( | |
Request, | |
Response, | |
HTTPException, | |
) | |
from fastapi.responses import JSONResponse | |
from fastapi.routing import APIRoute | |
from llama_cpp.server.types import ( | |
CreateCompletionRequest, | |
CreateEmbeddingRequest, | |
CreateChatCompletionRequest, | |
) | |
class ErrorResponse(TypedDict): | |
"""OpenAI style error response""" | |
message: str | |
type: str | |
param: Optional[str] | |
code: Optional[str] | |
class ErrorResponseFormatters: | |
"""Collection of formatters for error responses. | |
Args: | |
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]): | |
Request body | |
match (Match[str]): Match object from regex pattern | |
Returns: | |
Tuple[int, ErrorResponse]: Status code and error response | |
""" | |
def context_length_exceeded( | |
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], | |
match, # type: Match[str] # type: ignore | |
) -> Tuple[int, ErrorResponse]: | |
"""Formatter for context length exceeded error""" | |
context_window = int(match.group(2)) | |
prompt_tokens = int(match.group(1)) | |
completion_tokens = request.max_tokens | |
if hasattr(request, "messages"): | |
# Chat completion | |
message = ( | |
"This model's maximum context length is {} tokens. " | |
"However, you requested {} tokens " | |
"({} in the messages, {} in the completion). " | |
"Please reduce the length of the messages or completion." | |
) | |
else: | |
# Text completion | |
message = ( | |
"This model's maximum context length is {} tokens, " | |
"however you requested {} tokens " | |
"({} in your prompt; {} for the completion). " | |
"Please reduce your prompt; or completion length." | |
) | |
return 400, ErrorResponse( | |
message=message.format( | |
context_window, | |
(completion_tokens or 0) + prompt_tokens, | |
prompt_tokens, | |
completion_tokens, | |
), # type: ignore | |
type="invalid_request_error", | |
param="messages", | |
code="context_length_exceeded", | |
) | |
def model_not_found( | |
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], | |
match, # type: Match[str] # type: ignore | |
) -> Tuple[int, ErrorResponse]: | |
"""Formatter for model_not_found error""" | |
model_path = str(match.group(1)) | |
message = f"The model `{model_path}` does not exist" | |
return 400, ErrorResponse( | |
message=message, | |
type="invalid_request_error", | |
param=None, | |
code="model_not_found", | |
) | |
class RouteErrorHandler(APIRoute): | |
"""Custom APIRoute that handles application errors and exceptions""" | |
# key: regex pattern for original error message from llama_cpp | |
# value: formatter function | |
pattern_and_formatters: Dict[ | |
"Pattern[str]", | |
Callable[ | |
[ | |
Union["CreateCompletionRequest", "CreateChatCompletionRequest"], | |
"Match[str]", | |
], | |
Tuple[int, ErrorResponse], | |
], | |
] = { | |
compile( | |
r"Requested tokens \((\d+)\) exceed context window of (\d+)" | |
): ErrorResponseFormatters.context_length_exceeded, | |
compile( | |
r"Model path does not exist: (.+)" | |
): ErrorResponseFormatters.model_not_found, | |
} | |
def error_message_wrapper( | |
self, | |
error: Exception, | |
body: Optional[ | |
Union[ | |
"CreateChatCompletionRequest", | |
"CreateCompletionRequest", | |
"CreateEmbeddingRequest", | |
] | |
] = None, | |
) -> Tuple[int, ErrorResponse]: | |
"""Wraps error message in OpenAI style error response""" | |
print(f"Exception: {str(error)}", file=sys.stderr) | |
traceback.print_exc(file=sys.stderr) | |
if body is not None and isinstance( | |
body, | |
( | |
CreateCompletionRequest, | |
CreateChatCompletionRequest, | |
), | |
): | |
# When text completion or chat completion | |
for pattern, callback in self.pattern_and_formatters.items(): | |
match = pattern.search(str(error)) | |
if match is not None: | |
return callback(body, match) | |
# Wrap other errors as internal server error | |
return 500, ErrorResponse( | |
message=str(error), | |
type="internal_server_error", | |
param=None, | |
code=None, | |
) | |
def get_route_handler( | |
self, | |
) -> Callable[[Request], Coroutine[None, None, Response]]: | |
"""Defines custom route handler that catches exceptions and formats | |
in OpenAI style error response""" | |
original_route_handler = super().get_route_handler() | |
async def custom_route_handler(request: Request) -> Response: | |
try: | |
start_sec = time.perf_counter() | |
response = await original_route_handler(request) | |
elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000) | |
response.headers["openai-processing-ms"] = f"{elapsed_time_ms}" | |
return response | |
except HTTPException as unauthorized: | |
# api key check failed | |
raise unauthorized | |
except Exception as exc: | |
json_body = await request.json() | |
try: | |
if "messages" in json_body: | |
# Chat completion | |
body: Optional[ | |
Union[ | |
CreateChatCompletionRequest, | |
CreateCompletionRequest, | |
CreateEmbeddingRequest, | |
] | |
] = CreateChatCompletionRequest(**json_body) | |
elif "prompt" in json_body: | |
# Text completion | |
body = CreateCompletionRequest(**json_body) | |
else: | |
# Embedding | |
body = CreateEmbeddingRequest(**json_body) | |
except Exception: | |
# Invalid request body | |
body = None | |
# Get proper error message from the exception | |
( | |
status_code, | |
error_message, | |
) = self.error_message_wrapper(error=exc, body=body) | |
return JSONResponse( | |
{"error": error_message}, | |
status_code=status_code, | |
) | |
return custom_route_handler | |