Shyamnath's picture
Push core package and essential files
469eae6
raw
history blame
8.42 kB
import asyncio
from typing import Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
DBSpendUpdateTransactions,
Litellm_EntityType,
SpendUpdateQueueItem,
)
from litellm.proxy.db.db_transaction_queue.base_update_queue import (
BaseUpdateQueue,
service_logger_obj,
)
from litellm.types.services import ServiceTypes
class SpendUpdateQueue(BaseUpdateQueue):
"""
In memory buffer for spend updates that should be committed to the database
"""
def __init__(self):
super().__init__()
self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue()
async def flush_and_get_aggregated_db_spend_update_transactions(
self,
) -> DBSpendUpdateTransactions:
"""Flush all updates from the queue and return all updates aggregated by entity type."""
updates = await self.flush_all_updates_from_in_memory_queue()
verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates)
return self.get_aggregated_db_spend_update_transactions(updates)
async def add_update(self, update: SpendUpdateQueueItem):
"""Enqueue an update to the spend update queue"""
verbose_proxy_logger.debug("Adding update to queue: %s", update)
await self.update_queue.put(update)
# if the queue is full, aggregate the updates
if self.update_queue.qsize() >= self.MAX_SIZE_IN_MEMORY_QUEUE:
verbose_proxy_logger.warning(
"Spend update queue is full. Aggregating all entries in queue to concatenate entries."
)
await self.aggregate_queue_updates()
async def aggregate_queue_updates(self):
"""Concatenate all updates in the queue to reduce the size of in-memory queue"""
updates: List[
SpendUpdateQueueItem
] = await self.flush_all_updates_from_in_memory_queue()
aggregated_updates = self._get_aggregated_spend_update_queue_item(updates)
for update in aggregated_updates:
await self.update_queue.put(update)
return
def _get_aggregated_spend_update_queue_item(
self, updates: List[SpendUpdateQueueItem]
) -> List[SpendUpdateQueueItem]:
"""
This is used to reduce the size of the in-memory queue by aggregating updates by entity type + id
Aggregate updates by entity type + id
eg.
```
[
{
"entity_type": "user",
"entity_id": "123",
"response_cost": 100
},
{
"entity_type": "user",
"entity_id": "123",
"response_cost": 200
}
]
```
becomes
```
[
{
"entity_type": "user",
"entity_id": "123",
"response_cost": 300
}
]
```
"""
verbose_proxy_logger.debug(
"Aggregating spend updates, current queue size: %s",
self.update_queue.qsize(),
)
aggregated_spend_updates: List[SpendUpdateQueueItem] = []
_in_memory_map: Dict[str, SpendUpdateQueueItem] = {}
"""
Used for combining several updates into a single update
Key=entity_type:entity_id
Value=SpendUpdateQueueItem
"""
for update in updates:
_key = f"{update.get('entity_type')}:{update.get('entity_id')}"
if _key not in _in_memory_map:
_in_memory_map[_key] = update
else:
current_cost = _in_memory_map[_key].get("response_cost", 0) or 0
update_cost = update.get("response_cost", 0) or 0
_in_memory_map[_key]["response_cost"] = current_cost + update_cost
for _key, update in _in_memory_map.items():
aggregated_spend_updates.append(update)
verbose_proxy_logger.debug(
"Aggregated spend updates: %s", aggregated_spend_updates
)
return aggregated_spend_updates
def get_aggregated_db_spend_update_transactions(
self, updates: List[SpendUpdateQueueItem]
) -> DBSpendUpdateTransactions:
"""Aggregate updates by entity type."""
# Initialize all transaction lists as empty dicts
db_spend_update_transactions = DBSpendUpdateTransactions(
user_list_transactions={},
end_user_list_transactions={},
key_list_transactions={},
team_list_transactions={},
team_member_list_transactions={},
org_list_transactions={},
)
# Map entity types to their corresponding transaction dictionary keys
entity_type_to_dict_key = {
Litellm_EntityType.USER: "user_list_transactions",
Litellm_EntityType.END_USER: "end_user_list_transactions",
Litellm_EntityType.KEY: "key_list_transactions",
Litellm_EntityType.TEAM: "team_list_transactions",
Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions",
Litellm_EntityType.ORGANIZATION: "org_list_transactions",
}
for update in updates:
entity_type = update.get("entity_type")
entity_id = update.get("entity_id") or ""
response_cost = update.get("response_cost") or 0
if entity_type is None:
verbose_proxy_logger.debug(
"Skipping update spend for update: %s, because entity_type is None",
update,
)
continue
dict_key = entity_type_to_dict_key.get(entity_type)
if dict_key is None:
verbose_proxy_logger.debug(
"Skipping update spend for update: %s, because entity_type is not in entity_type_to_dict_key",
update,
)
continue # Skip unknown entity types
# Type-safe access using if/elif statements
if dict_key == "user_list_transactions":
transactions_dict = db_spend_update_transactions[
"user_list_transactions"
]
elif dict_key == "end_user_list_transactions":
transactions_dict = db_spend_update_transactions[
"end_user_list_transactions"
]
elif dict_key == "key_list_transactions":
transactions_dict = db_spend_update_transactions[
"key_list_transactions"
]
elif dict_key == "team_list_transactions":
transactions_dict = db_spend_update_transactions[
"team_list_transactions"
]
elif dict_key == "team_member_list_transactions":
transactions_dict = db_spend_update_transactions[
"team_member_list_transactions"
]
elif dict_key == "org_list_transactions":
transactions_dict = db_spend_update_transactions[
"org_list_transactions"
]
else:
continue
if transactions_dict is None:
transactions_dict = {}
# type ignore: dict_key is guaranteed to be one of "one of ("user_list_transactions", "end_user_list_transactions", "key_list_transactions", "team_list_transactions", "team_member_list_transactions", "org_list_transactions")"
db_spend_update_transactions[dict_key] = transactions_dict # type: ignore
if entity_id not in transactions_dict:
transactions_dict[entity_id] = 0
transactions_dict[entity_id] += response_cost or 0
return db_spend_update_transactions
async def _emit_new_item_added_to_queue_event(
self,
queue_size: Optional[int] = None,
):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
duration=0,
call_type="_emit_new_item_added_to_queue_event",
event_metadata={
"gauge_labels": ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
"gauge_value": queue_size,
},
)
)