File size: 10,703 Bytes
469eae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import asyncio
import traceback
from datetime import datetime
from typing import Any, Optional, Union, cast

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import (
    _get_parent_otel_span_from_kwargs,
    get_litellm_metadata_from_kwargs,
)
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.auth_checks import log_db_metrics
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.utils import ProxyUpdateSpend
from litellm.types.utils import (
    StandardLoggingPayload,
    StandardLoggingUserAPIKeyMetadata,
)
from litellm.utils import get_end_user_id_for_cost_tracking


class _ProxyDBLogger(CustomLogger):
    async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
        await self._PROXY_track_cost_callback(
            kwargs, response_obj, start_time, end_time
        )

    async def async_post_call_failure_hook(
        self,
        request_data: dict,
        original_exception: Exception,
        user_api_key_dict: UserAPIKeyAuth,
    ):
        request_route = user_api_key_dict.request_route
        if _ProxyDBLogger._should_track_errors_in_db() is False:
            return
        elif request_route is not None and not RouteChecks.is_llm_api_route(
            route=request_route
        ):
            return

        from litellm.proxy.proxy_server import proxy_logging_obj

        _metadata = dict(
            StandardLoggingUserAPIKeyMetadata(
                user_api_key_hash=user_api_key_dict.api_key,
                user_api_key_alias=user_api_key_dict.key_alias,
                user_api_key_user_email=user_api_key_dict.user_email,
                user_api_key_user_id=user_api_key_dict.user_id,
                user_api_key_team_id=user_api_key_dict.team_id,
                user_api_key_org_id=user_api_key_dict.org_id,
                user_api_key_team_alias=user_api_key_dict.team_alias,
                user_api_key_end_user_id=user_api_key_dict.end_user_id,
            )
        )
        _metadata["user_api_key"] = user_api_key_dict.api_key
        _metadata["status"] = "failure"
        _metadata[
            "error_information"
        ] = StandardLoggingPayloadSetup.get_error_information(
            original_exception=original_exception,
        )

        existing_metadata: dict = request_data.get("metadata", None) or {}
        existing_metadata.update(_metadata)

        if "litellm_params" not in request_data:
            request_data["litellm_params"] = {}
        request_data["litellm_params"]["proxy_server_request"] = (
            request_data.get("proxy_server_request") or {}
        )
        request_data["litellm_params"]["metadata"] = existing_metadata
        await proxy_logging_obj.db_spend_update_writer.update_database(
            token=user_api_key_dict.api_key,
            response_cost=0.0,
            user_id=user_api_key_dict.user_id,
            end_user_id=user_api_key_dict.end_user_id,
            team_id=user_api_key_dict.team_id,
            kwargs=request_data,
            completion_response=original_exception,
            start_time=datetime.now(),
            end_time=datetime.now(),
            org_id=user_api_key_dict.org_id,
        )

    @log_db_metrics
    async def _PROXY_track_cost_callback(
        self,
        kwargs,  # kwargs to completion
        completion_response: Optional[
            Union[litellm.ModelResponse, Any]
        ],  # response from completion
        start_time=None,
        end_time=None,  # start/end time for completion
    ):
        from litellm.proxy.proxy_server import (
            prisma_client,
            proxy_logging_obj,
            update_cache,
        )

        verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
        try:
            verbose_proxy_logger.debug(
                f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
            )
            parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
            litellm_params = kwargs.get("litellm_params", {}) or {}
            end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
            metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
            user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
            team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
            org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
            key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
            end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
            sl_object: Optional[StandardLoggingPayload] = kwargs.get(
                "standard_logging_object", None
            )
            response_cost = (
                sl_object.get("response_cost", None)
                if sl_object is not None
                else kwargs.get("response_cost", None)
            )

            if response_cost is not None:
                user_api_key = metadata.get("user_api_key", None)
                if kwargs.get("cache_hit", False) is True:
                    response_cost = 0.0
                    verbose_proxy_logger.info(
                        f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
                    )

                verbose_proxy_logger.debug(
                    f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
                )
                if _should_track_cost_callback(
                    user_api_key=user_api_key,
                    user_id=user_id,
                    team_id=team_id,
                    end_user_id=end_user_id,
                ):
                    ## UPDATE DATABASE
                    await proxy_logging_obj.db_spend_update_writer.update_database(
                        token=user_api_key,
                        response_cost=response_cost,
                        user_id=user_id,
                        end_user_id=end_user_id,
                        team_id=team_id,
                        kwargs=kwargs,
                        completion_response=completion_response,
                        start_time=start_time,
                        end_time=end_time,
                        org_id=org_id,
                    )

                    # update cache
                    asyncio.create_task(
                        update_cache(
                            token=user_api_key,
                            user_id=user_id,
                            end_user_id=end_user_id,
                            response_cost=response_cost,
                            team_id=team_id,
                            parent_otel_span=parent_otel_span,
                        )
                    )

                    await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
                        token=user_api_key,
                        key_alias=key_alias,
                        end_user_id=end_user_id,
                        response_cost=response_cost,
                        max_budget=end_user_max_budget,
                    )
                else:
                    raise Exception(
                        "User API key and team id and user id missing from custom callback."
                    )
            else:
                if kwargs["stream"] is not True or (
                    kwargs["stream"] is True and "complete_streaming_response" in kwargs
                ):
                    if sl_object is not None:
                        cost_tracking_failure_debug_info: Union[dict, str] = (
                            sl_object["response_cost_failure_debug_info"]  # type: ignore
                            or "response_cost_failure_debug_info is None in standard_logging_object"
                        )
                    else:
                        cost_tracking_failure_debug_info = (
                            "standard_logging_object not found"
                        )
                    model = kwargs.get("model")
                    raise Exception(
                        f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
                    )
        except Exception as e:
            error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
            model = kwargs.get("model", "")
            metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
            litellm_metadata = kwargs.get("litellm_params", {}).get(
                "litellm_metadata", {}
            )
            old_metadata = kwargs.get("litellm_params", {}).get("metadata", {})
            call_type = kwargs.get("call_type", "")
            error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n chosen_metadata: {metadata}\n litellm_metadata: {litellm_metadata}\n old_metadata: {old_metadata}\n call_type: {call_type}\n"
            asyncio.create_task(
                proxy_logging_obj.failed_tracking_alert(
                    error_message=error_msg,
                    failing_model=model,
                )
            )

            verbose_proxy_logger.exception(
                "Error in tracking cost callback - %s", str(e)
            )

    @staticmethod
    def _should_track_errors_in_db():
        """
        Returns True if errors should be tracked in the database

        By default, errors are tracked in the database

        If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings
        """
        from litellm.proxy.proxy_server import general_settings

        if general_settings.get("disable_error_logs") is True:
            return False
        return


def _should_track_cost_callback(
    user_api_key: Optional[str],
    user_id: Optional[str],
    team_id: Optional[str],
    end_user_id: Optional[str],
) -> bool:
    """
    Determine if the cost callback should be tracked based on the kwargs
    """

    # don't run track cost callback if user opted into disabling spend
    if ProxyUpdateSpend.disable_spend_updates() is True:
        return False

    if (
        user_api_key is not None
        or user_id is not None
        or team_id is not None
        or end_user_id is not None
    ):
        return True
    return False