Spaces:
Sleeping
Sleeping
import asyncio | |
from typing import Dict, List, Optional | |
from litellm._logging import verbose_proxy_logger | |
from litellm.proxy._types import ( | |
DBSpendUpdateTransactions, | |
Litellm_EntityType, | |
SpendUpdateQueueItem, | |
) | |
from litellm.proxy.db.db_transaction_queue.base_update_queue import ( | |
BaseUpdateQueue, | |
service_logger_obj, | |
) | |
from litellm.types.services import ServiceTypes | |
class SpendUpdateQueue(BaseUpdateQueue): | |
""" | |
In memory buffer for spend updates that should be committed to the database | |
""" | |
def __init__(self): | |
super().__init__() | |
self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue() | |
async def flush_and_get_aggregated_db_spend_update_transactions( | |
self, | |
) -> DBSpendUpdateTransactions: | |
"""Flush all updates from the queue and return all updates aggregated by entity type.""" | |
updates = await self.flush_all_updates_from_in_memory_queue() | |
verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates) | |
return self.get_aggregated_db_spend_update_transactions(updates) | |
async def add_update(self, update: SpendUpdateQueueItem): | |
"""Enqueue an update to the spend update queue""" | |
verbose_proxy_logger.debug("Adding update to queue: %s", update) | |
await self.update_queue.put(update) | |
# if the queue is full, aggregate the updates | |
if self.update_queue.qsize() >= self.MAX_SIZE_IN_MEMORY_QUEUE: | |
verbose_proxy_logger.warning( | |
"Spend update queue is full. Aggregating all entries in queue to concatenate entries." | |
) | |
await self.aggregate_queue_updates() | |
async def aggregate_queue_updates(self): | |
"""Concatenate all updates in the queue to reduce the size of in-memory queue""" | |
updates: List[ | |
SpendUpdateQueueItem | |
] = await self.flush_all_updates_from_in_memory_queue() | |
aggregated_updates = self._get_aggregated_spend_update_queue_item(updates) | |
for update in aggregated_updates: | |
await self.update_queue.put(update) | |
return | |
def _get_aggregated_spend_update_queue_item( | |
self, updates: List[SpendUpdateQueueItem] | |
) -> List[SpendUpdateQueueItem]: | |
""" | |
This is used to reduce the size of the in-memory queue by aggregating updates by entity type + id | |
Aggregate updates by entity type + id | |
eg. | |
``` | |
[ | |
{ | |
"entity_type": "user", | |
"entity_id": "123", | |
"response_cost": 100 | |
}, | |
{ | |
"entity_type": "user", | |
"entity_id": "123", | |
"response_cost": 200 | |
} | |
] | |
``` | |
becomes | |
``` | |
[ | |
{ | |
"entity_type": "user", | |
"entity_id": "123", | |
"response_cost": 300 | |
} | |
] | |
``` | |
""" | |
verbose_proxy_logger.debug( | |
"Aggregating spend updates, current queue size: %s", | |
self.update_queue.qsize(), | |
) | |
aggregated_spend_updates: List[SpendUpdateQueueItem] = [] | |
_in_memory_map: Dict[str, SpendUpdateQueueItem] = {} | |
""" | |
Used for combining several updates into a single update | |
Key=entity_type:entity_id | |
Value=SpendUpdateQueueItem | |
""" | |
for update in updates: | |
_key = f"{update.get('entity_type')}:{update.get('entity_id')}" | |
if _key not in _in_memory_map: | |
_in_memory_map[_key] = update | |
else: | |
current_cost = _in_memory_map[_key].get("response_cost", 0) or 0 | |
update_cost = update.get("response_cost", 0) or 0 | |
_in_memory_map[_key]["response_cost"] = current_cost + update_cost | |
for _key, update in _in_memory_map.items(): | |
aggregated_spend_updates.append(update) | |
verbose_proxy_logger.debug( | |
"Aggregated spend updates: %s", aggregated_spend_updates | |
) | |
return aggregated_spend_updates | |
def get_aggregated_db_spend_update_transactions( | |
self, updates: List[SpendUpdateQueueItem] | |
) -> DBSpendUpdateTransactions: | |
"""Aggregate updates by entity type.""" | |
# Initialize all transaction lists as empty dicts | |
db_spend_update_transactions = DBSpendUpdateTransactions( | |
user_list_transactions={}, | |
end_user_list_transactions={}, | |
key_list_transactions={}, | |
team_list_transactions={}, | |
team_member_list_transactions={}, | |
org_list_transactions={}, | |
) | |
# Map entity types to their corresponding transaction dictionary keys | |
entity_type_to_dict_key = { | |
Litellm_EntityType.USER: "user_list_transactions", | |
Litellm_EntityType.END_USER: "end_user_list_transactions", | |
Litellm_EntityType.KEY: "key_list_transactions", | |
Litellm_EntityType.TEAM: "team_list_transactions", | |
Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions", | |
Litellm_EntityType.ORGANIZATION: "org_list_transactions", | |
} | |
for update in updates: | |
entity_type = update.get("entity_type") | |
entity_id = update.get("entity_id") or "" | |
response_cost = update.get("response_cost") or 0 | |
if entity_type is None: | |
verbose_proxy_logger.debug( | |
"Skipping update spend for update: %s, because entity_type is None", | |
update, | |
) | |
continue | |
dict_key = entity_type_to_dict_key.get(entity_type) | |
if dict_key is None: | |
verbose_proxy_logger.debug( | |
"Skipping update spend for update: %s, because entity_type is not in entity_type_to_dict_key", | |
update, | |
) | |
continue # Skip unknown entity types | |
# Type-safe access using if/elif statements | |
if dict_key == "user_list_transactions": | |
transactions_dict = db_spend_update_transactions[ | |
"user_list_transactions" | |
] | |
elif dict_key == "end_user_list_transactions": | |
transactions_dict = db_spend_update_transactions[ | |
"end_user_list_transactions" | |
] | |
elif dict_key == "key_list_transactions": | |
transactions_dict = db_spend_update_transactions[ | |
"key_list_transactions" | |
] | |
elif dict_key == "team_list_transactions": | |
transactions_dict = db_spend_update_transactions[ | |
"team_list_transactions" | |
] | |
elif dict_key == "team_member_list_transactions": | |
transactions_dict = db_spend_update_transactions[ | |
"team_member_list_transactions" | |
] | |
elif dict_key == "org_list_transactions": | |
transactions_dict = db_spend_update_transactions[ | |
"org_list_transactions" | |
] | |
else: | |
continue | |
if transactions_dict is None: | |
transactions_dict = {} | |
# type ignore: dict_key is guaranteed to be one of "one of ("user_list_transactions", "end_user_list_transactions", "key_list_transactions", "team_list_transactions", "team_member_list_transactions", "org_list_transactions")" | |
db_spend_update_transactions[dict_key] = transactions_dict # type: ignore | |
if entity_id not in transactions_dict: | |
transactions_dict[entity_id] = 0 | |
transactions_dict[entity_id] += response_cost or 0 | |
return db_spend_update_transactions | |
async def _emit_new_item_added_to_queue_event( | |
self, | |
queue_size: Optional[int] = None, | |
): | |
asyncio.create_task( | |
service_logger_obj.async_service_success_hook( | |
service=ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE, | |
duration=0, | |
call_type="_emit_new_item_added_to_queue_event", | |
event_metadata={ | |
"gauge_labels": ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE, | |
"gauge_value": queue_size, | |
}, | |
) | |
) | |