lilac / lilac /router_utils.py
nsthorat-lilac's picture
Duplicate from lilacai/lilac
ddcfeb8
"""Utils for routers."""
import traceback
from typing import Callable, Iterable, Optional
from fastapi import HTTPException, Request, Response
from fastapi.routing import APIRoute
from .auth import UserInfo
from .concepts.db_concept import DISK_CONCEPT_DB, DISK_CONCEPT_MODEL_DB
from .schema import Item, RichData
from .signals.concept_scorer import ConceptSignal
class RouteErrorHandler(APIRoute):
"""Custom APIRoute that handles application errors and exceptions."""
def get_route_handler(self) -> Callable:
"""Get the route handler."""
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
try:
return await original_route_handler(request)
except Exception as ex:
if isinstance(ex, HTTPException):
raise ex
print('Route error:', request.url)
print(ex)
print(traceback.format_exc())
# wrap error into pretty 500 exception
raise HTTPException(status_code=500, detail=traceback.format_exc()) from ex
return custom_route_handler
def server_compute_concept(signal: ConceptSignal, examples: Iterable[RichData],
user: Optional[UserInfo]) -> list[Optional[Item]]:
"""Compute a concept from the REST endpoints."""
# TODO(nsthorat): Move this to the setup() method in the concept_scorer.
concept = DISK_CONCEPT_DB.get(signal.namespace, signal.concept_name, user)
if not concept:
raise HTTPException(
status_code=404, detail=f'Concept "{signal.namespace}/{signal.concept_name}" was not found')
DISK_CONCEPT_MODEL_DB.sync(
signal.namespace, signal.concept_name, signal.embedding, user=user, create=True)
texts = [example or '' for example in examples]
return list(signal.compute(texts))