nikhil_no_persistent / lilac /router_utils.py
nsthorat-lilac's picture
Duplicate from lilacai/nikhil_staging
bfc0ec6
raw
history blame
No virus
1.8 kB
"""Utils for routers."""
import traceback
from typing import Callable, Iterable, Optional
from fastapi import HTTPException, Request, Response
from fastapi.routing import APIRoute
from .auth import UserInfo
from .concepts.db_concept import DISK_CONCEPT_DB, DISK_CONCEPT_MODEL_DB
from .schema import Item, RichData
from .signals.concept_scorer import ConceptSignal
class RouteErrorHandler(APIRoute):
"""Custom APIRoute that handles application errors and exceptions."""
def get_route_handler(self) -> Callable:
"""Get the route handler."""
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
try:
return await original_route_handler(request)
except Exception as ex:
if isinstance(ex, HTTPException):
raise ex
print('Route error:', request.url)
print(ex)
print(traceback.format_exc())
# wrap error into pretty 500 exception
raise HTTPException(status_code=500, detail=traceback.format_exc()) from ex
return custom_route_handler
def server_compute_concept(signal: ConceptSignal, examples: Iterable[RichData],
user: Optional[UserInfo]) -> list[Optional[Item]]:
"""Compute a concept from the REST endpoints."""
# TODO(nsthorat): Move this to the setup() method in the concept_scorer.
concept = DISK_CONCEPT_DB.get(signal.namespace, signal.concept_name, user)
if not concept:
raise HTTPException(
status_code=404, detail=f'Concept "{signal.namespace}/{signal.concept_name}" was not found')
DISK_CONCEPT_MODEL_DB.sync(
signal.namespace, signal.concept_name, signal.embedding, user=user, create=True)
texts = [example or '' for example in examples]
return list(signal.compute(texts))