Shyamnath's picture
Push core package and essential files
469eae6
raw
history blame
7.76 kB
import json
from datetime import datetime
from typing import Optional, Union
from urllib.parse import urlparse
import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
PassthroughStandardLoggingPayload,
)
from litellm.types.utils import StandardPassThroughResponseObject
from litellm.utils import executor as thread_pool_executor
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
)
from .llm_provider_handlers.assembly_passthrough_logging_handler import (
AssemblyAIPassthroughLoggingHandler,
)
from .llm_provider_handlers.cohere_passthrough_logging_handler import (
CoherePassthroughLoggingHandler,
)
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
VertexPassthroughLoggingHandler,
)
cohere_passthrough_logging_handler = CoherePassthroughLoggingHandler()
class PassThroughEndpointLogging:
def __init__(self):
self.TRACKED_VERTEX_ROUTES = [
"generateContent",
"streamGenerateContent",
"predict",
]
# Anthropic
self.TRACKED_ANTHROPIC_ROUTES = ["/messages"]
# Cohere
self.TRACKED_COHERE_ROUTES = ["/v2/chat"]
self.assemblyai_passthrough_logging_handler = (
AssemblyAIPassthroughLoggingHandler()
)
async def _handle_logging(
self,
logging_obj: LiteLLMLoggingObj,
standard_logging_response_object: Union[
StandardPassThroughResponseObject,
PassThroughEndpointLoggingResultValues,
dict,
],
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""Helper function to handle both sync and async logging operations"""
# Submit to thread pool for sync logging
thread_pool_executor.submit(
logging_obj.success_handler,
standard_logging_response_object,
start_time,
end_time,
cache_hit,
**kwargs,
)
# Handle async logging
await logging_obj.async_success_handler(
result=(
json.dumps(result)
if isinstance(result, dict)
else standard_logging_response_object
),
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
async def pass_through_async_success_handler(
self,
httpx_response: httpx.Response,
response_body: Optional[dict],
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
request_body: dict,
passthrough_logging_payload: PassthroughStandardLoggingPayload,
**kwargs,
):
standard_logging_response_object: Optional[
PassThroughEndpointLoggingResultValues
] = None
logging_obj.model_call_details["passthrough_logging_payload"] = (
passthrough_logging_payload
)
if self.is_vertex_route(url_route):
vertex_passthrough_logging_handler_result = (
VertexPassthroughLoggingHandler.vertex_passthrough_handler(
httpx_response=httpx_response,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
)
standard_logging_response_object = (
vertex_passthrough_logging_handler_result["result"]
)
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
elif self.is_anthropic_route(url_route):
anthropic_passthrough_logging_handler_result = (
AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
)
standard_logging_response_object = (
anthropic_passthrough_logging_handler_result["result"]
)
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif self.is_cohere_route(url_route):
cohere_passthrough_logging_handler_result = (
cohere_passthrough_logging_handler.passthrough_chat_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)
)
standard_logging_response_object = (
cohere_passthrough_logging_handler_result["result"]
)
kwargs = cohere_passthrough_logging_handler_result["kwargs"]
elif self.is_assemblyai_route(url_route):
if (
AssemblyAIPassthroughLoggingHandler._should_log_request(
httpx_response.request.method
)
is not True
):
return
self.assemblyai_passthrough_logging_handler.assemblyai_passthrough_logging_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
return
if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject(
response=httpx_response.text
)
await self._handle_logging(
logging_obj=logging_obj,
standard_logging_response_object=standard_logging_response_object,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
def is_vertex_route(self, url_route: str):
for route in self.TRACKED_VERTEX_ROUTES:
if route in url_route:
return True
return False
def is_anthropic_route(self, url_route: str):
for route in self.TRACKED_ANTHROPIC_ROUTES:
if route in url_route:
return True
return False
def is_cohere_route(self, url_route: str):
for route in self.TRACKED_COHERE_ROUTES:
if route in url_route:
return True
def is_assemblyai_route(self, url_route: str):
parsed_url = urlparse(url_route)
if parsed_url.hostname == "api.assemblyai.com":
return True
elif "/transcript" in parsed_url.path:
return True
return False