Spaces:
Configuration error
Configuration error
File size: 4,325 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# tests/test_budget_endpoints.py
import os
import sys
import types
import pytest
from unittest.mock import AsyncMock, MagicMock
from fastapi.testclient import TestClient
import litellm.proxy.proxy_server as ps
from litellm.proxy.proxy_server import app
from litellm.proxy._types import UserAPIKeyAuth, LitellmUserRoles, CommonProxyErrors
import litellm.proxy.management_endpoints.budget_management_endpoints as bm
sys.path.insert(
0, os.path.abspath("../../../")
) # Adds the parent directory to the system path
@pytest.fixture
def client_and_mocks(monkeypatch):
# Setup MagicMock Prisma
mock_prisma = MagicMock()
mock_table = MagicMock()
mock_table.create = AsyncMock(side_effect=lambda *, data: data)
mock_table.update = AsyncMock(side_effect=lambda *, where, data: {**where, **data})
mock_prisma.db = types.SimpleNamespace(
litellm_budgettable = mock_table,
litellm_dailyspend = mock_table,
)
# Monkeypatch Mocked Prisma client into the server module
monkeypatch.setattr(ps, "prisma_client", mock_prisma)
# override returned auth user
fake_user = UserAPIKeyAuth(
user_id="test_user",
user_role=LitellmUserRoles.INTERNAL_USER,
)
app.dependency_overrides[ps.user_api_key_auth] = lambda: fake_user
client = TestClient(app)
yield client, mock_prisma, mock_table
# teardown
app.dependency_overrides.clear()
monkeypatch.setattr(ps, "prisma_client", ps.prisma_client)
@pytest.mark.asyncio
async def test_new_budget_success(client_and_mocks):
client, _, mock_table = client_and_mocks
# Call /budget/new endpoint
payload = {
"budget_id": "budget_123",
"max_budget": 42.0,
"budget_duration": "30d",
}
resp = client.post("/budget/new", json=payload)
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["budget_id"] == payload["budget_id"]
assert body["max_budget"] == payload["max_budget"]
assert body["budget_duration"] == payload["budget_duration"]
assert body["created_by"] == "test_user"
assert body["updated_by"] == "test_user"
mock_table.create.assert_awaited_once()
@pytest.mark.asyncio
async def test_new_budget_db_not_connected(client_and_mocks, monkeypatch):
client, mock_prisma, mock_table = client_and_mocks
# override the prisma_client that the handler imports at runtime
import litellm.proxy.proxy_server as ps
monkeypatch.setattr(ps, "prisma_client", None)
# Call /budget/new endpoint
resp = client.post("/budget/new", json={"budget_id": "no_db", "max_budget": 1.0})
assert resp.status_code == 500
detail = resp.json()["detail"]
assert detail["error"] == CommonProxyErrors.db_not_connected_error.value
@pytest.mark.asyncio
async def test_update_budget_success(client_and_mocks, monkeypatch):
client, mock_prisma, mock_table = client_and_mocks
payload = {
"budget_id": "budget_456",
"max_budget": 99.0,
"soft_budget": 50.0,
}
resp = client.post("/budget/update", json=payload)
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["budget_id"] == payload["budget_id"]
assert body["max_budget"] == payload["max_budget"]
assert body["soft_budget"] == payload["soft_budget"]
assert body["updated_by"] == "test_user"
@pytest.mark.asyncio
async def test_update_budget_missing_id(client_and_mocks, monkeypatch):
client, mock_prisma, mock_table = client_and_mocks
payload = {"max_budget": 10.0}
resp = client.post("/budget/update", json=payload)
assert resp.status_code == 400, resp.text
detail = resp.json()["detail"]
assert detail["error"] == "budget_id is required"
@pytest.mark.asyncio
async def test_update_budget_db_not_connected(client_and_mocks, monkeypatch):
client, mock_prisma, mock_table = client_and_mocks
# override the prisma_client that the handler imports at runtime
import litellm.proxy.proxy_server as ps
monkeypatch.setattr(ps, "prisma_client", None)
payload = {"budget_id": "any", "max_budget": 1.0}
resp = client.post("/budget/update", json=payload)
assert resp.status_code == 500
detail = resp.json()["detail"]
assert detail["error"] == CommonProxyErrors.db_not_connected_error.value
|