Spaces:
Sleeping
Sleeping
import asyncio | |
import traceback | |
from datetime import datetime | |
from typing import Any, Optional, Union, cast | |
import litellm | |
from litellm._logging import verbose_proxy_logger | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.litellm_core_utils.core_helpers import ( | |
_get_parent_otel_span_from_kwargs, | |
get_litellm_metadata_from_kwargs, | |
) | |
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup | |
from litellm.proxy._types import UserAPIKeyAuth | |
from litellm.proxy.auth.auth_checks import log_db_metrics | |
from litellm.proxy.auth.route_checks import RouteChecks | |
from litellm.proxy.utils import ProxyUpdateSpend | |
from litellm.types.utils import ( | |
StandardLoggingPayload, | |
StandardLoggingUserAPIKeyMetadata, | |
) | |
from litellm.utils import get_end_user_id_for_cost_tracking | |
class _ProxyDBLogger(CustomLogger): | |
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): | |
await self._PROXY_track_cost_callback( | |
kwargs, response_obj, start_time, end_time | |
) | |
async def async_post_call_failure_hook( | |
self, | |
request_data: dict, | |
original_exception: Exception, | |
user_api_key_dict: UserAPIKeyAuth, | |
): | |
request_route = user_api_key_dict.request_route | |
if _ProxyDBLogger._should_track_errors_in_db() is False: | |
return | |
elif request_route is not None and not RouteChecks.is_llm_api_route( | |
route=request_route | |
): | |
return | |
from litellm.proxy.proxy_server import proxy_logging_obj | |
_metadata = dict( | |
StandardLoggingUserAPIKeyMetadata( | |
user_api_key_hash=user_api_key_dict.api_key, | |
user_api_key_alias=user_api_key_dict.key_alias, | |
user_api_key_user_email=user_api_key_dict.user_email, | |
user_api_key_user_id=user_api_key_dict.user_id, | |
user_api_key_team_id=user_api_key_dict.team_id, | |
user_api_key_org_id=user_api_key_dict.org_id, | |
user_api_key_team_alias=user_api_key_dict.team_alias, | |
user_api_key_end_user_id=user_api_key_dict.end_user_id, | |
) | |
) | |
_metadata["user_api_key"] = user_api_key_dict.api_key | |
_metadata["status"] = "failure" | |
_metadata[ | |
"error_information" | |
] = StandardLoggingPayloadSetup.get_error_information( | |
original_exception=original_exception, | |
) | |
existing_metadata: dict = request_data.get("metadata", None) or {} | |
existing_metadata.update(_metadata) | |
if "litellm_params" not in request_data: | |
request_data["litellm_params"] = {} | |
request_data["litellm_params"]["proxy_server_request"] = ( | |
request_data.get("proxy_server_request") or {} | |
) | |
request_data["litellm_params"]["metadata"] = existing_metadata | |
await proxy_logging_obj.db_spend_update_writer.update_database( | |
token=user_api_key_dict.api_key, | |
response_cost=0.0, | |
user_id=user_api_key_dict.user_id, | |
end_user_id=user_api_key_dict.end_user_id, | |
team_id=user_api_key_dict.team_id, | |
kwargs=request_data, | |
completion_response=original_exception, | |
start_time=datetime.now(), | |
end_time=datetime.now(), | |
org_id=user_api_key_dict.org_id, | |
) | |
async def _PROXY_track_cost_callback( | |
self, | |
kwargs, # kwargs to completion | |
completion_response: Optional[ | |
Union[litellm.ModelResponse, Any] | |
], # response from completion | |
start_time=None, | |
end_time=None, # start/end time for completion | |
): | |
from litellm.proxy.proxy_server import ( | |
prisma_client, | |
proxy_logging_obj, | |
update_cache, | |
) | |
verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback") | |
try: | |
verbose_proxy_logger.debug( | |
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" | |
) | |
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) | |
litellm_params = kwargs.get("litellm_params", {}) or {} | |
end_user_id = get_end_user_id_for_cost_tracking(litellm_params) | |
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) | |
user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None)) | |
team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None)) | |
org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None)) | |
key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None)) | |
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) | |
sl_object: Optional[StandardLoggingPayload] = kwargs.get( | |
"standard_logging_object", None | |
) | |
response_cost = ( | |
sl_object.get("response_cost", None) | |
if sl_object is not None | |
else kwargs.get("response_cost", None) | |
) | |
if response_cost is not None: | |
user_api_key = metadata.get("user_api_key", None) | |
if kwargs.get("cache_hit", False) is True: | |
response_cost = 0.0 | |
verbose_proxy_logger.info( | |
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}" | |
) | |
verbose_proxy_logger.debug( | |
f"user_api_key {user_api_key}, prisma_client: {prisma_client}" | |
) | |
if _should_track_cost_callback( | |
user_api_key=user_api_key, | |
user_id=user_id, | |
team_id=team_id, | |
end_user_id=end_user_id, | |
): | |
## UPDATE DATABASE | |
await proxy_logging_obj.db_spend_update_writer.update_database( | |
token=user_api_key, | |
response_cost=response_cost, | |
user_id=user_id, | |
end_user_id=end_user_id, | |
team_id=team_id, | |
kwargs=kwargs, | |
completion_response=completion_response, | |
start_time=start_time, | |
end_time=end_time, | |
org_id=org_id, | |
) | |
# update cache | |
asyncio.create_task( | |
update_cache( | |
token=user_api_key, | |
user_id=user_id, | |
end_user_id=end_user_id, | |
response_cost=response_cost, | |
team_id=team_id, | |
parent_otel_span=parent_otel_span, | |
) | |
) | |
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert( | |
token=user_api_key, | |
key_alias=key_alias, | |
end_user_id=end_user_id, | |
response_cost=response_cost, | |
max_budget=end_user_max_budget, | |
) | |
else: | |
raise Exception( | |
"User API key and team id and user id missing from custom callback." | |
) | |
else: | |
if kwargs["stream"] is not True or ( | |
kwargs["stream"] is True and "complete_streaming_response" in kwargs | |
): | |
if sl_object is not None: | |
cost_tracking_failure_debug_info: Union[dict, str] = ( | |
sl_object["response_cost_failure_debug_info"] # type: ignore | |
or "response_cost_failure_debug_info is None in standard_logging_object" | |
) | |
else: | |
cost_tracking_failure_debug_info = ( | |
"standard_logging_object not found" | |
) | |
model = kwargs.get("model") | |
raise Exception( | |
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" | |
) | |
except Exception as e: | |
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}" | |
model = kwargs.get("model", "") | |
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) | |
litellm_metadata = kwargs.get("litellm_params", {}).get( | |
"litellm_metadata", {} | |
) | |
old_metadata = kwargs.get("litellm_params", {}).get("metadata", {}) | |
call_type = kwargs.get("call_type", "") | |
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n chosen_metadata: {metadata}\n litellm_metadata: {litellm_metadata}\n old_metadata: {old_metadata}\n call_type: {call_type}\n" | |
asyncio.create_task( | |
proxy_logging_obj.failed_tracking_alert( | |
error_message=error_msg, | |
failing_model=model, | |
) | |
) | |
verbose_proxy_logger.exception( | |
"Error in tracking cost callback - %s", str(e) | |
) | |
def _should_track_errors_in_db(): | |
""" | |
Returns True if errors should be tracked in the database | |
By default, errors are tracked in the database | |
If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings | |
""" | |
from litellm.proxy.proxy_server import general_settings | |
if general_settings.get("disable_error_logs") is True: | |
return False | |
return | |
def _should_track_cost_callback( | |
user_api_key: Optional[str], | |
user_id: Optional[str], | |
team_id: Optional[str], | |
end_user_id: Optional[str], | |
) -> bool: | |
""" | |
Determine if the cost callback should be tracked based on the kwargs | |
""" | |
# don't run track cost callback if user opted into disabling spend | |
if ProxyUpdateSpend.disable_spend_updates() is True: | |
return False | |
if ( | |
user_api_key is not None | |
or user_id is not None | |
or team_id is not None | |
or end_user_id is not None | |
): | |
return True | |
return False | |