Spaces:
Sleeping
Sleeping
#### Rerank Endpoints ##### | |
import orjson | |
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status | |
from fastapi.responses import ORJSONResponse | |
from litellm._logging import verbose_proxy_logger | |
from litellm.proxy._types import * | |
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth | |
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing | |
router = APIRouter() | |
import asyncio | |
async def rerank( | |
request: Request, | |
fastapi_response: Response, | |
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
): | |
from litellm.proxy.proxy_server import ( | |
add_litellm_data_to_request, | |
general_settings, | |
llm_router, | |
proxy_config, | |
proxy_logging_obj, | |
route_request, | |
user_model, | |
version, | |
) | |
data = {} | |
try: | |
body = await request.body() | |
data = orjson.loads(body) | |
# Include original request and headers in the data | |
data = await add_litellm_data_to_request( | |
data=data, | |
request=request, | |
general_settings=general_settings, | |
user_api_key_dict=user_api_key_dict, | |
version=version, | |
proxy_config=proxy_config, | |
) | |
### CALL HOOKS ### - modify incoming data / reject request before calling the model | |
data = await proxy_logging_obj.pre_call_hook( | |
user_api_key_dict=user_api_key_dict, data=data, call_type="rerank" | |
) | |
## ROUTE TO CORRECT ENDPOINT ## | |
llm_call = await route_request( | |
data=data, | |
route_type="arerank", | |
llm_router=llm_router, | |
user_model=user_model, | |
) | |
response = await llm_call | |
### ALERTING ### | |
asyncio.create_task( | |
proxy_logging_obj.update_request_status( | |
litellm_call_id=data.get("litellm_call_id", ""), status="success" | |
) | |
) | |
### RESPONSE HEADERS ### | |
hidden_params = getattr(response, "_hidden_params", {}) or {} | |
model_id = hidden_params.get("model_id", None) or "" | |
cache_key = hidden_params.get("cache_key", None) or "" | |
api_base = hidden_params.get("api_base", None) or "" | |
additional_headers = hidden_params.get("additional_headers", None) or {} | |
fastapi_response.headers.update( | |
ProxyBaseLLMRequestProcessing.get_custom_headers( | |
user_api_key_dict=user_api_key_dict, | |
model_id=model_id, | |
cache_key=cache_key, | |
api_base=api_base, | |
version=version, | |
model_region=getattr(user_api_key_dict, "allowed_model_region", ""), | |
request_data=data, | |
**additional_headers, | |
) | |
) | |
return response | |
except Exception as e: | |
await proxy_logging_obj.post_call_failure_hook( | |
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data | |
) | |
verbose_proxy_logger.error( | |
"litellm.proxy.proxy_server.rerank(): Exception occured - {}".format(str(e)) | |
) | |
if isinstance(e, HTTPException): | |
raise ProxyException( | |
message=getattr(e, "message", str(e)), | |
type=getattr(e, "type", "None"), | |
param=getattr(e, "param", "None"), | |
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), | |
) | |
else: | |
error_msg = f"{str(e)}" | |
raise ProxyException( | |
message=getattr(e, "message", error_msg), | |
type=getattr(e, "type", "None"), | |
param=getattr(e, "param", "None"), | |
code=getattr(e, "status_code", 500), | |
) | |