Spaces:
Sleeping
Sleeping
""" | |
Module responsible for | |
1. Writing spend increments to either in memory list of transactions or to redis | |
2. Reading increments from redis or in memory list of transactions and committing them to db | |
""" | |
import asyncio | |
import json | |
import os | |
import time | |
import traceback | |
from datetime import datetime, timedelta | |
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast, overload | |
import litellm | |
from litellm._logging import verbose_proxy_logger | |
from litellm.caching import DualCache, RedisCache | |
from litellm.constants import DB_SPEND_UPDATE_JOB_NAME | |
from litellm.proxy._types import ( | |
DB_CONNECTION_ERROR_TYPES, | |
BaseDailySpendTransaction, | |
DailyTagSpendTransaction, | |
DailyTeamSpendTransaction, | |
DailyUserSpendTransaction, | |
DBSpendUpdateTransactions, | |
Litellm_EntityType, | |
LiteLLM_UserTable, | |
SpendLogsMetadata, | |
SpendLogsPayload, | |
SpendUpdateQueueItem, | |
) | |
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import ( | |
DailySpendUpdateQueue, | |
) | |
from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager | |
from litellm.proxy.db.db_transaction_queue.redis_update_buffer import RedisUpdateBuffer | |
from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue | |
if TYPE_CHECKING: | |
from litellm.proxy.utils import PrismaClient, ProxyLogging | |
else: | |
PrismaClient = Any | |
ProxyLogging = Any | |
class DBSpendUpdateWriter: | |
""" | |
Module responsible for | |
1. Writing spend increments to either in memory list of transactions or to redis | |
2. Reading increments from redis or in memory list of transactions and committing them to db | |
""" | |
def __init__( | |
self, | |
redis_cache: Optional[RedisCache] = None, | |
): | |
self.redis_cache = redis_cache | |
self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache) | |
self.pod_lock_manager = PodLockManager() | |
self.spend_update_queue = SpendUpdateQueue() | |
self.daily_spend_update_queue = DailySpendUpdateQueue() | |
self.daily_team_spend_update_queue = DailySpendUpdateQueue() | |
self.daily_tag_spend_update_queue = DailySpendUpdateQueue() | |
async def update_database( | |
# LiteLLM management object fields | |
self, | |
token: Optional[str], | |
user_id: Optional[str], | |
end_user_id: Optional[str], | |
team_id: Optional[str], | |
org_id: Optional[str], | |
# Completion object fields | |
kwargs: Optional[dict], | |
completion_response: Optional[Union[litellm.ModelResponse, Any, Exception]], | |
start_time: Optional[datetime], | |
end_time: Optional[datetime], | |
response_cost: Optional[float], | |
): | |
from litellm.proxy.proxy_server import ( | |
disable_spend_logs, | |
litellm_proxy_budget_name, | |
prisma_client, | |
user_api_key_cache, | |
) | |
from litellm.proxy.utils import ProxyUpdateSpend, hash_token | |
try: | |
verbose_proxy_logger.debug( | |
f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}" | |
) | |
if ProxyUpdateSpend.disable_spend_updates() is True: | |
return | |
if token is not None and isinstance(token, str) and token.startswith("sk-"): | |
hashed_token = hash_token(token=token) | |
else: | |
hashed_token = token | |
## CREATE SPEND LOG PAYLOAD ## | |
from litellm.proxy.spend_tracking.spend_tracking_utils import ( | |
get_logging_payload, | |
) | |
payload = get_logging_payload( | |
kwargs=kwargs, | |
response_obj=completion_response, | |
start_time=start_time, | |
end_time=end_time, | |
) | |
payload["spend"] = response_cost or 0.0 | |
if isinstance(payload["startTime"], datetime): | |
payload["startTime"] = payload["startTime"].isoformat() | |
if isinstance(payload["endTime"], datetime): | |
payload["endTime"] = payload["endTime"].isoformat() | |
asyncio.create_task( | |
self._update_user_db( | |
response_cost=response_cost, | |
user_id=user_id, | |
prisma_client=prisma_client, | |
user_api_key_cache=user_api_key_cache, | |
litellm_proxy_budget_name=litellm_proxy_budget_name, | |
end_user_id=end_user_id, | |
) | |
) | |
asyncio.create_task( | |
self._update_key_db( | |
response_cost=response_cost, | |
hashed_token=hashed_token, | |
prisma_client=prisma_client, | |
) | |
) | |
asyncio.create_task( | |
self._update_team_db( | |
response_cost=response_cost, | |
team_id=team_id, | |
user_id=user_id, | |
prisma_client=prisma_client, | |
) | |
) | |
asyncio.create_task( | |
self._update_org_db( | |
response_cost=response_cost, | |
org_id=org_id, | |
prisma_client=prisma_client, | |
) | |
) | |
if disable_spend_logs is False: | |
await self._insert_spend_log_to_db( | |
payload=payload, | |
prisma_client=prisma_client, | |
) | |
else: | |
verbose_proxy_logger.info( | |
"disable_spend_logs=True. Skipping writing spend logs to db. Other spend updates - Key/User/Team table will still occur." | |
) | |
asyncio.create_task( | |
self.add_spend_log_transaction_to_daily_user_transaction( | |
payload=payload, | |
prisma_client=prisma_client, | |
) | |
) | |
asyncio.create_task( | |
self.add_spend_log_transaction_to_daily_team_transaction( | |
payload=payload, | |
prisma_client=prisma_client, | |
) | |
) | |
asyncio.create_task( | |
self.add_spend_log_transaction_to_daily_tag_transaction( | |
payload=payload, | |
prisma_client=prisma_client, | |
) | |
) | |
verbose_proxy_logger.debug("Runs spend update on all tables") | |
except Exception: | |
verbose_proxy_logger.debug( | |
f"Error updating Prisma database: {traceback.format_exc()}" | |
) | |
async def _update_key_db( | |
self, | |
response_cost: Optional[float], | |
hashed_token: Optional[str], | |
prisma_client: Optional[PrismaClient], | |
): | |
try: | |
if hashed_token is None or prisma_client is None: | |
return | |
await self.spend_update_queue.add_update( | |
update=SpendUpdateQueueItem( | |
entity_type=Litellm_EntityType.KEY, | |
entity_id=hashed_token, | |
response_cost=response_cost, | |
) | |
) | |
except Exception as e: | |
verbose_proxy_logger.exception( | |
f"Update Key DB Call failed to execute - {str(e)}" | |
) | |
raise e | |
async def _update_user_db( | |
self, | |
response_cost: Optional[float], | |
user_id: Optional[str], | |
prisma_client: Optional[PrismaClient], | |
user_api_key_cache: DualCache, | |
litellm_proxy_budget_name: Optional[str], | |
end_user_id: Optional[str] = None, | |
): | |
""" | |
- Update that user's row | |
- Update litellm-proxy-budget row (global proxy spend) | |
""" | |
## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db | |
existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) | |
if existing_user_obj is not None and isinstance(existing_user_obj, dict): | |
existing_user_obj = LiteLLM_UserTable(**existing_user_obj) | |
try: | |
if prisma_client is not None: # update | |
user_ids = [user_id] | |
if ( | |
litellm.max_budget > 0 | |
): # track global proxy budget, if user set max budget | |
user_ids.append(litellm_proxy_budget_name) | |
for _id in user_ids: | |
if _id is not None: | |
await self.spend_update_queue.add_update( | |
update=SpendUpdateQueueItem( | |
entity_type=Litellm_EntityType.USER, | |
entity_id=_id, | |
response_cost=response_cost, | |
) | |
) | |
if end_user_id is not None: | |
await self.spend_update_queue.add_update( | |
update=SpendUpdateQueueItem( | |
entity_type=Litellm_EntityType.END_USER, | |
entity_id=end_user_id, | |
response_cost=response_cost, | |
) | |
) | |
except Exception as e: | |
verbose_proxy_logger.info( | |
"\033[91m" | |
+ f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}" | |
) | |
async def _update_team_db( | |
self, | |
response_cost: Optional[float], | |
team_id: Optional[str], | |
user_id: Optional[str], | |
prisma_client: Optional[PrismaClient], | |
): | |
try: | |
if team_id is None or prisma_client is None: | |
verbose_proxy_logger.debug( | |
"track_cost_callback: team_id is None or prisma_client is None. Not tracking spend for team" | |
) | |
return | |
await self.spend_update_queue.add_update( | |
update=SpendUpdateQueueItem( | |
entity_type=Litellm_EntityType.TEAM, | |
entity_id=team_id, | |
response_cost=response_cost, | |
) | |
) | |
try: | |
# Track spend of the team member within this team | |
if user_id is not None: | |
# key is "team_id::<value>::user_id::<value>" | |
team_member_key = f"team_id::{team_id}::user_id::{user_id}" | |
await self.spend_update_queue.add_update( | |
update=SpendUpdateQueueItem( | |
entity_type=Litellm_EntityType.TEAM_MEMBER, | |
entity_id=team_member_key, | |
response_cost=response_cost, | |
) | |
) | |
except Exception: | |
pass | |
except Exception as e: | |
verbose_proxy_logger.info( | |
f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}" | |
) | |
raise e | |
async def _update_org_db( | |
self, | |
response_cost: Optional[float], | |
org_id: Optional[str], | |
prisma_client: Optional[PrismaClient], | |
): | |
try: | |
if org_id is None or prisma_client is None: | |
verbose_proxy_logger.debug( | |
"track_cost_callback: org_id is None or prisma_client is None. Not tracking spend for org" | |
) | |
return | |
await self.spend_update_queue.add_update( | |
update=SpendUpdateQueueItem( | |
entity_type=Litellm_EntityType.ORGANIZATION, | |
entity_id=org_id, | |
response_cost=response_cost, | |
) | |
) | |
except Exception as e: | |
verbose_proxy_logger.info( | |
f"Update Org DB failed to execute - {str(e)}\n{traceback.format_exc()}" | |
) | |
raise e | |
async def _insert_spend_log_to_db( | |
self, | |
payload: Union[dict, SpendLogsPayload], | |
prisma_client: Optional[PrismaClient] = None, | |
spend_logs_url: Optional[str] = os.getenv("SPEND_LOGS_URL"), | |
) -> Optional[PrismaClient]: | |
verbose_proxy_logger.info( | |
"Writing spend log to db - request_id: {}, spend: {}".format( | |
payload.get("request_id"), payload.get("spend") | |
) | |
) | |
if prisma_client is not None and spend_logs_url is not None: | |
prisma_client.spend_log_transactions.append(payload) | |
elif prisma_client is not None: | |
prisma_client.spend_log_transactions.append(payload) | |
else: | |
verbose_proxy_logger.debug( | |
"prisma_client is None. Skipping writing spend logs to db." | |
) | |
return prisma_client | |
async def db_update_spend_transaction_handler( | |
self, | |
prisma_client: PrismaClient, | |
n_retry_times: int, | |
proxy_logging_obj: ProxyLogging, | |
): | |
""" | |
Handles commiting update spend transactions to db | |
`UPDATES` can lead to deadlocks, hence we handle them separately | |
Args: | |
prisma_client: PrismaClient object | |
n_retry_times: int, number of retry times | |
proxy_logging_obj: ProxyLogging object | |
How this works: | |
- Check `general_settings.use_redis_transaction_buffer` | |
- If enabled, write in-memory transactions to Redis | |
- Check if this Pod should read from the DB | |
else: | |
- Regular flow of this method | |
""" | |
if RedisUpdateBuffer._should_commit_spend_updates_to_redis(): | |
await self._commit_spend_updates_to_db_with_redis( | |
prisma_client=prisma_client, | |
n_retry_times=n_retry_times, | |
proxy_logging_obj=proxy_logging_obj, | |
) | |
else: | |
await self._commit_spend_updates_to_db_without_redis_buffer( | |
prisma_client=prisma_client, | |
n_retry_times=n_retry_times, | |
proxy_logging_obj=proxy_logging_obj, | |
) | |
async def _commit_spend_updates_to_db_with_redis( | |
self, | |
prisma_client: PrismaClient, | |
n_retry_times: int, | |
proxy_logging_obj: ProxyLogging, | |
): | |
""" | |
Handler to commit spend updates to Redis and attempt to acquire lock to commit to db | |
This is a v2 scalable approach to first commit spend updates to redis, then commit to db | |
This minimizes DB Deadlocks since | |
- All pods only need to write their spend updates to redis | |
- Only 1 pod will commit to db at a time (based on if it can acquire the lock over writing to DB) | |
""" | |
await self.redis_update_buffer.store_in_memory_spend_updates_in_redis( | |
spend_update_queue=self.spend_update_queue, | |
daily_spend_update_queue=self.daily_spend_update_queue, | |
daily_team_spend_update_queue=self.daily_team_spend_update_queue, | |
daily_tag_spend_update_queue=self.daily_tag_spend_update_queue, | |
) | |
# Only commit from redis to db if this pod is the leader | |
if await self.pod_lock_manager.acquire_lock( | |
cronjob_id=DB_SPEND_UPDATE_JOB_NAME, | |
): | |
verbose_proxy_logger.debug("acquired lock for spend updates") | |
try: | |
db_spend_update_transactions = ( | |
await self.redis_update_buffer.get_all_update_transactions_from_redis_buffer() | |
) | |
if db_spend_update_transactions is not None: | |
await self._commit_spend_updates_to_db( | |
prisma_client=prisma_client, | |
n_retry_times=n_retry_times, | |
proxy_logging_obj=proxy_logging_obj, | |
db_spend_update_transactions=db_spend_update_transactions, | |
) | |
daily_spend_update_transactions = ( | |
await self.redis_update_buffer.get_all_daily_spend_update_transactions_from_redis_buffer() | |
) | |
if daily_spend_update_transactions is not None: | |
await DBSpendUpdateWriter.update_daily_user_spend( | |
n_retry_times=n_retry_times, | |
prisma_client=prisma_client, | |
proxy_logging_obj=proxy_logging_obj, | |
daily_spend_transactions=daily_spend_update_transactions, | |
) | |
daily_team_spend_update_transactions = ( | |
await self.redis_update_buffer.get_all_daily_team_spend_update_transactions_from_redis_buffer() | |
) | |
if daily_team_spend_update_transactions is not None: | |
await DBSpendUpdateWriter.update_daily_team_spend( | |
n_retry_times=n_retry_times, | |
prisma_client=prisma_client, | |
proxy_logging_obj=proxy_logging_obj, | |
daily_spend_transactions=daily_team_spend_update_transactions, | |
) | |
daily_tag_spend_update_transactions = ( | |
await self.redis_update_buffer.get_all_daily_tag_spend_update_transactions_from_redis_buffer() | |
) | |
if daily_tag_spend_update_transactions is not None: | |
await DBSpendUpdateWriter.update_daily_tag_spend( | |
n_retry_times=n_retry_times, | |
prisma_client=prisma_client, | |
proxy_logging_obj=proxy_logging_obj, | |
daily_spend_transactions=daily_tag_spend_update_transactions, | |
) | |
except Exception as e: | |
verbose_proxy_logger.error(f"Error committing spend updates: {e}") | |
finally: | |
await self.pod_lock_manager.release_lock( | |
cronjob_id=DB_SPEND_UPDATE_JOB_NAME, | |
) | |
async def _commit_spend_updates_to_db_without_redis_buffer( | |
self, | |
prisma_client: PrismaClient, | |
n_retry_times: int, | |
proxy_logging_obj: ProxyLogging, | |
): | |
""" | |
Commits all the spend `UPDATE` transactions to the Database | |
This is the regular flow of committing to db without using a redis buffer | |
Note: This flow causes Deadlocks in production (1K RPS+). Use self._commit_spend_updates_to_db_with_redis() instead if you expect 1K+ RPS. | |
""" | |
# Aggregate all in memory spend updates (key, user, end_user, team, team_member, org) and commit to db | |
################## Spend Update Transactions ################## | |
db_spend_update_transactions = ( | |
await self.spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions() | |
) | |
await self._commit_spend_updates_to_db( | |
prisma_client=prisma_client, | |
n_retry_times=n_retry_times, | |
proxy_logging_obj=proxy_logging_obj, | |
db_spend_update_transactions=db_spend_update_transactions, | |
) | |
################## Daily Spend Update Transactions ################## | |
# Aggregate all in memory daily spend transactions and commit to db | |
daily_spend_update_transactions = cast( | |
Dict[str, DailyUserSpendTransaction], | |
await self.daily_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions(), | |
) | |
await DBSpendUpdateWriter.update_daily_user_spend( | |
n_retry_times=n_retry_times, | |
prisma_client=prisma_client, | |
proxy_logging_obj=proxy_logging_obj, | |
daily_spend_transactions=daily_spend_update_transactions, | |
) | |
################## Daily Team Spend Update Transactions ################## | |
# Aggregate all in memory daily team spend transactions and commit to db | |
daily_team_spend_update_transactions = cast( | |
Dict[str, DailyTeamSpendTransaction], | |
await self.daily_team_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions(), | |
) | |
await DBSpendUpdateWriter.update_daily_team_spend( | |
n_retry_times=n_retry_times, | |
prisma_client=prisma_client, | |
proxy_logging_obj=proxy_logging_obj, | |
daily_spend_transactions=daily_team_spend_update_transactions, | |
) | |
################## Daily Tag Spend Update Transactions ################## | |
# Aggregate all in memory daily tag spend transactions and commit to db | |
daily_tag_spend_update_transactions = cast( | |
Dict[str, DailyTagSpendTransaction], | |
await self.daily_tag_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions(), | |
) | |
await DBSpendUpdateWriter.update_daily_tag_spend( | |
n_retry_times=n_retry_times, | |
prisma_client=prisma_client, | |
proxy_logging_obj=proxy_logging_obj, | |
daily_spend_transactions=daily_tag_spend_update_transactions, | |
) | |
async def _commit_spend_updates_to_db( # noqa: PLR0915 | |
self, | |
prisma_client: PrismaClient, | |
n_retry_times: int, | |
proxy_logging_obj: ProxyLogging, | |
db_spend_update_transactions: DBSpendUpdateTransactions, | |
): | |
""" | |
Commits all the spend `UPDATE` transactions to the Database | |
""" | |
from litellm.proxy.utils import ( | |
ProxyUpdateSpend, | |
_raise_failed_update_spend_exception, | |
) | |
### UPDATE USER TABLE ### | |
user_list_transactions = db_spend_update_transactions["user_list_transactions"] | |
verbose_proxy_logger.debug( | |
"User Spend transactions: {}".format(user_list_transactions) | |
) | |
if ( | |
user_list_transactions is not None | |
and len(user_list_transactions.keys()) > 0 | |
): | |
for i in range(n_retry_times + 1): | |
start_time = time.time() | |
try: | |
async with prisma_client.db.tx( | |
timeout=timedelta(seconds=60) | |
) as transaction: | |
async with transaction.batch_() as batcher: | |
for ( | |
user_id, | |
response_cost, | |
) in user_list_transactions.items(): | |
batcher.litellm_usertable.update_many( | |
where={"user_id": user_id}, | |
data={"spend": {"increment": response_cost}}, | |
) | |
break | |
except DB_CONNECTION_ERROR_TYPES as e: | |
if ( | |
i >= n_retry_times | |
): # If we've reached the maximum number of retries | |
_raise_failed_update_spend_exception( | |
e=e, | |
start_time=start_time, | |
proxy_logging_obj=proxy_logging_obj, | |
) | |
# Optionally, sleep for a bit before retrying | |
await asyncio.sleep(2**i) # Exponential backoff | |
except Exception as e: | |
_raise_failed_update_spend_exception( | |
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj | |
) | |
### UPDATE END-USER TABLE ### | |
end_user_list_transactions = db_spend_update_transactions[ | |
"end_user_list_transactions" | |
] | |
verbose_proxy_logger.debug( | |
"End-User Spend transactions: {}".format(end_user_list_transactions) | |
) | |
if ( | |
end_user_list_transactions is not None | |
and len(end_user_list_transactions.keys()) > 0 | |
): | |
await ProxyUpdateSpend.update_end_user_spend( | |
n_retry_times=n_retry_times, | |
prisma_client=prisma_client, | |
proxy_logging_obj=proxy_logging_obj, | |
end_user_list_transactions=end_user_list_transactions, | |
) | |
### UPDATE KEY TABLE ### | |
key_list_transactions = db_spend_update_transactions["key_list_transactions"] | |
verbose_proxy_logger.debug( | |
"KEY Spend transactions: {}".format(key_list_transactions) | |
) | |
if key_list_transactions is not None and len(key_list_transactions.keys()) > 0: | |
for i in range(n_retry_times + 1): | |
start_time = time.time() | |
try: | |
async with prisma_client.db.tx( | |
timeout=timedelta(seconds=60) | |
) as transaction: | |
async with transaction.batch_() as batcher: | |
for ( | |
token, | |
response_cost, | |
) in key_list_transactions.items(): | |
batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists | |
where={"token": token}, | |
data={"spend": {"increment": response_cost}}, | |
) | |
break | |
except DB_CONNECTION_ERROR_TYPES as e: | |
if ( | |
i >= n_retry_times | |
): # If we've reached the maximum number of retries | |
_raise_failed_update_spend_exception( | |
e=e, | |
start_time=start_time, | |
proxy_logging_obj=proxy_logging_obj, | |
) | |
# Optionally, sleep for a bit before retrying | |
await asyncio.sleep(2**i) # Exponential backoff | |
except Exception as e: | |
_raise_failed_update_spend_exception( | |
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj | |
) | |
### UPDATE TEAM TABLE ### | |
team_list_transactions = db_spend_update_transactions["team_list_transactions"] | |
verbose_proxy_logger.debug( | |
"Team Spend transactions: {}".format(team_list_transactions) | |
) | |
if ( | |
team_list_transactions is not None | |
and len(team_list_transactions.keys()) > 0 | |
): | |
for i in range(n_retry_times + 1): | |
start_time = time.time() | |
try: | |
async with prisma_client.db.tx( | |
timeout=timedelta(seconds=60) | |
) as transaction: | |
async with transaction.batch_() as batcher: | |
for ( | |
team_id, | |
response_cost, | |
) in team_list_transactions.items(): | |
verbose_proxy_logger.debug( | |
"Updating spend for team id={} by {}".format( | |
team_id, response_cost | |
) | |
) | |
batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists | |
where={"team_id": team_id}, | |
data={"spend": {"increment": response_cost}}, | |
) | |
break | |
except DB_CONNECTION_ERROR_TYPES as e: | |
if ( | |
i >= n_retry_times | |
): # If we've reached the maximum number of retries | |
_raise_failed_update_spend_exception( | |
e=e, | |
start_time=start_time, | |
proxy_logging_obj=proxy_logging_obj, | |
) | |
# Optionally, sleep for a bit before retrying | |
await asyncio.sleep(2**i) # Exponential backoff | |
except Exception as e: | |
_raise_failed_update_spend_exception( | |
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj | |
) | |
### UPDATE TEAM Membership TABLE with spend ### | |
team_member_list_transactions = db_spend_update_transactions[ | |
"team_member_list_transactions" | |
] | |
verbose_proxy_logger.debug( | |
"Team Membership Spend transactions: {}".format( | |
team_member_list_transactions | |
) | |
) | |
if ( | |
team_member_list_transactions is not None | |
and len(team_member_list_transactions.keys()) > 0 | |
): | |
for i in range(n_retry_times + 1): | |
start_time = time.time() | |
try: | |
async with prisma_client.db.tx( | |
timeout=timedelta(seconds=60) | |
) as transaction: | |
async with transaction.batch_() as batcher: | |
for ( | |
key, | |
response_cost, | |
) in team_member_list_transactions.items(): | |
# key is "team_id::<value>::user_id::<value>" | |
team_id = key.split("::")[1] | |
user_id = key.split("::")[3] | |
batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists | |
where={"team_id": team_id, "user_id": user_id}, | |
data={"spend": {"increment": response_cost}}, | |
) | |
break | |
except DB_CONNECTION_ERROR_TYPES as e: | |
if ( | |
i >= n_retry_times | |
): # If we've reached the maximum number of retries | |
_raise_failed_update_spend_exception( | |
e=e, | |
start_time=start_time, | |
proxy_logging_obj=proxy_logging_obj, | |
) | |
# Optionally, sleep for a bit before retrying | |
await asyncio.sleep(2**i) # Exponential backoff | |
except Exception as e: | |
_raise_failed_update_spend_exception( | |
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj | |
) | |
### UPDATE ORG TABLE ### | |
org_list_transactions = db_spend_update_transactions["org_list_transactions"] | |
verbose_proxy_logger.debug( | |
"Org Spend transactions: {}".format(org_list_transactions) | |
) | |
if org_list_transactions is not None and len(org_list_transactions.keys()) > 0: | |
for i in range(n_retry_times + 1): | |
start_time = time.time() | |
try: | |
async with prisma_client.db.tx( | |
timeout=timedelta(seconds=60) | |
) as transaction: | |
async with transaction.batch_() as batcher: | |
for ( | |
org_id, | |
response_cost, | |
) in org_list_transactions.items(): | |
batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists | |
where={"organization_id": org_id}, | |
data={"spend": {"increment": response_cost}}, | |
) | |
break | |
except DB_CONNECTION_ERROR_TYPES as e: | |
if ( | |
i >= n_retry_times | |
): # If we've reached the maximum number of retries | |
_raise_failed_update_spend_exception( | |
e=e, | |
start_time=start_time, | |
proxy_logging_obj=proxy_logging_obj, | |
) | |
# Optionally, sleep for a bit before retrying | |
await asyncio.sleep(2**i) # Exponential backoff | |
except Exception as e: | |
_raise_failed_update_spend_exception( | |
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj | |
) | |
async def _update_daily_spend( | |
n_retry_times: int, | |
prisma_client: PrismaClient, | |
proxy_logging_obj: ProxyLogging, | |
daily_spend_transactions: Dict[str, DailyUserSpendTransaction], | |
entity_type: Literal["user"], | |
entity_id_field: str, | |
table_name: str, | |
unique_constraint_name: str, | |
) -> None: | |
... | |
async def _update_daily_spend( | |
n_retry_times: int, | |
prisma_client: PrismaClient, | |
proxy_logging_obj: ProxyLogging, | |
daily_spend_transactions: Dict[str, DailyTeamSpendTransaction], | |
entity_type: Literal["team"], | |
entity_id_field: str, | |
table_name: str, | |
unique_constraint_name: str, | |
) -> None: | |
... | |
async def _update_daily_spend( | |
n_retry_times: int, | |
prisma_client: PrismaClient, | |
proxy_logging_obj: ProxyLogging, | |
daily_spend_transactions: Dict[str, DailyTagSpendTransaction], | |
entity_type: Literal["tag"], | |
entity_id_field: str, | |
table_name: str, | |
unique_constraint_name: str, | |
) -> None: | |
... | |
async def _update_daily_spend( | |
n_retry_times: int, | |
prisma_client: PrismaClient, | |
proxy_logging_obj: ProxyLogging, | |
daily_spend_transactions: Union[ | |
Dict[str, DailyUserSpendTransaction], | |
Dict[str, DailyTeamSpendTransaction], | |
Dict[str, DailyTagSpendTransaction], | |
], | |
entity_type: Literal["user", "team", "tag"], | |
entity_id_field: str, | |
table_name: str, | |
unique_constraint_name: str, | |
) -> None: | |
""" | |
Generic function to update daily spend for any entity type (user, team, tag) | |
""" | |
from litellm.proxy.utils import _raise_failed_update_spend_exception | |
verbose_proxy_logger.debug( | |
f"Daily {entity_type.capitalize()} Spend transactions: {len(daily_spend_transactions)}" | |
) | |
BATCH_SIZE = 100 | |
start_time = time.time() | |
try: | |
for i in range(n_retry_times + 1): | |
try: | |
transactions_to_process = dict( | |
list(daily_spend_transactions.items())[:BATCH_SIZE] | |
) | |
if len(transactions_to_process) == 0: | |
verbose_proxy_logger.debug( | |
f"No new transactions to process for daily {entity_type} spend update" | |
) | |
break | |
async with prisma_client.db.batch_() as batcher: | |
for _, transaction in transactions_to_process.items(): | |
entity_id = transaction.get(entity_id_field) | |
if not entity_id: | |
continue | |
# Construct the where clause dynamically | |
where_clause = { | |
unique_constraint_name: { | |
entity_id_field: entity_id, | |
"date": transaction["date"], | |
"api_key": transaction["api_key"], | |
"model": transaction["model"], | |
"custom_llm_provider": transaction.get( | |
"custom_llm_provider" | |
), | |
} | |
} | |
# Get the table dynamically | |
table = getattr(batcher, table_name) | |
# Common data structure for both create and update | |
common_data = { | |
entity_id_field: entity_id, | |
"date": transaction["date"], | |
"api_key": transaction["api_key"], | |
"model": transaction["model"], | |
"model_group": transaction.get("model_group"), | |
"custom_llm_provider": transaction.get( | |
"custom_llm_provider" | |
), | |
"prompt_tokens": transaction["prompt_tokens"], | |
"completion_tokens": transaction["completion_tokens"], | |
"spend": transaction["spend"], | |
"api_requests": transaction["api_requests"], | |
"successful_requests": transaction[ | |
"successful_requests" | |
], | |
"failed_requests": transaction["failed_requests"], | |
} | |
# Add cache-related fields if they exist | |
if "cache_read_input_tokens" in transaction: | |
common_data[ | |
"cache_read_input_tokens" | |
] = transaction.get("cache_read_input_tokens", 0) | |
if "cache_creation_input_tokens" in transaction: | |
common_data[ | |
"cache_creation_input_tokens" | |
] = transaction.get("cache_creation_input_tokens", 0) | |
# Create update data structure | |
update_data = { | |
"prompt_tokens": { | |
"increment": transaction["prompt_tokens"] | |
}, | |
"completion_tokens": { | |
"increment": transaction["completion_tokens"] | |
}, | |
"spend": {"increment": transaction["spend"]}, | |
"api_requests": { | |
"increment": transaction["api_requests"] | |
}, | |
"successful_requests": { | |
"increment": transaction["successful_requests"] | |
}, | |
"failed_requests": { | |
"increment": transaction["failed_requests"] | |
}, | |
} | |
# Add cache-related fields to update if they exist | |
if "cache_read_input_tokens" in transaction: | |
update_data["cache_read_input_tokens"] = { | |
"increment": transaction.get( | |
"cache_read_input_tokens", 0 | |
) | |
} | |
if "cache_creation_input_tokens" in transaction: | |
update_data["cache_creation_input_tokens"] = { | |
"increment": transaction.get( | |
"cache_creation_input_tokens", 0 | |
) | |
} | |
table.upsert( | |
where=where_clause, | |
data={ | |
"create": common_data, | |
"update": update_data, | |
}, | |
) | |
verbose_proxy_logger.info( | |
f"Processed {len(transactions_to_process)} daily {entity_type} transactions in {time.time() - start_time:.2f}s" | |
) | |
# Remove processed transactions | |
for key in transactions_to_process.keys(): | |
daily_spend_transactions.pop(key, None) | |
break | |
except DB_CONNECTION_ERROR_TYPES as e: | |
if i >= n_retry_times: | |
_raise_failed_update_spend_exception( | |
e=e, | |
start_time=start_time, | |
proxy_logging_obj=proxy_logging_obj, | |
) | |
await asyncio.sleep(2**i) | |
except Exception as e: | |
if "transactions_to_process" in locals(): | |
for key in transactions_to_process.keys(): # type: ignore | |
daily_spend_transactions.pop(key, None) | |
_raise_failed_update_spend_exception( | |
e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj | |
) | |
async def update_daily_user_spend( | |
n_retry_times: int, | |
prisma_client: PrismaClient, | |
proxy_logging_obj: ProxyLogging, | |
daily_spend_transactions: Dict[str, DailyUserSpendTransaction], | |
): | |
""" | |
Batch job to update LiteLLM_DailyUserSpend table using in-memory daily_spend_transactions | |
""" | |
await DBSpendUpdateWriter._update_daily_spend( | |
n_retry_times=n_retry_times, | |
prisma_client=prisma_client, | |
proxy_logging_obj=proxy_logging_obj, | |
daily_spend_transactions=daily_spend_transactions, | |
entity_type="user", | |
entity_id_field="user_id", | |
table_name="litellm_dailyuserspend", | |
unique_constraint_name="user_id_date_api_key_model_custom_llm_provider", | |
) | |
async def update_daily_team_spend( | |
n_retry_times: int, | |
prisma_client: PrismaClient, | |
proxy_logging_obj: ProxyLogging, | |
daily_spend_transactions: Dict[str, DailyTeamSpendTransaction], | |
): | |
""" | |
Batch job to update LiteLLM_DailyTeamSpend table using in-memory daily_spend_transactions | |
""" | |
await DBSpendUpdateWriter._update_daily_spend( | |
n_retry_times=n_retry_times, | |
prisma_client=prisma_client, | |
proxy_logging_obj=proxy_logging_obj, | |
daily_spend_transactions=daily_spend_transactions, | |
entity_type="team", | |
entity_id_field="team_id", | |
table_name="litellm_dailyteamspend", | |
unique_constraint_name="team_id_date_api_key_model_custom_llm_provider", | |
) | |
async def update_daily_tag_spend( | |
n_retry_times: int, | |
prisma_client: PrismaClient, | |
proxy_logging_obj: ProxyLogging, | |
daily_spend_transactions: Dict[str, DailyTagSpendTransaction], | |
): | |
""" | |
Batch job to update LiteLLM_DailyTagSpend table using in-memory daily_spend_transactions | |
""" | |
await DBSpendUpdateWriter._update_daily_spend( | |
n_retry_times=n_retry_times, | |
prisma_client=prisma_client, | |
proxy_logging_obj=proxy_logging_obj, | |
daily_spend_transactions=daily_spend_transactions, | |
entity_type="tag", | |
entity_id_field="tag", | |
table_name="litellm_dailytagspend", | |
unique_constraint_name="tag_date_api_key_model_custom_llm_provider", | |
) | |
async def _common_add_spend_log_transaction_to_daily_transaction( | |
self, | |
payload: Union[dict, SpendLogsPayload], | |
prisma_client: PrismaClient, | |
type: Literal["user", "team", "request_tags"] = "user", | |
) -> Optional[BaseDailySpendTransaction]: | |
common_expected_keys = ["startTime", "api_key", "model", "custom_llm_provider"] | |
if type == "user": | |
expected_keys = ["user", *common_expected_keys] | |
elif type == "team": | |
expected_keys = ["team_id", *common_expected_keys] | |
elif type == "request_tags": | |
expected_keys = ["request_tags", *common_expected_keys] | |
else: | |
raise ValueError(f"Invalid type: {type}") | |
if not all(key in payload for key in expected_keys): | |
verbose_proxy_logger.debug( | |
f"Missing expected keys: {expected_keys}, in payload, skipping from daily_user_spend_transactions" | |
) | |
return None | |
request_status = prisma_client.get_request_status(payload) | |
verbose_proxy_logger.info(f"Logged request status: {request_status}") | |
_metadata: SpendLogsMetadata = json.loads(payload["metadata"]) | |
usage_obj = _metadata.get("usage_object", {}) or {} | |
if isinstance(payload["startTime"], datetime): | |
start_time = payload["startTime"].isoformat() | |
date = start_time.split("T")[0] | |
elif isinstance(payload["startTime"], str): | |
date = payload["startTime"].split("T")[0] | |
else: | |
verbose_proxy_logger.debug( | |
f"Invalid start time: {payload['startTime']}, skipping from daily_user_spend_transactions" | |
) | |
return None | |
try: | |
daily_transaction = BaseDailySpendTransaction( | |
date=date, | |
api_key=payload["api_key"], | |
model=payload["model"], | |
model_group=payload["model_group"], | |
custom_llm_provider=payload["custom_llm_provider"], | |
prompt_tokens=payload["prompt_tokens"], | |
completion_tokens=payload["completion_tokens"], | |
spend=payload["spend"], | |
api_requests=1, | |
successful_requests=1 if request_status == "success" else 0, | |
failed_requests=1 if request_status != "success" else 0, | |
cache_read_input_tokens=usage_obj.get("cache_read_input_tokens", 0) | |
or 0, | |
cache_creation_input_tokens=usage_obj.get( | |
"cache_creation_input_tokens", 0 | |
) | |
or 0, | |
) | |
return daily_transaction | |
except Exception as e: | |
raise e | |
async def add_spend_log_transaction_to_daily_user_transaction( | |
self, | |
payload: Union[dict, SpendLogsPayload], | |
prisma_client: Optional[PrismaClient] = None, | |
): | |
""" | |
Add a spend log transaction to the `daily_spend_update_queue` | |
Key = @@unique([user_id, date, api_key, model, custom_llm_provider]) ) | |
If key exists, update the transaction with the new spend and usage | |
""" | |
if prisma_client is None: | |
verbose_proxy_logger.debug( | |
"prisma_client is None. Skipping writing spend logs to db." | |
) | |
return | |
base_daily_transaction = ( | |
await self._common_add_spend_log_transaction_to_daily_transaction( | |
payload, prisma_client, "user" | |
) | |
) | |
if base_daily_transaction is None: | |
return | |
daily_transaction_key = f"{payload['user']}_{base_daily_transaction['date']}_{payload['api_key']}_{payload['model']}_{payload['custom_llm_provider']}" | |
daily_transaction = DailyUserSpendTransaction( | |
user_id=payload["user"], **base_daily_transaction | |
) | |
await self.daily_spend_update_queue.add_update( | |
update={daily_transaction_key: daily_transaction} | |
) | |
async def add_spend_log_transaction_to_daily_team_transaction( | |
self, | |
payload: SpendLogsPayload, | |
prisma_client: Optional[PrismaClient] = None, | |
) -> None: | |
if prisma_client is None: | |
verbose_proxy_logger.debug( | |
"prisma_client is None. Skipping writing spend logs to db." | |
) | |
return | |
base_daily_transaction = ( | |
await self._common_add_spend_log_transaction_to_daily_transaction( | |
payload, prisma_client, "team" | |
) | |
) | |
if base_daily_transaction is None: | |
return | |
if payload["team_id"] is None: | |
verbose_proxy_logger.debug( | |
"team_id is None for request. Skipping incrementing team spend." | |
) | |
return | |
daily_transaction_key = f"{payload['team_id']}_{base_daily_transaction['date']}_{payload['api_key']}_{payload['model']}_{payload['custom_llm_provider']}" | |
daily_transaction = DailyTeamSpendTransaction( | |
team_id=payload["team_id"], **base_daily_transaction | |
) | |
await self.daily_team_spend_update_queue.add_update( | |
update={daily_transaction_key: daily_transaction} | |
) | |
async def add_spend_log_transaction_to_daily_tag_transaction( | |
self, | |
payload: SpendLogsPayload, | |
prisma_client: Optional[PrismaClient] = None, | |
) -> None: | |
if prisma_client is None: | |
verbose_proxy_logger.debug( | |
"prisma_client is None. Skipping writing spend logs to db." | |
) | |
return | |
base_daily_transaction = ( | |
await self._common_add_spend_log_transaction_to_daily_transaction( | |
payload, prisma_client, "request_tags" | |
) | |
) | |
if base_daily_transaction is None: | |
return | |
if payload["request_tags"] is None: | |
verbose_proxy_logger.debug( | |
"request_tags is None for request. Skipping incrementing tag spend." | |
) | |
return | |
request_tags = [] | |
if isinstance(payload["request_tags"], str): | |
request_tags = json.loads(payload["request_tags"]) | |
elif isinstance(payload["request_tags"], list): | |
request_tags = payload["request_tags"] | |
else: | |
raise ValueError(f"Invalid request_tags: {payload['request_tags']}") | |
for tag in request_tags: | |
daily_transaction_key = f"{tag}_{base_daily_transaction['date']}_{payload['api_key']}_{payload['model']}_{payload['custom_llm_provider']}" | |
daily_transaction = DailyTagSpendTransaction( | |
tag=tag, **base_daily_transaction | |
) | |
await self.daily_tag_spend_update_queue.add_update( | |
update={daily_transaction_key: daily_transaction} | |
) | |