Spaces:
Configuration error
Configuration error
# Test the following scenarios: | |
# 1. Generate a Key, and use it to make a call | |
# 2. Make a call with invalid key, expect it to fail | |
# 3. Make a call to a key with invalid model - expect to fail | |
# 4. Make a call to a key with valid model - expect to pass | |
# 5. Make a call with user over budget, expect to fail | |
# 6. Make a streaming chat/completions call with user over budget, expect to fail | |
# 7. Make a call with an key that never expires, expect to pass | |
# 8. Make a call with an expired key, expect to fail | |
# 9. Delete a Key | |
# 10. Generate a key, call key/info. Assert info returned is the same as generated key info | |
# 11. Generate a Key, cal key/info, call key/update, call key/info | |
# 12. Make a call with key over budget, expect to fail | |
# 14. Make a streaming chat/completions call with key over budget, expect to fail | |
# 15. Generate key, when `allow_user_auth`=False - check if `/key/info` returns key_name=null | |
# 16. Generate key, when `allow_user_auth`=True - check if `/key/info` returns key_name=sk...<last-4-digits> | |
# function to call to generate key - async def new_user(data: NewUserRequest): | |
# function to validate a request - async def user_auth(request: Request): | |
import os | |
import sys | |
import traceback | |
import uuid | |
from datetime import datetime, timezone | |
from dotenv import load_dotenv | |
from fastapi import Request | |
from fastapi.routing import APIRoute | |
import httpx | |
load_dotenv() | |
import io | |
import os | |
import time | |
# this file is to test litellm/proxy | |
sys.path.insert( | |
0, os.path.abspath("../..") | |
) # Adds the parent directory to the system path | |
import asyncio | |
import logging | |
import pytest | |
import litellm | |
from litellm._logging import verbose_proxy_logger | |
from litellm.proxy.management_endpoints.internal_user_endpoints import ( | |
new_user, | |
user_info, | |
user_update, | |
) | |
from litellm.proxy.auth.auth_checks import get_key_object | |
from litellm.proxy.management_endpoints.key_management_endpoints import ( | |
delete_key_fn, | |
generate_key_fn, | |
generate_key_helper_fn, | |
info_key_fn, | |
list_keys, | |
regenerate_key_fn, | |
update_key_fn, | |
) | |
from litellm.proxy.management_endpoints.team_endpoints import ( | |
new_team, | |
team_info, | |
update_team, | |
) | |
from litellm.proxy.proxy_server import ( | |
LitellmUserRoles, | |
audio_transcriptions, | |
chat_completion, | |
completion, | |
embeddings, | |
model_list, | |
moderations, | |
user_api_key_auth, | |
) | |
from litellm.proxy.image_endpoints import image_generation | |
from litellm.proxy.management_endpoints.customer_endpoints import ( | |
new_end_user, | |
) | |
from litellm.proxy.spend_tracking.spend_management_endpoints import ( | |
global_spend, | |
spend_key_fn, | |
spend_user_fn, | |
view_spend_logs, | |
) | |
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend | |
verbose_proxy_logger.setLevel(level=logging.DEBUG) | |
from starlette.datastructures import URL | |
from litellm.caching.caching import DualCache | |
from litellm.proxy._types import ( | |
DynamoDBArgs, | |
GenerateKeyRequest, | |
KeyRequest, | |
LiteLLM_UpperboundKeyGenerateParams, | |
NewCustomerRequest, | |
NewTeamRequest, | |
NewUserRequest, | |
ProxyErrorTypes, | |
ProxyException, | |
UpdateKeyRequest, | |
UpdateTeamRequest, | |
UpdateUserRequest, | |
UserAPIKeyAuth, | |
) | |
proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) | |
request_data = { | |
"model": "azure-gpt-3.5", | |
"messages": [ | |
{"role": "user", "content": "this is my new test. respond in 50 lines"} | |
], | |
} | |
def prisma_client(): | |
from litellm.proxy.proxy_cli import append_query_params | |
### add connection pool + pool timeout args | |
params = {"connection_limit": 100, "pool_timeout": 60} | |
database_url = os.getenv("DATABASE_URL") | |
modified_url = append_query_params(database_url, params) | |
os.environ["DATABASE_URL"] = modified_url | |
# Assuming PrismaClient is a class that needs to be instantiated | |
prisma_client = PrismaClient( | |
database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj | |
) | |
# Reset litellm.proxy.proxy_server.prisma_client to None | |
litellm.proxy.proxy_server.litellm_proxy_budget_name = ( | |
f"litellm-proxy-budget-{time.time()}" | |
) | |
litellm.proxy.proxy_server.user_custom_key_generate = None | |
return prisma_client | |
async def test_new_user_response(prisma_client): | |
try: | |
print("prisma client=", prisma_client) | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
from litellm.proxy.proxy_server import user_api_key_cache | |
_team_id = "ishaan-special-team_{}".format(uuid.uuid4()) | |
await new_team( | |
NewTeamRequest( | |
team_id=_team_id, | |
), | |
http_request=Request(scope={"type": "http"}), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
_response = await new_user( | |
data=NewUserRequest( | |
models=["azure-gpt-3.5"], | |
team_id=_team_id, | |
tpm_limit=20, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
) | |
print(_response) | |
assert _response.models == ["azure-gpt-3.5"] | |
assert _response.team_id == _team_id | |
assert _response.tpm_limit == 20 | |
except Exception as e: | |
print("Got Exception", e) | |
pytest.fail(f"Got exception {e}") | |
def test_generate_and_call_with_valid_key(prisma_client, api_route): | |
# 1. Generate a Key, and use it to make a call | |
from unittest.mock import MagicMock | |
print("prisma client=", prisma_client) | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
from litellm.proxy.proxy_server import user_api_key_cache | |
user_api_key_dict = UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
) | |
request = NewUserRequest(user_role=LitellmUserRoles.INTERNAL_USER) | |
key = await new_user(request, user_api_key_dict=user_api_key_dict) | |
print(key) | |
user_id = key.user_id | |
# check /user/info to verify user_role was set correctly | |
request_mock = MagicMock() | |
new_user_info = await user_info( | |
request=request_mock, user_id=user_id, user_api_key_dict=user_api_key_dict | |
) | |
new_user_info = new_user_info.user_info | |
print("new_user_info=", new_user_info) | |
assert new_user_info["user_role"] == LitellmUserRoles.INTERNAL_USER | |
assert new_user_info["user_id"] == user_id | |
generated_key = key.key | |
bearer_token = "Bearer " + generated_key | |
assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict | |
value_from_prisma = await prisma_client.get_data( | |
token=generated_key, | |
) | |
print("token from prisma", value_from_prisma) | |
request = Request( | |
{ | |
"type": "http", | |
"route": api_route, | |
"path": api_route.path, | |
"headers": [("Authorization", bearer_token)], | |
} | |
) | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
asyncio.run(test()) | |
except Exception as e: | |
pytest.fail(f"An exception occurred - {str(e)}") | |
def test_call_with_invalid_key(prisma_client): | |
# 2. Make a call with invalid key, expect it to fail | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
generated_key = "sk-126666" | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}, receive=None) | |
request._url = URL(url="/chat/completions") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("got result", result) | |
pytest.fail(f"This should have failed!. IT's an invalid key") | |
asyncio.run(test()) | |
except Exception as e: | |
print("Got Exception", e) | |
print(e.message) | |
assert "Authentication Error, Invalid proxy server token passed" in e.message | |
pass | |
def test_call_with_invalid_model(prisma_client): | |
litellm.set_verbose = True | |
# 3. Make a call to a key with an invalid model - expect to fail | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = NewUserRequest(models=["mistral"]) | |
key = await new_user( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
async def return_body(): | |
return b'{"model": "gemini-pro-vision"}' | |
request.body = return_body | |
# use generated key to auth in | |
print( | |
"Bearer token being sent to user_api_key_auth() - {}".format( | |
bearer_token | |
) | |
) | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
pytest.fail(f"This should have failed!. IT's an invalid model") | |
asyncio.run(test()) | |
except Exception as e: | |
assert isinstance(e, ProxyException) | |
assert e.type == ProxyErrorTypes.key_model_access_denied | |
assert e.param == "model" | |
def test_call_with_valid_model(prisma_client): | |
# 4. Make a call to a key with a valid model - expect to pass | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = NewUserRequest(models=["mistral"]) | |
key = await new_user( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
async def return_body(): | |
return b'{"model": "mistral"}' | |
request.body = return_body | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
asyncio.run(test()) | |
except Exception as e: | |
pytest.fail(f"An exception occurred - {str(e)}") | |
async def test_call_with_valid_model_using_all_models(prisma_client): | |
""" | |
Do not delete | |
this is the Admin UI flow | |
1. Create a team with model = `all-proxy-models` | |
2. Create a key with model = `all-team-models` | |
3. Call /chat/completions with the key -> expect to pass | |
""" | |
# Make a call to a key with model = `all-proxy-models` this is an Alias from LiteLLM Admin UI | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
team_request = NewTeamRequest( | |
team_alias="testing-team", | |
models=["all-proxy-models"], | |
) | |
new_team_response = await new_team( | |
data=team_request, | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
http_request=Request(scope={"type": "http"}), | |
) | |
print("new_team_response", new_team_response) | |
created_team_id = new_team_response["team_id"] | |
request = GenerateKeyRequest( | |
models=["all-team-models"], team_id=created_team_id | |
) | |
key = await generate_key_fn( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
async def return_body(): | |
return b'{"model": "mistral"}' | |
request.body = return_body | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
# call /key/info for key - models == "all-proxy-models" | |
key_info = await info_key_fn( | |
key=generated_key, | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
) | |
print("key_info", key_info) | |
models = key_info["info"]["models"] | |
assert models == ["all-team-models"] | |
except Exception as e: | |
pytest.fail(f"An exception occurred - {str(e)}") | |
def test_call_with_user_over_budget(prisma_client): | |
# 5. Make a call with a key over budget, expect to fail | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = NewUserRequest(max_budget=0.00001) | |
key = await new_user( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
user_id = key.user_id | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
# update spend using track_cost callback, make 2nd request, it should fail | |
from litellm import Choices, Message, ModelResponse, Usage | |
from litellm.proxy.proxy_server import _ProxyDBLogger | |
proxy_db_logger = _ProxyDBLogger() | |
resp = ModelResponse( | |
id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac", | |
choices=[ | |
Choices( | |
finish_reason=None, | |
index=0, | |
message=Message( | |
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
role="assistant", | |
), | |
) | |
], | |
model="gpt-35-turbo", # azure always has model written like this | |
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
) | |
await proxy_db_logger._PROXY_track_cost_callback( | |
kwargs={ | |
"stream": False, | |
"litellm_params": { | |
"metadata": { | |
"user_api_key": generated_key, | |
"user_api_key_user_id": user_id, | |
} | |
}, | |
"response_cost": 0.00002, | |
}, | |
completion_response=resp, | |
start_time=datetime.now(), | |
end_time=datetime.now(), | |
) | |
await asyncio.sleep(5) | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
pytest.fail("This should have failed!. They key crossed it's budget") | |
asyncio.run(test()) | |
except Exception as e: | |
print("got an errror=", e) | |
error_detail = e.message | |
assert "ExceededBudget:" in error_detail | |
assert isinstance(e, ProxyException) | |
assert e.type == ProxyErrorTypes.budget_exceeded | |
print(vars(e)) | |
def test_end_user_cache_write_unit_test(): | |
""" | |
assert end user object is being written to cache as expected | |
""" | |
pass | |
def test_call_with_end_user_over_budget(prisma_client): | |
# Test if a user passed to /chat/completions is tracked & fails when they cross their budget | |
# we only check this when litellm.max_end_user_budget is set | |
import random | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
setattr(litellm, "max_end_user_budget", 0.00001) | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
user = f"ishaan {uuid.uuid4().hex}" | |
request = NewCustomerRequest( | |
user_id=user, max_budget=0.000001 | |
) # create a key with no budget | |
await new_end_user( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
bearer_token = "Bearer sk-1234" | |
async def return_body(): | |
return_string = f'{{"model": "gemini-pro-vision", "user": "{user}"}}' | |
# return string as bytes | |
return return_string.encode() | |
request.body = return_body | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
# update spend using track_cost callback, make 2nd request, it should fail | |
from litellm import Choices, Message, ModelResponse, Usage | |
from litellm.proxy.proxy_server import _ProxyDBLogger | |
proxy_db_logger = _ProxyDBLogger() | |
resp = ModelResponse( | |
id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac", | |
choices=[ | |
Choices( | |
finish_reason=None, | |
index=0, | |
message=Message( | |
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
role="assistant", | |
), | |
) | |
], | |
model="gpt-35-turbo", # azure always has model written like this | |
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
) | |
await proxy_db_logger._PROXY_track_cost_callback( | |
kwargs={ | |
"stream": False, | |
"litellm_params": { | |
"metadata": { | |
"user_api_key": "sk-1234", | |
"user_api_key_end_user_id": user, | |
}, | |
"proxy_server_request": { | |
"body": { | |
"user": user, | |
} | |
}, | |
}, | |
"response_cost": 10, | |
}, | |
completion_response=resp, | |
start_time=datetime.now(), | |
end_time=datetime.now(), | |
) | |
await asyncio.sleep(10) | |
await update_spend( | |
prisma_client=prisma_client, | |
db_writer_client=None, | |
proxy_logging_obj=proxy_logging_obj, | |
) | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
pytest.fail("This should have failed!. They key crossed it's budget") | |
asyncio.run(test()) | |
except Exception as e: | |
print(f"raised error: {e}, traceback: {traceback.format_exc()}") | |
error_detail = e.message | |
assert "Budget has been exceeded! Current" in error_detail | |
assert isinstance(e, ProxyException) | |
assert e.type == ProxyErrorTypes.budget_exceeded | |
print(vars(e)) | |
def test_call_with_proxy_over_budget(prisma_client): | |
# 5.1 Make a call with a proxy over budget, expect to fail | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
litellm_proxy_budget_name = f"litellm-proxy-budget-{time.time()}" | |
setattr( | |
litellm.proxy.proxy_server, | |
"litellm_proxy_admin_name", | |
litellm_proxy_budget_name, | |
) | |
setattr(litellm, "max_budget", 0.00001) | |
from litellm.proxy.proxy_server import user_api_key_cache | |
user_api_key_cache.set_cache( | |
key="{}:spend".format(litellm_proxy_budget_name), value=0 | |
) | |
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = NewUserRequest() | |
key = await new_user( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
user_id = key.user_id | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
# update spend using track_cost callback, make 2nd request, it should fail | |
from litellm import Choices, Message, ModelResponse, Usage | |
from litellm.proxy.proxy_server import _ProxyDBLogger | |
proxy_db_logger = _ProxyDBLogger() | |
resp = ModelResponse( | |
id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac", | |
choices=[ | |
Choices( | |
finish_reason=None, | |
index=0, | |
message=Message( | |
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
role="assistant", | |
), | |
) | |
], | |
model="gpt-35-turbo", # azure always has model written like this | |
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
) | |
await proxy_db_logger._PROXY_track_cost_callback( | |
kwargs={ | |
"stream": False, | |
"litellm_params": { | |
"metadata": { | |
"user_api_key": generated_key, | |
"user_api_key_user_id": user_id, | |
} | |
}, | |
"response_cost": 0.00002, | |
}, | |
completion_response=resp, | |
start_time=datetime.now(), | |
end_time=datetime.now(), | |
) | |
await asyncio.sleep(5) | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
pytest.fail(f"This should have failed!. They key crossed it's budget") | |
asyncio.run(test()) | |
except Exception as e: | |
if hasattr(e, "message"): | |
error_detail = e.message | |
else: | |
error_detail = traceback.format_exc() | |
assert "Budget has been exceeded" in error_detail | |
assert isinstance(e, ProxyException) | |
assert e.type == ProxyErrorTypes.budget_exceeded | |
print(vars(e)) | |
def test_call_with_user_over_budget_stream(prisma_client): | |
# 6. Make a call with a key over budget, expect to fail | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
import logging | |
from litellm._logging import verbose_proxy_logger | |
litellm.set_verbose = True | |
verbose_proxy_logger.setLevel(logging.DEBUG) | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = NewUserRequest(max_budget=0.00001) | |
key = await new_user( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
user_id = key.user_id | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
# update spend using track_cost callback, make 2nd request, it should fail | |
from litellm import Choices, Message, ModelResponse, Usage | |
from litellm.proxy.proxy_server import _ProxyDBLogger | |
proxy_db_logger = _ProxyDBLogger() | |
resp = ModelResponse( | |
id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac", | |
choices=[ | |
Choices( | |
finish_reason=None, | |
index=0, | |
message=Message( | |
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
role="assistant", | |
), | |
) | |
], | |
model="gpt-35-turbo", # azure always has model written like this | |
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
) | |
await proxy_db_logger._PROXY_track_cost_callback( | |
kwargs={ | |
"stream": True, | |
"complete_streaming_response": resp, | |
"litellm_params": { | |
"metadata": { | |
"user_api_key": generated_key, | |
"user_api_key_user_id": user_id, | |
} | |
}, | |
"response_cost": 0.00002, | |
}, | |
completion_response=ModelResponse(), | |
start_time=datetime.now(), | |
end_time=datetime.now(), | |
) | |
await asyncio.sleep(5) | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
pytest.fail("This should have failed!. They key crossed it's budget") | |
asyncio.run(test()) | |
except Exception as e: | |
error_detail = e.message | |
assert "ExceededBudget:" in error_detail | |
assert isinstance(e, ProxyException) | |
assert e.type == ProxyErrorTypes.budget_exceeded | |
print(vars(e)) | |
def test_call_with_proxy_over_budget_stream(prisma_client): | |
# 6.1 Make a call with a global proxy over budget, expect to fail | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
litellm_proxy_budget_name = f"litellm-proxy-budget-{time.time()}" | |
setattr( | |
litellm.proxy.proxy_server, | |
"litellm_proxy_admin_name", | |
litellm_proxy_budget_name, | |
) | |
setattr(litellm, "max_budget", 0.00001) | |
from litellm.proxy.proxy_server import user_api_key_cache | |
user_api_key_cache.set_cache( | |
key="{}:spend".format(litellm_proxy_budget_name), value=0 | |
) | |
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) | |
import logging | |
from litellm._logging import verbose_proxy_logger | |
litellm.set_verbose = True | |
verbose_proxy_logger.setLevel(logging.DEBUG) | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
## CREATE PROXY + USER BUDGET ## | |
# request = NewUserRequest( | |
# max_budget=0.00001, user_id=litellm_proxy_budget_name | |
# ) | |
request = NewUserRequest() | |
key = await new_user( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
user_id = key.user_id | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
# update spend using track_cost callback, make 2nd request, it should fail | |
from litellm import Choices, Message, ModelResponse, Usage | |
from litellm.proxy.proxy_server import _ProxyDBLogger | |
proxy_db_logger = _ProxyDBLogger() | |
resp = ModelResponse( | |
id="chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac", | |
choices=[ | |
Choices( | |
finish_reason=None, | |
index=0, | |
message=Message( | |
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
role="assistant", | |
), | |
) | |
], | |
model="gpt-35-turbo", # azure always has model written like this | |
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
) | |
await proxy_db_logger._PROXY_track_cost_callback( | |
kwargs={ | |
"stream": True, | |
"complete_streaming_response": resp, | |
"litellm_params": { | |
"metadata": { | |
"user_api_key": generated_key, | |
"user_api_key_user_id": user_id, | |
} | |
}, | |
"response_cost": 0.00002, | |
}, | |
completion_response=ModelResponse(), | |
start_time=datetime.now(), | |
end_time=datetime.now(), | |
) | |
await asyncio.sleep(5) | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
pytest.fail(f"This should have failed!. They key crossed it's budget") | |
asyncio.run(test()) | |
except Exception as e: | |
error_detail = e.message | |
assert "Budget has been exceeded" in error_detail | |
print(vars(e)) | |
def test_generate_and_call_with_valid_key_never_expires(prisma_client): | |
# 7. Make a call with an key that never expires, expect to pass | |
print("prisma client=", prisma_client) | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = NewUserRequest(duration=None) | |
key = await new_user( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
asyncio.run(test()) | |
except Exception as e: | |
pytest.fail(f"An exception occurred - {str(e)}") | |
def test_generate_and_call_with_expired_key(prisma_client): | |
# 8. Make a call with an expired key, expect to fail | |
print("prisma client=", prisma_client) | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = NewUserRequest(duration="0s") | |
key = await new_user( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
pytest.fail("This should have failed!. It's an expired key") | |
asyncio.run(test()) | |
except Exception as e: | |
print("Got Exception", e) | |
print(e.message) | |
assert "Authentication Error" in e.message | |
assert e.type == ProxyErrorTypes.expired_key | |
pass | |
def test_delete_key(prisma_client): | |
# 9. Generate a Key, delete it. Check if deletion works fine | |
print("prisma client=", prisma_client) | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
setattr(litellm.proxy.proxy_server, "user_custom_auth", None) | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
from litellm.proxy.proxy_server import user_api_key_cache | |
request = NewUserRequest() | |
key = await new_user( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
bearer_token = "Bearer " + generated_key | |
delete_key_request = KeyRequest(keys=[generated_key]) | |
bearer_token = "Bearer sk-1234" | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/key/delete") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print(f"result: {result}") | |
result.user_role = LitellmUserRoles.PROXY_ADMIN | |
# delete the key | |
result_delete_key = await delete_key_fn( | |
data=delete_key_request, user_api_key_dict=result | |
) | |
print("result from delete key", result_delete_key) | |
assert result_delete_key == {"deleted_keys": [generated_key]} | |
assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict | |
assert ( | |
hash_token(generated_key) | |
not in user_api_key_cache.in_memory_cache.cache_dict | |
) | |
asyncio.run(test()) | |
except Exception as e: | |
pytest.fail(f"An exception occurred - {str(e)}") | |
def test_delete_key_auth(prisma_client): | |
# 10. Generate a Key, delete it, use it to make a call -> expect fail | |
print("prisma client=", prisma_client) | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
from litellm.proxy.proxy_server import user_api_key_cache | |
request = NewUserRequest() | |
key = await new_user( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
bearer_token = "Bearer " + generated_key | |
delete_key_request = KeyRequest(keys=[generated_key]) | |
# delete the key | |
bearer_token = "Bearer sk-1234" | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/key/delete") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print(f"result: {result}") | |
result.user_role = LitellmUserRoles.PROXY_ADMIN | |
result_delete_key = await delete_key_fn( | |
data=delete_key_request, user_api_key_dict=result | |
) | |
print("result from delete key", result_delete_key) | |
assert result_delete_key == {"deleted_keys": [generated_key]} | |
request = Request(scope={"type": "http"}, receive=None) | |
request._url = URL(url="/chat/completions") | |
assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict | |
assert ( | |
hash_token(generated_key) | |
not in user_api_key_cache.in_memory_cache.cache_dict | |
) | |
# use generated key to auth in | |
bearer_token = "Bearer " + generated_key | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("got result", result) | |
pytest.fail(f"This should have failed!. IT's an invalid key") | |
asyncio.run(test()) | |
except Exception as e: | |
print("Got Exception", e) | |
print(e.message) | |
assert "Authentication Error" in e.message | |
pass | |
def test_generate_and_call_key_info(prisma_client): | |
# 10. Generate a Key, cal key/info | |
print("prisma client=", prisma_client) | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = NewUserRequest( | |
metadata={"team": "litellm-team3", "project": "litellm-project3"} | |
) | |
key = await new_user( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
# use generated key to auth in | |
result = await info_key_fn( | |
key=generated_key, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
), | |
) | |
print("result from info_key_fn", result) | |
assert result["key"] == generated_key | |
print("\n info for key=", result["info"]) | |
assert result["info"]["max_parallel_requests"] == None | |
assert result["info"]["metadata"] == { | |
"team": "litellm-team3", | |
"project": "litellm-project3", | |
} | |
# cleanup - delete key | |
delete_key_request = KeyRequest(keys=[generated_key]) | |
bearer_token = "Bearer sk-1234" | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/key/delete") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print(f"result: {result}") | |
result.user_role = LitellmUserRoles.PROXY_ADMIN | |
result_delete_key = await delete_key_fn( | |
data=delete_key_request, user_api_key_dict=result | |
) | |
asyncio.run(test()) | |
except Exception as e: | |
pytest.fail(f"An exception occurred - {str(e)}") | |
def test_generate_and_update_key(prisma_client): | |
# 11. Generate a Key, cal key/info, call key/update, call key/info | |
# Check if data gets updated | |
# Check if untouched data does not get updated | |
import uuid | |
print("prisma client=", prisma_client) | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
# create team "litellm-core-infra@gmail.com"" | |
print("creating team litellm-core-infra@gmail.com") | |
_team_1 = "litellm-core-infra@gmail.com_{}".format(uuid.uuid4()) | |
await new_team( | |
NewTeamRequest( | |
team_id=_team_1, | |
), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
http_request=Request(scope={"type": "http"}), | |
) | |
_team_2 = "ishaan-special-team_{}".format(uuid.uuid4()) | |
await new_team( | |
NewTeamRequest( | |
team_id=_team_2, | |
), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
http_request=Request(scope={"type": "http"}), | |
) | |
request = NewUserRequest( | |
metadata={"project": "litellm-project3"}, | |
team_id=_team_1, | |
) | |
key = await new_user( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
# use generated key to auth in | |
result = await info_key_fn( | |
key=generated_key, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
), | |
) | |
print("result from info_key_fn", result) | |
assert result["key"] == generated_key | |
print("\n info for key=", result["info"]) | |
assert result["info"]["max_parallel_requests"] == None | |
assert result["info"]["metadata"] == { | |
"project": "litellm-project3", | |
} | |
assert result["info"]["team_id"] == _team_1 | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/update/key") | |
# update the key | |
response1 = await update_key_fn( | |
request=Request, | |
data=UpdateKeyRequest( | |
key=generated_key, | |
models=["ada", "babbage", "curie", "davinci"], | |
budget_duration="1mo", | |
max_budget=100, | |
), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print("response1=", response1) | |
# update the tpm limit | |
response2 = await update_key_fn( | |
request=Request, | |
data=UpdateKeyRequest(key=generated_key, tpm_limit=1000), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print("response2=", response2) | |
# get info on key after update | |
result = await info_key_fn( | |
key=generated_key, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
), | |
) | |
print("result from info_key_fn", result) | |
assert result["key"] == generated_key | |
print("\n info for key=", result["info"]) | |
assert result["info"]["max_parallel_requests"] == None | |
assert result["info"]["metadata"] == { | |
"project": "litellm-project3", | |
} | |
assert result["info"]["models"] == ["ada", "babbage", "curie", "davinci"] | |
assert result["info"]["tpm_limit"] == 1000 | |
assert result["info"]["budget_duration"] == "1mo" | |
assert result["info"]["max_budget"] == 100 | |
# budget_reset_at should exist for "1mo" duration | |
assert result["info"]["budget_reset_at"] is not None | |
budget_reset_at = result["info"]["budget_reset_at"].replace(tzinfo=timezone.utc) | |
current_time = datetime.now(timezone.utc) | |
print(f"Budget reset time: {budget_reset_at}") | |
print(f"Current time: {current_time}") | |
# Instead of checking exact timing, just verify that: | |
# 1. Both are in the same day (for tests running same day) | |
# 2. Or budget_reset_at is in next month | |
if budget_reset_at.day == current_time.day: | |
# Same day of month - just check month difference | |
month_diff = budget_reset_at.month - current_time.month | |
if budget_reset_at.year > current_time.year: | |
month_diff += 12 | |
# Should be scheduled for next month (at least 0.5 month away) | |
assert month_diff >= 1, f"Expected reset to be at least 1 month ahead, got {month_diff} months" | |
assert month_diff <= 2, f"Expected reset to be at most 2 months ahead, got {month_diff} months" | |
else: | |
# Just ensure the date is reasonable (not more than 40 days away) | |
days_diff = (budget_reset_at - current_time).days | |
assert 0 <= days_diff <= 40, f"Expected reset date to be reasonable, got {days_diff} days from now" | |
# cleanup - delete key | |
delete_key_request = KeyRequest(keys=[generated_key]) | |
# delete the key | |
bearer_token = "Bearer sk-1234" | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/key/delete") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print(f"result: {result}") | |
result.user_role = LitellmUserRoles.PROXY_ADMIN | |
result_delete_key = await delete_key_fn( | |
data=delete_key_request, user_api_key_dict=result | |
) | |
asyncio.run(test()) | |
except Exception as e: | |
print("Got Exception", e) | |
pytest.fail(f"An exception occurred - {str(e)}\n{traceback.format_exc()}") | |
def test_key_generate_with_custom_auth(prisma_client): | |
# custom - generate key function | |
async def custom_generate_key_fn(data: GenerateKeyRequest) -> dict: | |
""" | |
Asynchronous function for generating a key based on the input data. | |
Args: | |
data (GenerateKeyRequest): The input data for key generation. | |
Returns: | |
dict: A dictionary containing the decision and an optional message. | |
{ | |
"decision": False, | |
"message": "This violates LiteLLM Proxy Rules. No team id provided.", | |
} | |
""" | |
# decide if a key should be generated or not | |
print("using custom auth function!") | |
data_json = data.json() # type: ignore | |
# Unpacking variables | |
team_id = data_json.get("team_id") | |
duration = data_json.get("duration") | |
models = data_json.get("models") | |
aliases = data_json.get("aliases") | |
config = data_json.get("config") | |
spend = data_json.get("spend") | |
user_id = data_json.get("user_id") | |
max_parallel_requests = data_json.get("max_parallel_requests") | |
metadata = data_json.get("metadata") | |
tpm_limit = data_json.get("tpm_limit") | |
rpm_limit = data_json.get("rpm_limit") | |
if team_id is not None and team_id == "litellm-core-infra@gmail.com": | |
# only team_id="litellm-core-infra@gmail.com" can make keys | |
return { | |
"decision": True, | |
} | |
else: | |
print("Failed custom auth") | |
return { | |
"decision": False, | |
"message": "This violates LiteLLM Proxy Rules. No team id provided.", | |
} | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
setattr( | |
litellm.proxy.proxy_server, "user_custom_key_generate", custom_generate_key_fn | |
) | |
try: | |
async def test(): | |
try: | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = GenerateKeyRequest() | |
key = await generate_key_fn( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
pytest.fail(f"Expected an exception. Got {key}") | |
except Exception as e: | |
# this should fail | |
print("Got Exception", e) | |
print(e.message) | |
print("First request failed!. This is expected") | |
assert ( | |
"This violates LiteLLM Proxy Rules. No team id provided." | |
in e.message | |
) | |
request_2 = GenerateKeyRequest( | |
team_id="litellm-core-infra@gmail.com", | |
) | |
key = await generate_key_fn( | |
request_2, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
asyncio.run(test()) | |
except Exception as e: | |
print("Got Exception", e) | |
print(e.message) | |
pytest.fail(f"An exception occurred - {str(e)}") | |
def test_call_with_key_over_budget(prisma_client): | |
# 12. Make a call with a key over budget, expect to fail | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = GenerateKeyRequest(max_budget=0.00001) | |
key = await generate_key_fn( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
user_id = key.user_id | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
# update spend using track_cost callback, make 2nd request, it should fail | |
from litellm import Choices, Message, ModelResponse, Usage | |
from litellm.caching.caching import Cache | |
from litellm.proxy.proxy_server import _ProxyDBLogger | |
proxy_db_logger = _ProxyDBLogger() | |
litellm.cache = Cache() | |
import time | |
import uuid | |
request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{uuid.uuid4()}" | |
resp = ModelResponse( | |
id=request_id, | |
choices=[ | |
Choices( | |
finish_reason=None, | |
index=0, | |
message=Message( | |
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
role="assistant", | |
), | |
) | |
], | |
model="gpt-35-turbo", # azure always has model written like this | |
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
) | |
await proxy_db_logger._PROXY_track_cost_callback( | |
kwargs={ | |
"model": "chatgpt-v-3", | |
"stream": False, | |
"litellm_params": { | |
"metadata": { | |
"user_api_key": hash_token(generated_key), | |
"user_api_key_user_id": user_id, | |
} | |
}, | |
"response_cost": 0.00002, | |
}, | |
completion_response=resp, | |
start_time=datetime.now(), | |
end_time=datetime.now(), | |
) | |
await update_spend( | |
prisma_client=prisma_client, | |
db_writer_client=None, | |
proxy_logging_obj=proxy_logging_obj, | |
) | |
# test spend_log was written and we can read it | |
spend_logs = await view_spend_logs( | |
request_id=request_id, | |
user_api_key_dict=UserAPIKeyAuth(api_key=generated_key), | |
) | |
print("read spend logs", spend_logs) | |
assert len(spend_logs) == 1 | |
spend_log = spend_logs[0] | |
assert spend_log.request_id == request_id | |
assert spend_log.spend == float("2e-05") | |
assert spend_log.model == "chatgpt-v-3" | |
assert ( | |
spend_log.cache_key | |
== "509ba0554a7129ae4f4fd13d11c141acce5549bb6aaf1f629ed543101615658e" | |
) | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
pytest.fail("This should have failed!. They key crossed it's budget") | |
asyncio.run(test()) | |
except Exception as e: | |
# print(f"Error - {str(e)}") | |
traceback.print_exc() | |
if hasattr(e, "message"): | |
error_detail = e.message | |
else: | |
error_detail = str(e) | |
assert "Budget has been exceeded" in error_detail | |
assert isinstance(e, ProxyException) | |
assert e.type == ProxyErrorTypes.budget_exceeded | |
print(vars(e)) | |
def test_call_with_key_over_budget_no_cache(prisma_client): | |
# 12. Make a call with a key over budget, expect to fail | |
# ✅ Tests if spend trackign works when the key does not exist in memory | |
# Related to this: https://github.com/BerriAI/litellm/issues/3920 | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = GenerateKeyRequest(max_budget=0.00001) | |
key = await generate_key_fn( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
user_id = key.user_id | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
# update spend using track_cost callback, make 2nd request, it should fail | |
from litellm.proxy.proxy_server import _ProxyDBLogger | |
from litellm.proxy.proxy_server import user_api_key_cache | |
user_api_key_cache.in_memory_cache.cache_dict = {} | |
setattr(litellm.proxy.proxy_server, "proxy_batch_write_at", 1) | |
from litellm import Choices, Message, ModelResponse, Usage | |
from litellm.caching.caching import Cache | |
litellm.cache = Cache() | |
import time | |
import uuid | |
request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{uuid.uuid4()}" | |
resp = ModelResponse( | |
id=request_id, | |
choices=[ | |
Choices( | |
finish_reason=None, | |
index=0, | |
message=Message( | |
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
role="assistant", | |
), | |
) | |
], | |
model="gpt-35-turbo", # azure always has model written like this | |
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
) | |
proxy_db_logger = _ProxyDBLogger() | |
await proxy_db_logger._PROXY_track_cost_callback( | |
kwargs={ | |
"model": "chatgpt-v-3", | |
"stream": False, | |
"litellm_params": { | |
"metadata": { | |
"user_api_key": hash_token(generated_key), | |
"user_api_key_user_id": user_id, | |
} | |
}, | |
"response_cost": 0.00002, | |
}, | |
completion_response=resp, | |
start_time=datetime.now(), | |
end_time=datetime.now(), | |
) | |
await asyncio.sleep(10) | |
await update_spend( | |
prisma_client=prisma_client, | |
db_writer_client=None, | |
proxy_logging_obj=proxy_logging_obj, | |
) | |
# test spend_log was written and we can read it | |
spend_logs = await view_spend_logs( | |
request_id=request_id, | |
user_api_key_dict=UserAPIKeyAuth(api_key=generated_key), | |
) | |
print("read spend logs", spend_logs) | |
assert len(spend_logs) == 1 | |
spend_log = spend_logs[0] | |
assert spend_log.request_id == request_id | |
assert spend_log.spend == float("2e-05") | |
assert spend_log.model == "chatgpt-v-3" | |
assert ( | |
spend_log.cache_key | |
== "509ba0554a7129ae4f4fd13d11c141acce5549bb6aaf1f629ed543101615658e" | |
) | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
pytest.fail(f"This should have failed!. They key crossed it's budget") | |
asyncio.run(test()) | |
except Exception as e: | |
# print(f"Error - {str(e)}") | |
traceback.print_exc() | |
if hasattr(e, "message"): | |
error_detail = e.message | |
else: | |
error_detail = str(e) | |
assert "Budget has been exceeded" in error_detail | |
assert isinstance(e, ProxyException) | |
assert e.type == ProxyErrorTypes.budget_exceeded | |
print(vars(e)) | |
async def test_call_with_key_over_model_budget( | |
prisma_client, request_model, should_pass | |
): | |
# 12. Make a call with a key over budget, expect to fail | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
verbose_proxy_logger.setLevel(logging.DEBUG) | |
# init model max budget limiter | |
from litellm.proxy.hooks.model_max_budget_limiter import ( | |
_PROXY_VirtualKeyModelMaxBudgetLimiter, | |
) | |
model_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter( | |
dual_cache=DualCache() | |
) | |
litellm.callbacks.append(model_budget_limiter) | |
try: | |
# set budget for chatgpt-v-3 to 0.000001, expect the next request to fail | |
model_max_budget = { | |
"gpt-4o-mini": { | |
"budget_limit": "0.000001", | |
"time_period": "1d", | |
}, | |
"gpt-4o": { | |
"budget_limit": "200", | |
"time_period": "30d", | |
}, | |
} | |
request = GenerateKeyRequest( | |
max_budget=100000, # the key itself has a very high budget | |
model_max_budget=model_max_budget, | |
) | |
key = await generate_key_fn( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
user_id = key.user_id | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
async def return_body(): | |
request_str = f'{{"model": "{request_model}"}}' # Added extra curly braces to escape JSON | |
return request_str.encode() | |
request.body = return_body | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
# update spend using track_cost callback, make 2nd request, it should fail | |
await litellm.acompletion( | |
model=request_model, | |
messages=[{"role": "user", "content": "Hello, how are you?"}], | |
metadata={ | |
"user_api_key": hash_token(generated_key), | |
"user_api_key_model_max_budget": model_max_budget, | |
}, | |
) | |
await asyncio.sleep(2) | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
if should_pass is True: | |
print( | |
f"Passed request for model={request_model}, model_max_budget={model_max_budget}" | |
) | |
return | |
print("result from user auth with new key", result) | |
pytest.fail("This should have failed!. They key crossed it's budget") | |
except Exception as e: | |
# print(f"Error - {str(e)}") | |
print( | |
f"Failed request for model={request_model}, model_max_budget={model_max_budget}" | |
) | |
assert ( | |
should_pass is False | |
), f"This should have failed!. They key crossed it's budget for model={request_model}. {e}" | |
traceback.print_exc() | |
error_detail = e.message | |
assert f"exceeded budget for model={request_model}" in error_detail | |
assert isinstance(e, ProxyException) | |
assert e.type == ProxyErrorTypes.budget_exceeded | |
print(vars(e)) | |
finally: | |
litellm.callbacks.remove(model_budget_limiter) | |
async def test_call_with_key_never_over_budget(prisma_client): | |
# Make a call with a key with budget=None, it should never fail | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = GenerateKeyRequest(max_budget=None) | |
key = await generate_key_fn( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
user_id = key.user_id | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key: {result}") | |
# update spend using track_cost callback, make 2nd request, it should fail | |
import time | |
import uuid | |
from litellm import Choices, Message, ModelResponse, Usage | |
from litellm.proxy.proxy_server import _ProxyDBLogger | |
proxy_db_logger = _ProxyDBLogger() | |
request_id = f"chatcmpl-{uuid.uuid4()}" | |
resp = ModelResponse( | |
id=request_id, | |
choices=[ | |
Choices( | |
finish_reason=None, | |
index=0, | |
message=Message( | |
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
role="assistant", | |
), | |
) | |
], | |
model="gpt-35-turbo", # azure always has model written like this | |
usage=Usage( | |
prompt_tokens=210000, completion_tokens=200000, total_tokens=41000 | |
), | |
) | |
await proxy_db_logger._PROXY_track_cost_callback( | |
kwargs={ | |
"model": "chatgpt-v-3", | |
"stream": False, | |
"litellm_params": { | |
"metadata": { | |
"user_api_key": hash_token(generated_key), | |
"user_api_key_user_id": user_id, | |
} | |
}, | |
"response_cost": 200000, | |
}, | |
completion_response=resp, | |
start_time=datetime.now(), | |
end_time=datetime.now(), | |
) | |
await update_spend( | |
prisma_client=prisma_client, | |
db_writer_client=None, | |
proxy_logging_obj=proxy_logging_obj, | |
) | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
except Exception as e: | |
pytest.fail(f"This should have not failed!. They key uses max_budget=None. {e}") | |
async def test_call_with_key_over_budget_stream(prisma_client): | |
# 14. Make a call with a key over budget, expect to fail | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
import logging | |
from litellm._logging import verbose_proxy_logger | |
litellm.set_verbose = True | |
verbose_proxy_logger.setLevel(logging.DEBUG) | |
try: | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = GenerateKeyRequest(max_budget=0.00001) | |
key = await generate_key_fn( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
user_id = key.user_id | |
bearer_token = "Bearer " + generated_key | |
print(f"generated_key: {generated_key}") | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
# update spend using track_cost callback, make 2nd request, it should fail | |
import time | |
import uuid | |
from litellm import Choices, Message, ModelResponse, Usage | |
from litellm.proxy.proxy_server import _ProxyDBLogger | |
proxy_db_logger = _ProxyDBLogger() | |
request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{uuid.uuid4()}" | |
resp = ModelResponse( | |
id=request_id, | |
choices=[ | |
Choices( | |
finish_reason=None, | |
index=0, | |
message=Message( | |
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
role="assistant", | |
), | |
) | |
], | |
model="gpt-35-turbo", # azure always has model written like this | |
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
) | |
await proxy_db_logger._PROXY_track_cost_callback( | |
kwargs={ | |
"call_type": "acompletion", | |
"model": "sagemaker-chatgpt-v-3", | |
"stream": True, | |
"complete_streaming_response": resp, | |
"litellm_params": { | |
"metadata": { | |
"user_api_key": hash_token(generated_key), | |
"user_api_key_user_id": user_id, | |
} | |
}, | |
"response_cost": 0.00005, | |
}, | |
completion_response=resp, | |
start_time=datetime.now(), | |
end_time=datetime.now(), | |
) | |
await update_spend( | |
prisma_client=prisma_client, | |
db_writer_client=None, | |
proxy_logging_obj=proxy_logging_obj, | |
) | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
pytest.fail(f"This should have failed!. They key crossed it's budget") | |
except Exception as e: | |
print("Got Exception", e) | |
error_detail = e.message | |
assert "Budget has been exceeded" in error_detail | |
print(vars(e)) | |
async def test_aview_spend_per_user(prisma_client): | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
try: | |
user_by_spend = await spend_user_fn(user_id=None) | |
assert type(user_by_spend) == list | |
assert len(user_by_spend) > 0 | |
first_user = user_by_spend[0] | |
print("\nfirst_user=", first_user) | |
assert first_user["spend"] >= 0 | |
except Exception as e: | |
print("Got Exception", e) | |
pytest.fail(f"Got exception {e}") | |
async def test_view_spend_per_key(prisma_client): | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
try: | |
key_by_spend = await spend_key_fn() | |
assert type(key_by_spend) == list | |
assert len(key_by_spend) > 0 | |
first_key = key_by_spend[0] | |
print("\nfirst_key=", first_key) | |
assert first_key.spend >= 0 | |
except Exception as e: | |
print("Got Exception", e) | |
pytest.fail(f"Got exception {e}") | |
async def test_key_name_null(prisma_client): | |
""" | |
- create key | |
- get key info | |
- assert key_name is null | |
""" | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
os.environ["DISABLE_KEY_NAME"] = "True" | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
try: | |
request = GenerateKeyRequest() | |
key = await generate_key_fn( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print("generated key=", key) | |
generated_key = key.key | |
result = await info_key_fn( | |
key=generated_key, | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
) | |
print("result from info_key_fn", result) | |
assert result["info"]["key_name"] is None | |
except Exception as e: | |
print("Got Exception", e) | |
pytest.fail(f"Got exception {e}") | |
finally: | |
os.environ["DISABLE_KEY_NAME"] = "False" | |
async def test_key_name_set(prisma_client): | |
""" | |
- create key | |
- get key info | |
- assert key_name is not null | |
""" | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True}) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
try: | |
request = GenerateKeyRequest() | |
key = await generate_key_fn( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
generated_key = key.key | |
result = await info_key_fn( | |
key=generated_key, | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
) | |
print("result from info_key_fn", result) | |
assert isinstance(result["info"]["key_name"], str) | |
except Exception as e: | |
print("Got Exception", e) | |
pytest.fail(f"Got exception {e}") | |
async def test_default_key_params(prisma_client): | |
""" | |
- create key | |
- get key info | |
- assert key_name is not null | |
""" | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True}) | |
litellm.default_key_generate_params = {"max_budget": 0.000122} | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
try: | |
request = GenerateKeyRequest() | |
key = await generate_key_fn( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
generated_key = key.key | |
result = await info_key_fn( | |
key=generated_key, | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
) | |
print("result from info_key_fn", result) | |
assert result["info"]["max_budget"] == 0.000122 | |
except Exception as e: | |
print("Got Exception", e) | |
pytest.fail(f"Got exception {e}") | |
async def test_upperbound_key_param_larger_budget(prisma_client): | |
""" | |
- create key | |
- get key info | |
- assert key_name is not null | |
""" | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
litellm.upperbound_key_generate_params = LiteLLM_UpperboundKeyGenerateParams( | |
max_budget=0.001, budget_duration="1m" | |
) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
try: | |
request = GenerateKeyRequest( | |
max_budget=200000, | |
budget_duration="30d", | |
) | |
key = await generate_key_fn( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
# print(result) | |
except Exception as e: | |
assert e.code == str(400) | |
async def test_upperbound_key_param_larger_duration(prisma_client): | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
litellm.upperbound_key_generate_params = LiteLLM_UpperboundKeyGenerateParams( | |
max_budget=100, duration="14d" | |
) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
try: | |
request = GenerateKeyRequest( | |
max_budget=10, | |
duration="30d", | |
) | |
key = await generate_key_fn( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
pytest.fail("Expected this to fail but it passed") | |
# print(result) | |
except Exception as e: | |
assert e.code == str(400) | |
async def test_upperbound_key_param_none_duration(prisma_client): | |
from datetime import datetime, timedelta | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
litellm.upperbound_key_generate_params = LiteLLM_UpperboundKeyGenerateParams( | |
max_budget=100, duration="14d" | |
) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
try: | |
request = GenerateKeyRequest() | |
key = await generate_key_fn( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
# print(result) | |
assert key.max_budget == 100 | |
assert key.expires is not None | |
_date_key_expires = key.expires.date() | |
_fourteen_days_from_now = (datetime.now() + timedelta(days=14)).date() | |
assert _date_key_expires == _fourteen_days_from_now | |
except Exception as e: | |
pytest.fail(f"Got exception {e}") | |
def test_get_bearer_token(): | |
from litellm.proxy.auth.user_api_key_auth import _get_bearer_token | |
# Test valid Bearer token | |
api_key = "Bearer valid_token" | |
result = _get_bearer_token(api_key) | |
assert result == "valid_token", f"Expected 'valid_token', got '{result}'" | |
# Test empty API key | |
api_key = "" | |
result = _get_bearer_token(api_key) | |
assert result == "", f"Expected '', got '{result}'" | |
# Test API key without Bearer prefix | |
api_key = "invalid_token" | |
result = _get_bearer_token(api_key) | |
assert result == "", f"Expected '', got '{result}'" | |
# Test API key with Bearer prefix and extra spaces | |
api_key = " Bearer valid_token " | |
result = _get_bearer_token(api_key) | |
assert result == "", f"Expected '', got '{result}'" | |
# Test API key with Bearer prefix and no token | |
api_key = "Bearer sk-1234" | |
result = _get_bearer_token(api_key) | |
assert result == "sk-1234", f"Expected 'valid_token', got '{result}'" | |
async def test_update_logs_with_spend_logs_url(prisma_client): | |
""" | |
Unit test for making sure spend logs list is still updated when url passed in | |
""" | |
from litellm.proxy.db.db_spend_update_writer import DBSpendUpdateWriter | |
db_spend_update_writer = DBSpendUpdateWriter() | |
payload = {"startTime": datetime.now(), "endTime": datetime.now()} | |
await db_spend_update_writer._insert_spend_log_to_db(payload=payload, prisma_client=prisma_client) | |
assert len(prisma_client.spend_log_transactions) > 0 | |
prisma_client.spend_log_transactions = [] | |
spend_logs_url = "" | |
payload = {"startTime": datetime.now(), "endTime": datetime.now()} | |
await db_spend_update_writer._insert_spend_log_to_db( | |
payload=payload, spend_logs_url=spend_logs_url, prisma_client=prisma_client | |
) | |
assert len(prisma_client.spend_log_transactions) > 0 | |
async def test_user_api_key_auth(prisma_client): | |
from litellm.proxy.proxy_server import ProxyException | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True}) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# Test case: No API Key passed in | |
try: | |
await user_api_key_auth(request, api_key=None) | |
pytest.fail(f"This should have failed!. IT's an invalid key") | |
except ProxyException as exc: | |
print(exc.message) | |
assert exc.message == "Authentication Error, No api key passed in." | |
# Test case: Malformed API Key (missing 'Bearer ' prefix) | |
try: | |
await user_api_key_auth(request, api_key="my_token") | |
pytest.fail(f"This should have failed!. IT's an invalid key") | |
except ProxyException as exc: | |
print(exc.message) | |
assert ( | |
exc.message | |
== "Authentication Error, Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: my_token" | |
) | |
# Test case: User passes empty string API Key | |
try: | |
await user_api_key_auth(request, api_key="") | |
pytest.fail(f"This should have failed!. IT's an invalid key") | |
except ProxyException as exc: | |
print(exc.message) | |
assert ( | |
exc.message | |
== "Authentication Error, Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: " | |
) | |
async def test_user_api_key_auth_without_master_key(prisma_client): | |
# if master key is not set, expect all calls to go through | |
try: | |
from litellm.proxy.proxy_server import ProxyException | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", None) | |
setattr( | |
litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": True} | |
) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# Test case: No API Key passed in | |
await user_api_key_auth(request, api_key=None) | |
await user_api_key_auth(request, api_key="my_token") | |
await user_api_key_auth(request, api_key="") | |
await user_api_key_auth(request, api_key="Bearer " + "1234") | |
except Exception as e: | |
print("Got Exception", e) | |
pytest.fail(f"Got exception {e}") | |
async def test_key_with_no_permissions(prisma_client): | |
""" | |
- create key | |
- get key info | |
- assert key_name is null | |
""" | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
setattr(litellm.proxy.proxy_server, "general_settings", {"allow_user_auth": False}) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
try: | |
response = await generate_key_helper_fn( | |
request_type="key", | |
**{"duration": "1hr", "key_max_budget": 0, "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": "ishaan", "team_id": "litellm-dashboard"}, # type: ignore | |
) | |
print(response) | |
key = response["token"] | |
# make a /chat/completions call -> it should fail | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key="Bearer " + key) | |
print("result from user auth with new key", result) | |
pytest.fail(f"This should have failed!. IT's an invalid key") | |
except Exception as e: | |
print("Got Exception", e) | |
print(e.message) | |
async def track_cost_callback_helper_fn(generated_key: str, user_id: str): | |
import uuid | |
from litellm import Choices, Message, ModelResponse, Usage | |
from litellm.proxy.proxy_server import _ProxyDBLogger | |
request_id = f"chatcmpl-e41836bb-bb8b-4df2-8e70-8f3e160155ac{uuid.uuid4()}" | |
resp = ModelResponse( | |
id=request_id, | |
choices=[ | |
Choices( | |
finish_reason=None, | |
index=0, | |
message=Message( | |
content=" Sure! Here is a short poem about the sky:\n\nA canvas of blue, a", | |
role="assistant", | |
), | |
) | |
], | |
model="gpt-35-turbo", # azure always has model written like this | |
usage=Usage(prompt_tokens=210, completion_tokens=200, total_tokens=410), | |
) | |
proxy_db_logger = _ProxyDBLogger() | |
await proxy_db_logger._PROXY_track_cost_callback( | |
kwargs={ | |
"call_type": "acompletion", | |
"model": "sagemaker-chatgpt-v-3", | |
"stream": True, | |
"complete_streaming_response": resp, | |
"litellm_params": { | |
"metadata": { | |
"user_api_key": hash_token(generated_key), | |
"user_api_key_user_id": user_id, | |
} | |
}, | |
"response_cost": 0.00005, | |
}, | |
completion_response=resp, | |
start_time=datetime.now(), | |
end_time=datetime.now(), | |
) | |
async def test_proxy_load_test_db(prisma_client): | |
""" | |
Run 1500 req./s against track_cost_callback function | |
""" | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
import logging | |
import time | |
from litellm._logging import verbose_proxy_logger | |
litellm.set_verbose = True | |
verbose_proxy_logger.setLevel(logging.DEBUG) | |
try: | |
start_time = time.time() | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = GenerateKeyRequest(max_budget=0.00001) | |
key = await generate_key_fn( | |
request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
user_id = key.user_id | |
bearer_token = "Bearer " + generated_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# use generated key to auth in | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
print("result from user auth with new key", result) | |
# update spend using track_cost callback, make 2nd request, it should fail | |
n = 5000 | |
tasks = [ | |
track_cost_callback_helper_fn(generated_key=generated_key, user_id=user_id) | |
for _ in range(n) | |
] | |
completions = await asyncio.gather(*tasks) | |
await asyncio.sleep(120) | |
try: | |
# call spend logs | |
spend_logs = await view_spend_logs( | |
api_key=generated_key, | |
user_api_key_dict=UserAPIKeyAuth(api_key=generated_key), | |
) | |
print(f"len responses: {len(spend_logs)}") | |
assert len(spend_logs) == n | |
print(n, time.time() - start_time, len(spend_logs)) | |
except Exception: | |
print(n, time.time() - start_time, 0) | |
raise Exception(f"it worked! key={key.key}") | |
except Exception as e: | |
pytest.fail(f"An exception occurred - {str(e)}") | |
async def test_master_key_hashing(prisma_client): | |
try: | |
import uuid | |
print("prisma client=", prisma_client) | |
master_key = "sk-1234" | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", master_key) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
from litellm.proxy.proxy_server import user_api_key_cache | |
_team_id = "ishaans-special-team_{}".format(uuid.uuid4()) | |
user_api_key_dict = UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
) | |
await new_team( | |
NewTeamRequest(team_id=_team_id), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
http_request=Request(scope={"type": "http"}), | |
) | |
_response = await new_user( | |
data=NewUserRequest( | |
models=["azure-gpt-3.5"], | |
team_id=_team_id, | |
tpm_limit=20, | |
), | |
user_api_key_dict=user_api_key_dict, | |
) | |
print(_response) | |
assert _response.models == ["azure-gpt-3.5"] | |
assert _response.team_id == _team_id | |
assert _response.tpm_limit == 20 | |
bearer_token = "Bearer " + master_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# use generated key to auth in | |
result: UserAPIKeyAuth = await user_api_key_auth( | |
request=request, api_key=bearer_token | |
) | |
assert result.api_key == hash_token(master_key) | |
except Exception as e: | |
print("Got Exception", e) | |
pytest.fail(f"Got exception {e}") | |
async def test_reset_spend_authentication(prisma_client): | |
""" | |
1. Test master key can access this route -> ONLY MASTER KEY SHOULD BE ABLE TO RESET SPEND | |
2. Test that non-master key gets rejected | |
3. Test that non-master key with role == LitellmUserRoles.PROXY_ADMIN or admin gets rejected | |
""" | |
print("prisma client=", prisma_client) | |
master_key = "sk-1234" | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", master_key) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
from litellm.proxy.proxy_server import user_api_key_cache | |
bearer_token = "Bearer " + master_key | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/global/spend/reset") | |
# Test 1 - Master Key | |
result: UserAPIKeyAuth = await user_api_key_auth( | |
request=request, api_key=bearer_token | |
) | |
print("result from user auth with Master key", result) | |
assert result.token is not None | |
# Test 2 - Non-Master Key | |
_response = await new_user( | |
data=NewUserRequest( | |
tpm_limit=20, | |
) | |
) | |
generate_key = "Bearer " + _response.key | |
try: | |
await user_api_key_auth(request=request, api_key=generate_key) | |
pytest.fail(f"This should have failed!. IT's an expired key") | |
except Exception as e: | |
print("Got Exception", e) | |
assert ( | |
"Tried to access route=/global/spend/reset, which is only for MASTER KEY" | |
in e.message | |
) | |
# Test 3 - Non-Master Key with role == LitellmUserRoles.PROXY_ADMIN or admin | |
_response = await new_user( | |
data=NewUserRequest( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
tpm_limit=20, | |
) | |
) | |
generate_key = "Bearer " + _response.key | |
try: | |
await user_api_key_auth(request=request, api_key=generate_key) | |
pytest.fail(f"This should have failed!. IT's an expired key") | |
except Exception as e: | |
print("Got Exception", e) | |
assert ( | |
"Tried to access route=/global/spend/reset, which is only for MASTER KEY" | |
in e.message | |
) | |
async def test_create_update_team(prisma_client): | |
""" | |
- Set max_budget, budget_duration, max_budget, tpm_limit, rpm_limit | |
- Assert response has correct values | |
- Update max_budget, budget_duration, max_budget, tpm_limit, rpm_limit | |
- Assert response has correct values | |
- Call team_info and assert response has correct values | |
""" | |
print("prisma client=", prisma_client) | |
master_key = "sk-1234" | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", master_key) | |
import datetime | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
from litellm.proxy.proxy_server import user_api_key_cache | |
_team_id = "test-team_{}".format(uuid.uuid4()) | |
response = await new_team( | |
NewTeamRequest( | |
team_id=_team_id, | |
max_budget=20, | |
budget_duration="30d", | |
tpm_limit=20, | |
rpm_limit=20, | |
), | |
http_request=Request(scope={"type": "http"}), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print("RESPONSE from new_team", response) | |
assert response["team_id"] == _team_id | |
assert response["max_budget"] == 20 | |
assert response["tpm_limit"] == 20 | |
assert response["rpm_limit"] == 20 | |
assert response["budget_duration"] == "30d" | |
assert response["budget_reset_at"] is not None and isinstance( | |
response["budget_reset_at"], datetime.datetime | |
) | |
# updating team budget duration and reset at | |
response = await update_team( | |
UpdateTeamRequest( | |
team_id=_team_id, | |
max_budget=30, | |
budget_duration="2d", | |
tpm_limit=30, | |
rpm_limit=30, | |
), | |
http_request=Request(scope={"type": "http"}), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print("RESPONSE from update_team", response) | |
_updated_info = response["data"] | |
_updated_info = dict(_updated_info) | |
assert _updated_info["team_id"] == _team_id | |
assert _updated_info["max_budget"] == 30 | |
assert _updated_info["tpm_limit"] == 30 | |
assert _updated_info["rpm_limit"] == 30 | |
assert _updated_info["budget_duration"] == "2d" | |
assert _updated_info["budget_reset_at"] is not None and isinstance( | |
_updated_info["budget_reset_at"], datetime.datetime | |
) | |
# budget_reset_at should be 2 days from now | |
budget_reset_at = _updated_info["budget_reset_at"].replace(tzinfo=timezone.utc) | |
current_time = datetime.datetime.now(timezone.utc) | |
# Verify that budget_reset_at is at midnight (hour, minute, second are all 0) | |
assert budget_reset_at.hour == 0 | |
assert budget_reset_at.minute == 0 | |
assert budget_reset_at.second == 0 | |
# Calculate days difference - should be close to 2 days (within 1 day to account for time of test execution) | |
days_diff = (budget_reset_at.date() - current_time.date()).days | |
assert 1 <= days_diff <= 2 | |
# now hit team_info | |
try: | |
response = await team_info( | |
team_id=_team_id, | |
http_request=Request(scope={"type": "http"}), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
except Exception as e: | |
print(e) | |
pytest.fail("Receives error - {}".format(e)) | |
_team_info = response["team_info"] | |
_team_info = dict(_team_info) | |
assert _team_info["team_id"] == _team_id | |
assert _team_info["max_budget"] == 30 | |
assert _team_info["tpm_limit"] == 30 | |
assert _team_info["rpm_limit"] == 30 | |
assert _team_info["budget_duration"] == "2d" | |
assert _team_info["budget_reset_at"] is not None and isinstance( | |
_team_info["budget_reset_at"], datetime.datetime | |
) | |
async def test_update_user_role(prisma_client): | |
""" | |
Tests if we update user role, incorrect values are not stored in cache | |
-> create a user with role == INTERNAL_USER | |
-> access an Admin only route -> expect to fail | |
-> update user role to == PROXY_ADMIN | |
-> access an Admin only route -> expect to succeed | |
""" | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
key = await new_user( | |
data=NewUserRequest( | |
user_role=LitellmUserRoles.INTERNAL_USER, | |
) | |
) | |
print(key) | |
api_key = "Bearer " + key.key | |
api_route = APIRoute(path="/global/spend", endpoint=global_spend) | |
request = Request( | |
{ | |
"type": "http", | |
"route": api_route, | |
"path": "/global/spend", | |
"headers": [("Authorization", api_key)], | |
} | |
) | |
request._url = URL(url="/global/spend") | |
# use generated key to auth in | |
try: | |
result = await user_api_key_auth(request=request, api_key=api_key) | |
print("result from user auth with new key", result) | |
except Exception as e: | |
print(e) | |
pass | |
await user_update( | |
data=UpdateUserRequest( | |
user_id=key.user_id, user_role=LitellmUserRoles.PROXY_ADMIN | |
) | |
) | |
# await asyncio.sleep(3) | |
# use generated key to auth in | |
print("\n\nMAKING NEW REQUEST WITH UPDATED USER ROLE\n\n") | |
result = await user_api_key_auth(request=request, api_key=api_key) | |
print("result from user auth with new key", result) | |
async def test_update_user_unit_test(prisma_client): | |
""" | |
Unit test for /user/update | |
Ensure that params are updated for UpdateUserRequest | |
""" | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
key = await new_user( | |
data=NewUserRequest( | |
user_email=f"test-{uuid.uuid4()}@test.com", | |
) | |
) | |
print(key) | |
user_info = await user_update( | |
data=UpdateUserRequest( | |
user_id=key.user_id, | |
team_id="1234", | |
max_budget=100, | |
budget_duration="10d", | |
tpm_limit=100, | |
rpm_limit=100, | |
metadata={"very-new-metadata": "something"}, | |
) | |
) | |
print("user_info", user_info) | |
assert user_info is not None | |
_user_info = user_info["data"].model_dump() | |
assert _user_info["user_id"] == key.user_id | |
assert _user_info["team_id"] == "1234" | |
assert _user_info["max_budget"] == 100 | |
assert _user_info["budget_duration"] == "10d" | |
assert _user_info["tpm_limit"] == 100 | |
assert _user_info["rpm_limit"] == 100 | |
assert _user_info["metadata"] == {"very-new-metadata": "something"} | |
# budget_reset_at should be at midnight 10 days from now | |
budget_reset_at = _user_info["budget_reset_at"].replace(tzinfo=timezone.utc) | |
current_time = datetime.now(timezone.utc) | |
# Verify that budget_reset_at is at midnight (hour, minute, second are all 0) | |
assert budget_reset_at.hour == 0 | |
assert budget_reset_at.minute == 0 | |
assert budget_reset_at.second == 0 | |
# Calculate days difference - should be close to 10 days (within 1 day to account for time of test execution) | |
days_diff = (budget_reset_at.date() - current_time.date()).days | |
assert 9 <= days_diff <= 10 | |
async def test_custom_api_key_header_name(prisma_client): | |
""" """ | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
setattr( | |
litellm.proxy.proxy_server, | |
"general_settings", | |
{"litellm_key_header_name": "x-litellm-key"}, | |
) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
api_route = APIRoute(path="/chat/completions", endpoint=chat_completion) | |
request = Request( | |
{ | |
"type": "http", | |
"route": api_route, | |
"path": api_route.path, | |
"headers": [ | |
(b"x-litellm-key", b"Bearer sk-1234"), | |
], | |
} | |
) | |
# this should pass because we pass the master key as X-Litellm-Key and litellm_key_header_name="X-Litellm-Key" in general settings | |
result = await user_api_key_auth(request=request, api_key="Bearer invalid-key") | |
# this should fail because X-Litellm-Key is invalid | |
request = Request( | |
{ | |
"type": "http", | |
"route": api_route, | |
"path": api_route.path, | |
"headers": [], | |
} | |
) | |
try: | |
result = await user_api_key_auth(request=request, api_key="Bearer sk-1234") | |
pytest.fail(f"This should have failed!. invalid Auth on this request") | |
except Exception as e: | |
print("failed with error", e) | |
assert ( | |
"Malformed API Key passed in. Ensure Key has `Bearer ` prefix" in e.message | |
) | |
pass | |
# this should pass because X-Litellm-Key is valid | |
async def test_generate_key_with_model_tpm_limit(prisma_client): | |
print("prisma client=", prisma_client) | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = GenerateKeyRequest( | |
metadata={ | |
"team": "litellm-team3", | |
"model_tpm_limit": {"gpt-4": 100}, | |
"model_rpm_limit": {"gpt-4": 2}, | |
} | |
) | |
key = await generate_key_fn( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
# use generated key to auth in | |
result = await info_key_fn( | |
key=generated_key, | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
) | |
print("result from info_key_fn", result) | |
assert result["key"] == generated_key | |
print("\n info for key=", result["info"]) | |
assert result["info"]["metadata"] == { | |
"team": "litellm-team3", | |
"model_tpm_limit": {"gpt-4": 100}, | |
"model_rpm_limit": {"gpt-4": 2}, | |
} | |
# Update model tpm_limit and rpm_limit | |
request = UpdateKeyRequest( | |
key=generated_key, | |
model_tpm_limit={"gpt-4": 200}, | |
model_rpm_limit={"gpt-4": 3}, | |
) | |
_request = Request(scope={"type": "http"}) | |
_request._url = URL(url="/update/key") | |
await update_key_fn( | |
data=request, | |
request=_request, | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
) | |
result = await info_key_fn( | |
key=generated_key, | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
) | |
print("result from info_key_fn", result) | |
assert result["key"] == generated_key | |
print("\n info for key=", result["info"]) | |
assert result["info"]["metadata"] == { | |
"team": "litellm-team3", | |
"model_tpm_limit": {"gpt-4": 200}, | |
"model_rpm_limit": {"gpt-4": 3}, | |
} | |
async def test_generate_key_with_guardrails(prisma_client): | |
print("prisma client=", prisma_client) | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
request = GenerateKeyRequest( | |
guardrails=["aporia-pre-call"], | |
metadata={ | |
"team": "litellm-team3", | |
}, | |
) | |
key = await generate_key_fn( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print("generated key=", key) | |
generated_key = key.key | |
# use generated key to auth in | |
result = await info_key_fn( | |
key=generated_key, | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
) | |
print("result from info_key_fn", result) | |
assert result["key"] == generated_key | |
print("\n info for key=", result["info"]) | |
assert result["info"]["metadata"] == { | |
"team": "litellm-team3", | |
"guardrails": ["aporia-pre-call"], | |
} | |
# Update model tpm_limit and rpm_limit | |
request = UpdateKeyRequest( | |
key=generated_key, | |
guardrails=["aporia-pre-call", "aporia-post-call"], | |
) | |
_request = Request(scope={"type": "http"}) | |
_request._url = URL(url="/update/key") | |
await update_key_fn( | |
data=request, | |
request=_request, | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
) | |
result = await info_key_fn( | |
key=generated_key, | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
) | |
print("result from info_key_fn", result) | |
assert result["key"] == generated_key | |
print("\n info for key=", result["info"]) | |
assert result["info"]["metadata"] == { | |
"team": "litellm-team3", | |
"guardrails": ["aporia-pre-call", "aporia-post-call"], | |
} | |
async def test_team_guardrails(prisma_client): | |
""" | |
- Test setting guardrails on a team | |
- Assert this is returned when calling /team/info | |
- Team/update with guardrails should update the guardrails | |
- Assert new guardrails are returned when calling /team/info | |
""" | |
litellm.set_verbose = True | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
_new_team = NewTeamRequest( | |
team_alias="test-teamA", | |
guardrails=["aporia-pre-call"], | |
) | |
new_team_response = await new_team( | |
data=_new_team, | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
http_request=Request(scope={"type": "http"}), | |
) | |
print("new_team_response", new_team_response) | |
# call /team/info | |
team_info_response = await team_info( | |
team_id=new_team_response["team_id"], | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
http_request=Request(scope={"type": "http"}), | |
) | |
print("team_info_response", team_info_response) | |
assert team_info_response["team_info"].metadata["guardrails"] == ["aporia-pre-call"] | |
# team update with guardrails | |
team_update_response = await update_team( | |
data=UpdateTeamRequest( | |
team_id=new_team_response["team_id"], | |
guardrails=["aporia-pre-call", "aporia-post-call"], | |
), | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
http_request=Request(scope={"type": "http"}), | |
) | |
print("team_update_response", team_update_response) | |
# call /team/info again | |
team_info_response = await team_info( | |
team_id=new_team_response["team_id"], | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
http_request=Request(scope={"type": "http"}), | |
) | |
print("team_info_response", team_info_response) | |
assert team_info_response["team_info"].metadata["guardrails"] == [ | |
"aporia-pre-call", | |
"aporia-post-call", | |
] | |
async def test_team_access_groups(prisma_client): | |
""" | |
Test team based model access groups | |
- Test calling a model in the access group -> pass | |
- Test calling a model not in the access group -> fail | |
""" | |
litellm.set_verbose = True | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
# create router with access groups | |
litellm_router = litellm.Router( | |
model_list=[ | |
{ | |
"model_name": "gemini-pro-vision", | |
"litellm_params": { | |
"model": "vertex_ai/gemini-1.0-pro-vision-001", | |
}, | |
"model_info": {"access_groups": ["beta-models"]}, | |
}, | |
{ | |
"model_name": "gpt-4o", | |
"litellm_params": { | |
"model": "gpt-4o", | |
}, | |
"model_info": {"access_groups": ["beta-models"]}, | |
}, | |
] | |
) | |
setattr(litellm.proxy.proxy_server, "llm_router", litellm_router) | |
# Create team with models=["beta-models"] | |
team_request = NewTeamRequest( | |
team_alias="testing-team", | |
models=["beta-models"], | |
) | |
new_team_response = await new_team( | |
data=team_request, | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
http_request=Request(scope={"type": "http"}), | |
) | |
print("new_team_response", new_team_response) | |
created_team_id = new_team_response["team_id"] | |
# create key with team_id=created_team_id | |
request = GenerateKeyRequest( | |
team_id=created_team_id, | |
) | |
key = await generate_key_fn( | |
data=request, | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
print(key) | |
generated_key = key.key | |
bearer_token = "Bearer " + generated_key | |
request._url = URL(url="/chat/completions") | |
for model in ["gpt-4o", "gemini-pro-vision"]: | |
# Expect these to pass | |
async def return_body(): | |
return_string = f'{{"model": "{model}"}}' | |
# return string as bytes | |
return return_string.encode() | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
request.body = return_body | |
# use generated key to auth in | |
print( | |
"Bearer token being sent to user_api_key_auth() - {}".format(bearer_token) | |
) | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
for model in ["gpt-4", "gpt-4o-mini", "gemini-experimental"]: | |
# Expect these to fail | |
async def return_body_2(): | |
return_string = f'{{"model": "{model}"}}' | |
# return string as bytes | |
return return_string.encode() | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
request.body = return_body_2 | |
# use generated key to auth in | |
print( | |
"Bearer token being sent to user_api_key_auth() - {}".format(bearer_token) | |
) | |
try: | |
result = await user_api_key_auth(request=request, api_key=bearer_token) | |
pytest.fail(f"This should have failed!. IT's an invalid model") | |
except Exception as e: | |
print("got exception", e) | |
assert isinstance(e, ProxyException) | |
assert e.type == ProxyErrorTypes.team_model_access_denied | |
assert e.param == "model" | |
async def test_team_tags(prisma_client): | |
""" | |
- Test setting tags on a team | |
- Assert this is returned when calling /team/info | |
- Team/update with tags should update the tags | |
- Assert new tags are returned when calling /team/info | |
""" | |
litellm.set_verbose = True | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
_new_team = NewTeamRequest( | |
team_alias="test-teamA", | |
tags=["teamA"], | |
) | |
new_team_response = await new_team( | |
data=_new_team, | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
http_request=Request(scope={"type": "http"}), | |
) | |
print("new_team_response", new_team_response) | |
# call /team/info | |
team_info_response = await team_info( | |
team_id=new_team_response["team_id"], | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
http_request=Request(scope={"type": "http"}), | |
) | |
print("team_info_response", team_info_response) | |
assert team_info_response["team_info"].metadata["tags"] == ["teamA"] | |
# team update with tags | |
team_update_response = await update_team( | |
data=UpdateTeamRequest( | |
team_id=new_team_response["team_id"], | |
tags=["teamA", "teamB"], | |
), | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
http_request=Request(scope={"type": "http"}), | |
) | |
print("team_update_response", team_update_response) | |
# call /team/info again | |
team_info_response = await team_info( | |
team_id=new_team_response["team_id"], | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
http_request=Request(scope={"type": "http"}), | |
) | |
print("team_info_response", team_info_response) | |
assert team_info_response["team_info"].metadata["tags"] == ["teamA", "teamB"] | |
async def test_aadmin_only_routes(prisma_client): | |
""" | |
Tests if setting admin_only_routes works | |
only an admin should be able to access admin only routes | |
""" | |
litellm.set_verbose = True | |
print(f"os.getenv('DATABASE_URL')={os.getenv('DATABASE_URL')}") | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
general_settings = { | |
"allowed_routes": ["/embeddings", "/key/generate"], | |
"admin_only_routes": ["/key/generate"], | |
} | |
from litellm.proxy import proxy_server | |
initial_general_settings = getattr(proxy_server, "general_settings") | |
setattr(proxy_server, "general_settings", general_settings) | |
admin_user = await new_user( | |
data=NewUserRequest( | |
user_name="admin", | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
), | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
) | |
non_admin_user = await new_user( | |
data=NewUserRequest( | |
user_name="non-admin", | |
user_role=LitellmUserRoles.INTERNAL_USER, | |
), | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
) | |
admin_user_key = admin_user.key | |
non_admin_user_key = non_admin_user.key | |
assert admin_user_key is not None | |
assert non_admin_user_key is not None | |
# assert non-admin can not access admin routes | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/key/generate") | |
await user_api_key_auth( | |
request=request, | |
api_key="Bearer " + admin_user_key, | |
) | |
# this should pass | |
try: | |
await user_api_key_auth( | |
request=request, | |
api_key="Bearer " + non_admin_user_key, | |
) | |
pytest.fail("Expected this call to fail. User is over limit.") | |
except Exception as e: | |
print("error str=", str(e.message)) | |
error_str = str(e.message) | |
assert "Route" in error_str and "admin only route" in error_str | |
pass | |
setattr(proxy_server, "general_settings", initial_general_settings) | |
async def test_list_keys(prisma_client): | |
""" | |
Test the list_keys function: | |
- Test basic key | |
- Test pagination | |
- Test filtering by user_id, and key_alias | |
""" | |
from fastapi import Query | |
from litellm.proxy.proxy_server import hash_token | |
from litellm.proxy._types import LitellmUserRoles | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
# Test basic listing | |
request = Request(scope={"type": "http", "query_string": b""}) | |
response = await list_keys( | |
request, | |
UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN.value, | |
), | |
page=1, | |
size=10, | |
) | |
print("response=", response) | |
assert "keys" in response | |
assert len(response["keys"]) > 0 | |
assert "total_count" in response | |
assert "current_page" in response | |
assert "total_pages" in response | |
# Test pagination | |
response = await list_keys( | |
request, | |
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value), | |
page=1, | |
size=2, | |
) | |
print("pagination response=", response) | |
assert len(response["keys"]) == 2 | |
assert response["current_page"] == 1 | |
# Test filtering by user_id | |
unique_id = str(uuid.uuid4()) | |
team_id = f"key-list-team-{unique_id}" | |
key_alias = f"key-list-alias-{unique_id}" | |
user_id = f"key-list-user-{unique_id}" | |
response = await new_user( | |
data=NewUserRequest( | |
user_id=f"key-list-user-{unique_id}", | |
user_role=LitellmUserRoles.INTERNAL_USER, | |
key_alias=f"key-list-alias-{unique_id}", | |
), | |
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), | |
) | |
_key = hash_token(response.key) | |
await asyncio.sleep(2) | |
# Test filtering by user_id | |
response = await list_keys( | |
request, | |
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value), | |
user_id=user_id, | |
page=1, | |
size=10, | |
) | |
print("filtered user_id response=", response) | |
assert len(response["keys"]) == 1 | |
assert _key in response["keys"] | |
# Test filtering by key_alias | |
response = await list_keys( | |
request, | |
UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN.value), | |
key_alias=key_alias, | |
page=1, | |
size=10, | |
) | |
assert len(response["keys"]) == 1 | |
assert _key in response["keys"] | |
async def test_auth_vertex_ai_route(prisma_client): | |
""" | |
If user is premium user and vertex-ai route is used. Assert Virtual Key checks are run | |
""" | |
litellm.set_verbose = True | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "premium_user", True) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
route = "/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent" | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url=route) | |
request._headers = {"Authorization": "Bearer sk-12345"} | |
try: | |
await user_api_key_auth(request=request, api_key="Bearer " + "sk-12345") | |
pytest.fail("Expected this call to fail. User is over limit.") | |
except Exception as e: | |
print(vars(e)) | |
print("error str=", str(e.message)) | |
error_str = str(e.message) | |
assert e.code == "401" | |
assert "Invalid proxy server token passed" in error_str | |
pass | |
async def test_user_api_key_auth_db_unavailable(): | |
""" | |
Test that user_api_key_auth handles DB connection failures appropriately when: | |
1. DB connection fails during token validation | |
2. allow_requests_on_db_unavailable=True | |
""" | |
litellm.set_verbose = True | |
# Mock dependencies | |
class MockPrismaClient: | |
async def get_data(self, *args, **kwargs): | |
print("MockPrismaClient.get_data() called") | |
raise httpx.ConnectError("Failed to connect to DB") | |
async def connect(self): | |
print("MockPrismaClient.connect() called") | |
pass | |
class MockDualCache: | |
async def async_get_cache(self, *args, **kwargs): | |
return None | |
async def async_set_cache(self, *args, **kwargs): | |
pass | |
async def set_cache(self, *args, **kwargs): | |
pass | |
# Set up test environment | |
setattr(litellm.proxy.proxy_server, "prisma_client", MockPrismaClient()) | |
setattr(litellm.proxy.proxy_server, "user_api_key_cache", MockDualCache()) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
setattr( | |
litellm.proxy.proxy_server, | |
"general_settings", | |
{"allow_requests_on_db_unavailable": True}, | |
) | |
# Create test request | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# Run test with a sample API key | |
result = await user_api_key_auth( | |
request=request, | |
api_key="Bearer sk-123456789", | |
) | |
# Verify results | |
assert isinstance(result, UserAPIKeyAuth) | |
assert result.key_name == "failed-to-connect-to-db" | |
assert result.user_id == litellm.proxy.proxy_server.litellm_proxy_admin_name | |
async def test_user_api_key_auth_db_unavailable_not_allowed(): | |
""" | |
Test that user_api_key_auth raises an exception when: | |
This is default behavior | |
1. DB connection fails during token validation | |
2. allow_requests_on_db_unavailable=False (default behavior) | |
""" | |
# Mock dependencies | |
class MockPrismaClient: | |
async def get_data(self, *args, **kwargs): | |
print("MockPrismaClient.get_data() called") | |
raise httpx.ConnectError("Failed to connect to DB") | |
async def connect(self): | |
print("MockPrismaClient.connect() called") | |
pass | |
class MockDualCache: | |
async def async_get_cache(self, *args, **kwargs): | |
return None | |
async def async_set_cache(self, *args, **kwargs): | |
pass | |
async def set_cache(self, *args, **kwargs): | |
pass | |
# Set up test environment | |
setattr(litellm.proxy.proxy_server, "prisma_client", MockPrismaClient()) | |
setattr(litellm.proxy.proxy_server, "user_api_key_cache", MockDualCache()) | |
setattr(litellm.proxy.proxy_server, "general_settings", {}) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
# Create test request | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/chat/completions") | |
# Run test with a sample API key | |
with pytest.raises(litellm.proxy._types.ProxyException): | |
await user_api_key_auth( | |
request=request, | |
api_key="Bearer sk-123456789", | |
) | |
## E2E Virtual Key + Secret Manager Tests ######################################### | |
async def test_key_generate_with_secret_manager_call(prisma_client): | |
""" | |
Generate a key | |
assert it exists in the secret manager | |
delete the key | |
assert it is deleted from the secret manager | |
""" | |
from litellm.secret_managers.aws_secret_manager_v2 import AWSSecretsManagerV2 | |
from litellm.proxy._types import KeyManagementSystem, KeyManagementSettings | |
from litellm.proxy.hooks.key_management_event_hooks import ( | |
LITELLM_PREFIX_STORED_VIRTUAL_KEYS, | |
) | |
litellm.set_verbose = True | |
#### Test Setup ############################################################ | |
aws_secret_manager_client = AWSSecretsManagerV2() | |
litellm.secret_manager_client = aws_secret_manager_client | |
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER | |
litellm._key_management_settings = KeyManagementSettings( | |
store_virtual_keys=True, | |
) | |
general_settings = { | |
"key_management_system": "aws_secret_manager", | |
"key_management_settings": { | |
"store_virtual_keys": True, | |
}, | |
} | |
setattr(litellm.proxy.proxy_server, "general_settings", general_settings) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
############################################################################ | |
# generate new key | |
key_alias = f"test_alias_secret_manager_key-{uuid.uuid4()}" | |
spend = 100 | |
max_budget = 400 | |
models = ["fake-openai-endpoint"] | |
new_key = await generate_key_fn( | |
data=GenerateKeyRequest( | |
key_alias=key_alias, spend=spend, max_budget=max_budget, models=models | |
), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
generated_key = new_key.key | |
print(generated_key) | |
await asyncio.sleep(2) | |
# read from the secret manager | |
result = await aws_secret_manager_client.async_read_secret( | |
secret_name=f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}{key_alias}" | |
) | |
# Assert the correct key is stored in the secret manager | |
print("response from AWS Secret Manager") | |
print(result) | |
assert result == generated_key | |
# delete the key | |
await delete_key_fn( | |
data=KeyRequest(keys=[generated_key]), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1234", user_id="1234" | |
), | |
) | |
await asyncio.sleep(2) | |
# Assert the key is deleted from the secret manager | |
result = await aws_secret_manager_client.async_read_secret( | |
secret_name=f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}{key_alias}" | |
) | |
assert result is None | |
# cleanup | |
setattr(litellm.proxy.proxy_server, "general_settings", {}) | |
################################################################################ | |
async def test_key_alias_uniqueness(prisma_client): | |
""" | |
Test that: | |
1. We cannot create two keys with the same alias | |
2. We cannot update a key to use an alias that's already taken | |
3. We can update a key while keeping its existing alias | |
""" | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
try: | |
# Create first key with an alias | |
unique_alias = f"test-alias-{uuid.uuid4()}" | |
key1 = await generate_key_fn( | |
data=GenerateKeyRequest(key_alias=unique_alias), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
# Try to create second key with same alias - should fail | |
try: | |
key2 = await generate_key_fn( | |
data=GenerateKeyRequest(key_alias=unique_alias), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
pytest.fail("Should not be able to create a second key with the same alias") | |
except Exception as e: | |
print("vars(e)=", vars(e)) | |
assert "Unique key aliases across all keys are required" in str(e.message) | |
# Create another key with different alias | |
another_alias = f"test-alias-{uuid.uuid4()}" | |
key3 = await generate_key_fn( | |
data=GenerateKeyRequest(key_alias=another_alias), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
# Try to update key3 to use key1's alias - should fail | |
try: | |
await update_key_fn( | |
data=UpdateKeyRequest(key=key3.key, key_alias=unique_alias), | |
request=Request(scope={"type": "http"}), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
pytest.fail("Should not be able to update a key to use an existing alias") | |
except Exception as e: | |
assert "Unique key aliases across all keys are required" in str(e.message) | |
# Update key1 with its own existing alias - should succeed | |
updated_key = await update_key_fn( | |
data=UpdateKeyRequest(key=key1.key, key_alias=unique_alias), | |
request=Request(scope={"type": "http"}), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
assert updated_key is not None | |
except Exception as e: | |
print("got exceptions, e=", e) | |
print("vars(e)=", vars(e)) | |
pytest.fail(f"An unexpected error occurred: {str(e)}") | |
async def test_enforce_unique_key_alias(prisma_client): | |
""" | |
Unit test the _enforce_unique_key_alias function: | |
1. Test it allows unique aliases | |
2. Test it blocks duplicate aliases for new keys | |
3. Test it allows updating a key with its own existing alias | |
4. Test it blocks updating a key with another key's alias | |
""" | |
from litellm.proxy.management_endpoints.key_management_endpoints import ( | |
_enforce_unique_key_alias, | |
) | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
try: | |
# Test 1: Allow unique alias | |
unique_alias = f"test-alias-{uuid.uuid4()}" | |
await _enforce_unique_key_alias( | |
key_alias=unique_alias, | |
prisma_client=prisma_client, | |
) # Should pass | |
# Create a key with this alias in the database | |
key1 = await generate_key_fn( | |
data=GenerateKeyRequest(key_alias=unique_alias), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
# Test 2: Block duplicate alias for new key | |
try: | |
await _enforce_unique_key_alias( | |
key_alias=unique_alias, | |
prisma_client=prisma_client, | |
) | |
pytest.fail("Should not allow duplicate alias") | |
except Exception as e: | |
assert "Unique key aliases across all keys are required" in str(e.message) | |
# Test 3: Allow updating key with its own alias | |
await _enforce_unique_key_alias( | |
key_alias=unique_alias, | |
existing_key_token=hash_token(key1.key), | |
prisma_client=prisma_client, | |
) # Should pass | |
# Test 4: Block updating with another key's alias | |
another_key = await generate_key_fn( | |
data=GenerateKeyRequest(key_alias=f"test-alias-{uuid.uuid4()}"), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
try: | |
await _enforce_unique_key_alias( | |
key_alias=unique_alias, | |
existing_key_token=another_key.key, | |
prisma_client=prisma_client, | |
) | |
pytest.fail("Should not allow using another key's alias") | |
except Exception as e: | |
assert "Unique key aliases across all keys are required" in str(e.message) | |
except Exception as e: | |
print("Unexpected error:", e) | |
pytest.fail(f"An unexpected error occurred: {str(e)}") | |
def test_should_track_cost_callback(): | |
""" | |
Test that the should_track_cost_callback function works as expected | |
""" | |
from litellm.proxy.hooks.proxy_track_cost_callback import ( | |
_should_track_cost_callback, | |
) | |
assert _should_track_cost_callback( | |
user_api_key=None, | |
user_id=None, | |
team_id=None, | |
end_user_id="1234", | |
) | |
async def test_get_paginated_teams(prisma_client): | |
""" | |
Test the get_paginated_teams function: | |
1. Test pagination returns valid results | |
2. Test total count matches across pages | |
3. Test page size is respected | |
""" | |
from litellm.proxy.management_endpoints.team_endpoints import get_paginated_teams | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
try: | |
# Get first page with page_size=2 | |
teams_page_1, total_count_1 = await get_paginated_teams( | |
prisma_client=prisma_client, page_size=2, page=1 | |
) | |
print("teams_page_1=", teams_page_1) | |
print("total_count_1=", total_count_1) | |
# Get second page | |
teams_page_2, total_count_2 = await get_paginated_teams( | |
prisma_client=prisma_client, page_size=2, page=2 | |
) | |
print("teams_page_2=", teams_page_2) | |
print("total_count_2=", total_count_2) | |
# Verify results | |
assert isinstance(teams_page_1, list) # Should return a list | |
assert isinstance(total_count_1, int) # Should return an integer count | |
assert ( | |
total_count_1 == total_count_2 | |
) # Total count should be consistent across pages | |
assert len(teams_page_1) <= 2 # Should respect page_size limit | |
except Exception as e: | |
print(f"Error occurred: {e}") | |
pytest.fail(f"Test failed with exception: {e}") | |
async def test_reset_budget_job(prisma_client, entity_type): | |
""" | |
Test that the ResetBudgetJob correctly resets budgets for keys, users, and teams. | |
For each entity type: | |
1. Create a new entity with max_budget=100, spend=99, budget_duration=5s | |
2. Call the reset_budget function | |
3. Verify the entity's spend is reset to 0 and budget_reset_at is updated | |
""" | |
from datetime import datetime, timedelta | |
import time | |
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob | |
from litellm.proxy.utils import ProxyLogging | |
# Setup | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
proxy_logging_obj = ProxyLogging(user_api_key_cache=None) | |
reset_budget_job = ResetBudgetJob( | |
proxy_logging_obj=proxy_logging_obj, prisma_client=prisma_client | |
) | |
# Create entity based on type | |
entity_id = None | |
if entity_type == "key": | |
# Create a key with specific budget settings | |
key = await generate_key_fn( | |
data=GenerateKeyRequest( | |
max_budget=100, | |
budget_duration="5s", | |
), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
entity_id = key.token_id | |
print("generated key=", key) | |
# Update the key to set spend and reset_at to now | |
updated = await prisma_client.db.litellm_verificationtoken.update_many( | |
where={"token": key.token_id}, | |
data={ | |
"spend": 99.0, | |
}, | |
) | |
print("Updated key=", updated) | |
elif entity_type == "user": | |
# Create a user with specific budget settings | |
user = await new_user( | |
data=NewUserRequest( | |
max_budget=100, | |
budget_duration="5s", | |
), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
) | |
entity_id = user.user_id | |
# Update the user to set spend and reset_at to now | |
await prisma_client.db.litellm_usertable.update_many( | |
where={"user_id": user.user_id}, | |
data={ | |
"spend": 99.0, | |
}, | |
) | |
elif entity_type == "team": | |
# Create a team with specific budget settings | |
team_id = f"test-team-{uuid.uuid4()}" | |
team = await new_team( | |
NewTeamRequest( | |
team_id=team_id, | |
max_budget=100, | |
budget_duration="5s", | |
), | |
user_api_key_dict=UserAPIKeyAuth( | |
user_role=LitellmUserRoles.PROXY_ADMIN, | |
api_key="sk-1234", | |
user_id="1234", | |
), | |
http_request=Request(scope={"type": "http"}), | |
) | |
entity_id = team_id | |
# Update the team to set spend and reset_at to now | |
current_time = datetime.utcnow() | |
await prisma_client.db.litellm_teamtable.update( | |
where={"team_id": team_id}, | |
data={ | |
"spend": 99.0, | |
}, | |
) | |
# Verify entity was created and updated with spend | |
if entity_type == "key": | |
entity_before = await prisma_client.db.litellm_verificationtoken.find_unique( | |
where={"token": entity_id} | |
) | |
elif entity_type == "user": | |
entity_before = await prisma_client.db.litellm_usertable.find_unique( | |
where={"user_id": entity_id} | |
) | |
elif entity_type == "team": | |
entity_before = await prisma_client.db.litellm_teamtable.find_unique( | |
where={"team_id": entity_id} | |
) | |
assert entity_before is not None | |
assert entity_before.spend == 99.0 | |
# Wait for 5 seconds to pass | |
print("sleeping for 5 seconds") | |
time.sleep(5) | |
# Call the reset_budget function | |
await reset_budget_job.reset_budget() | |
# Verify the entity's spend is reset and budget_reset_at is updated | |
if entity_type == "key": | |
entity_after = await prisma_client.db.litellm_verificationtoken.find_unique( | |
where={"token": entity_id} | |
) | |
elif entity_type == "user": | |
entity_after = await prisma_client.db.litellm_usertable.find_unique( | |
where={"user_id": entity_id} | |
) | |
elif entity_type == "team": | |
entity_after = await prisma_client.db.litellm_teamtable.find_unique( | |
where={"team_id": entity_id} | |
) | |
assert entity_after is not None | |
assert entity_after.spend == 0.0 | |
def test_delete_nonexistent_key_returns_404(prisma_client): | |
# Try to delete a key that does not exist, expect a 404 error | |
import random, string | |
from litellm.proxy._types import KeyRequest, UserAPIKeyAuth, LitellmUserRoles, ProxyException | |
from litellm.proxy.management_endpoints.key_management_endpoints import delete_key_fn | |
from starlette.datastructures import URL | |
from fastapi import Request | |
print("prisma client=", prisma_client) | |
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
try: | |
async def test(): | |
await litellm.proxy.proxy_server.prisma_client.connect() | |
# Generate a random key that does not exist | |
random_key = "sk-" + ''.join(random.choices(string.ascii_letters + string.digits, k=24)) | |
delete_key_request = KeyRequest(keys=[random_key]) | |
bearer_token = "Bearer sk-1234" | |
request = Request(scope={"type": "http"}) | |
request._url = URL(url="/key/delete") | |
# use admin to auth in | |
result = await litellm.proxy.proxy_server.user_api_key_auth(request=request, api_key=bearer_token) | |
result.user_role = LitellmUserRoles.PROXY_ADMIN | |
try: | |
await delete_key_fn(data=delete_key_request, user_api_key_dict=result) | |
pytest.fail("Expected ProxyException 404 for non-existent key, but delete_key_fn did not raise.") | |
except ProxyException as e: | |
print("Caught ProxyException:", e) | |
assert str(e.code) == "404" | |
assert "No keys found" in str(e.message) or "No matching keys or aliases found to delete" in str(e.message) | |
import asyncio | |
asyncio.run(test()) | |
except Exception as e: | |
pytest.fail(f"An exception occurred - {str(e)}") | |