Spaces:
Configuration error
Configuration error
import asyncio | |
import os | |
import sys | |
from unittest.mock import Mock | |
from litellm.proxy.utils import _get_redoc_url, _get_docs_url | |
import pytest | |
from fastapi import Request | |
sys.path.insert( | |
0, os.path.abspath("../..") | |
) # Adds the parent directory to the system path | |
import litellm | |
from unittest.mock import MagicMock, patch, AsyncMock | |
import httpx | |
from litellm.proxy.utils import update_spend, DB_CONNECTION_ERROR_TYPES | |
class MockPrismaClient: | |
def __init__(self): | |
# Create AsyncMock for db operations | |
self.db = AsyncMock() | |
self.db.litellm_spendlogs = AsyncMock() | |
self.db.litellm_spendlogs.create_many = AsyncMock() | |
# Initialize transaction lists | |
self.spend_log_transactions = [] | |
self.daily_user_spend_transactions = {} | |
def jsonify_object(self, obj): | |
return obj | |
def add_spend_log_transaction_to_daily_user_transaction(self, payload): | |
# Mock implementation | |
pass | |
def create_mock_proxy_logging(): | |
print("creating mock proxy logging") | |
proxy_logging_obj = MagicMock() | |
proxy_logging_obj.failure_handler = AsyncMock() | |
proxy_logging_obj.db_spend_update_writer = AsyncMock() | |
proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler = AsyncMock() | |
print("returning proxy logging obj") | |
return proxy_logging_obj | |
async def test_update_spend_logs_connection_errors(error_type): | |
"""Test retry mechanism for different connection error types""" | |
# Setup | |
prisma_client = MockPrismaClient() | |
proxy_logging_obj = create_mock_proxy_logging() | |
# Create AsyncMock for db_spend_update_writer | |
proxy_logging_obj.db_spend_update_writer = AsyncMock() | |
proxy_logging_obj.db_spend_update_writer.db_update_spend_transaction_handler = AsyncMock() | |
# Add test spend logs | |
prisma_client.spend_log_transactions = [ | |
{"id": "1", "spend": 10}, | |
{"id": "2", "spend": 20}, | |
] | |
# Mock the database to fail with connection error twice then succeed | |
create_many_mock = AsyncMock() | |
create_many_mock.side_effect = [ | |
error_type, # First attempt fails | |
error_type, # Second attempt fails | |
error_type, # Third attempt fails | |
None, # Fourth attempt succeeds | |
] | |
prisma_client.db.litellm_spendlogs.create_many = create_many_mock | |
# Execute | |
await update_spend(prisma_client, None, proxy_logging_obj) | |
# Verify | |
assert create_many_mock.call_count == 4 # Should have tried 3 times | |
assert ( | |
len(prisma_client.spend_log_transactions) == 0 | |
) # Should have cleared after success | |
async def test_update_spend_logs_max_retries_exceeded(error_type): | |
"""Test that each connection error type properly fails after max retries""" | |
# Setup | |
prisma_client = MockPrismaClient() | |
proxy_logging_obj = create_mock_proxy_logging() | |
# Add test spend logs | |
prisma_client.spend_log_transactions = [ | |
{"id": "1", "spend": 10}, | |
{"id": "2", "spend": 20}, | |
] | |
# Mock the database to always fail | |
create_many_mock = AsyncMock(side_effect=error_type) | |
prisma_client.db.litellm_spendlogs.create_many = create_many_mock | |
# Execute and verify it raises after max retries | |
with pytest.raises(type(error_type)) as exc_info: | |
await update_spend(prisma_client, None, proxy_logging_obj) | |
# Verify error message matches | |
assert str(exc_info.value) == str(error_type) | |
# Verify retry attempts (initial try + 4 retries) | |
assert create_many_mock.call_count == 4 | |
await asyncio.sleep(2) | |
# Verify failure handler was called | |
assert proxy_logging_obj.failure_handler.call_count == 1 | |
async def test_update_spend_logs_non_connection_error(): | |
"""Test handling of non-connection related errors""" | |
# Setup | |
prisma_client = MockPrismaClient() | |
proxy_logging_obj = create_mock_proxy_logging() | |
# Add test spend logs | |
prisma_client.spend_log_transactions = [ | |
{"id": "1", "spend": 10}, | |
{"id": "2", "spend": 20}, | |
] | |
# Mock a different type of error (not connection-related) | |
unexpected_error = ValueError("Unexpected database error") | |
create_many_mock = AsyncMock(side_effect=unexpected_error) | |
prisma_client.db.litellm_spendlogs.create_many = create_many_mock | |
# Execute and verify it raises immediately without retrying | |
with pytest.raises(ValueError) as exc_info: | |
await update_spend(prisma_client, None, proxy_logging_obj) | |
# Verify error message | |
assert str(exc_info.value) == "Unexpected database error" | |
# Verify only tried once (no retries for non-connection errors) | |
assert create_many_mock.call_count == 1 | |
# Verify failure handler was called | |
assert proxy_logging_obj.failure_handler.called | |
async def test_update_spend_logs_exponential_backoff(): | |
"""Test that exponential backoff is working correctly""" | |
# Setup | |
prisma_client = MockPrismaClient() | |
proxy_logging_obj = create_mock_proxy_logging() | |
# Add test spend logs | |
prisma_client.spend_log_transactions = [{"id": "1", "spend": 10}] | |
# Track sleep times | |
sleep_times = [] | |
# Mock asyncio.sleep to track delay times | |
async def mock_sleep(seconds): | |
sleep_times.append(seconds) | |
# Mock the database to fail with connection errors | |
create_many_mock = AsyncMock( | |
side_effect=[ | |
httpx.ConnectError("Failed to connect"), # First attempt | |
httpx.ConnectError("Failed to connect"), # Second attempt | |
None, # Third attempt succeeds | |
] | |
) | |
prisma_client.db.litellm_spendlogs.create_many = create_many_mock | |
# Apply mocks | |
with patch("asyncio.sleep", mock_sleep): | |
await update_spend(prisma_client, None, proxy_logging_obj) | |
# Verify exponential backoff | |
assert len(sleep_times) == 2 # Should have slept twice | |
assert sleep_times[0] == 1 # First retry after 2^0 seconds | |
assert sleep_times[1] == 2 # Second retry after 2^1 seconds | |
async def test_update_spend_logs_multiple_batches_success(): | |
""" | |
Test successful processing of multiple batches of spend logs | |
Code sets batch size to 100. This test creates 150 logs, so it should make 2 batches. | |
""" | |
# Setup | |
prisma_client = MockPrismaClient() | |
proxy_logging_obj = create_mock_proxy_logging() | |
# Create 150 test spend logs (1.5x BATCH_SIZE) | |
prisma_client.spend_log_transactions = [ | |
{"id": str(i), "spend": 10} for i in range(150) | |
] | |
create_many_mock = AsyncMock(return_value=None) | |
prisma_client.db.litellm_spendlogs.create_many = create_many_mock | |
# Execute | |
await update_spend(prisma_client, None, proxy_logging_obj) | |
# Verify | |
assert create_many_mock.call_count == 2 # Should have made 2 batch calls | |
# Get the actual data from each batch call | |
first_batch = create_many_mock.call_args_list[0][1]["data"] | |
second_batch = create_many_mock.call_args_list[1][1]["data"] | |
# Verify batch sizes | |
assert len(first_batch) == 100 | |
assert len(second_batch) == 50 | |
# Verify exact IDs in each batch | |
expected_first_batch_ids = {str(i) for i in range(100)} | |
expected_second_batch_ids = {str(i) for i in range(100, 150)} | |
actual_first_batch_ids = {item["id"] for item in first_batch} | |
actual_second_batch_ids = {item["id"] for item in second_batch} | |
assert actual_first_batch_ids == expected_first_batch_ids | |
assert actual_second_batch_ids == expected_second_batch_ids | |
# Verify all logs were processed | |
assert len(prisma_client.spend_log_transactions) == 0 | |
async def test_update_spend_logs_multiple_batches_with_failure(): | |
""" | |
Test processing of multiple batches where one batch fails. | |
Creates 400 logs (4 batches) with one batch failing but eventually succeeding after retry. | |
""" | |
# Setup | |
prisma_client = MockPrismaClient() | |
proxy_logging_obj = create_mock_proxy_logging() | |
# Create 400 test spend logs (4x BATCH_SIZE) | |
prisma_client.spend_log_transactions = [ | |
{"id": str(i), "spend": 10} for i in range(400) | |
] | |
# Mock to fail on second batch first attempt, then succeed | |
call_count = 0 | |
async def create_many_side_effect(**kwargs): | |
nonlocal call_count | |
call_count += 1 | |
# Fail on the second batch's first attempt | |
if call_count == 2: | |
raise httpx.ConnectError("Failed to connect") | |
return None | |
create_many_mock = AsyncMock(side_effect=create_many_side_effect) | |
prisma_client.db.litellm_spendlogs.create_many = create_many_mock | |
# Execute | |
await update_spend(prisma_client, None, proxy_logging_obj) | |
# Verify | |
assert create_many_mock.call_count == 6 # 4 batches + 2 retries for failed batch | |
# Verify all batches were processed | |
all_processed_logs = [] | |
for call in create_many_mock.call_args_list: | |
all_processed_logs.extend(call[1]["data"]) | |
# Verify all IDs were processed | |
processed_ids = {item["id"] for item in all_processed_logs} | |
# these should have ids 0-399 | |
print("all processed ids", sorted(processed_ids, key=int)) | |
expected_ids = {str(i) for i in range(400)} | |
assert processed_ids == expected_ids | |
# Verify all logs were cleared from transactions | |
assert len(prisma_client.spend_log_transactions) == 0 | |