|
from typing import Optional, List, Any, Literal, Union |
|
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx |
|
import litellm, backoff |
|
from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs |
|
from litellm.caching import DualCache |
|
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler |
|
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter |
|
from litellm.integrations.custom_logger import CustomLogger |
|
from litellm.proxy.db.base_client import CustomDB |
|
from fastapi import HTTPException, status |
|
import smtplib |
|
from email.mime.text import MIMEText |
|
from email.mime.multipart import MIMEMultipart |
|
|
|
|
|
def print_verbose(print_statement): |
|
if litellm.set_verbose: |
|
print(f"LiteLLM Proxy: {print_statement}") |
|
|
|
|
|
|
|
class ProxyLogging: |
|
""" |
|
Logging/Custom Handlers for proxy. |
|
|
|
Implemented mainly to: |
|
- log successful/failed db read/writes |
|
- support the max parallel request integration |
|
""" |
|
|
|
def __init__(self, user_api_key_cache: DualCache): |
|
|
|
self.call_details: dict = {} |
|
self.call_details["user_api_key_cache"] = user_api_key_cache |
|
self.max_parallel_request_limiter = MaxParallelRequestsHandler() |
|
self.max_budget_limiter = MaxBudgetLimiter() |
|
self.alerting: Optional[List] = None |
|
self.alerting_threshold: float = 300 |
|
pass |
|
|
|
def update_values( |
|
self, alerting: Optional[List], alerting_threshold: Optional[float] |
|
): |
|
self.alerting = alerting |
|
if alerting_threshold is not None: |
|
self.alerting_threshold = alerting_threshold |
|
|
|
def _init_litellm_callbacks(self): |
|
print_verbose(f"INITIALIZING LITELLM CALLBACKS!") |
|
litellm.callbacks.append(self.max_parallel_request_limiter) |
|
litellm.callbacks.append(self.max_budget_limiter) |
|
for callback in litellm.callbacks: |
|
if callback not in litellm.input_callback: |
|
litellm.input_callback.append(callback) |
|
if callback not in litellm.success_callback: |
|
litellm.success_callback.append(callback) |
|
if callback not in litellm.failure_callback: |
|
litellm.failure_callback.append(callback) |
|
if callback not in litellm._async_success_callback: |
|
litellm._async_success_callback.append(callback) |
|
if callback not in litellm._async_failure_callback: |
|
litellm._async_failure_callback.append(callback) |
|
|
|
if ( |
|
len(litellm.input_callback) > 0 |
|
or len(litellm.success_callback) > 0 |
|
or len(litellm.failure_callback) > 0 |
|
): |
|
callback_list = list( |
|
set( |
|
litellm.input_callback |
|
+ litellm.success_callback |
|
+ litellm.failure_callback |
|
) |
|
) |
|
litellm.utils.set_callbacks(callback_list=callback_list) |
|
|
|
async def pre_call_hook( |
|
self, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
data: dict, |
|
call_type: Literal["completion", "embeddings"], |
|
): |
|
""" |
|
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. |
|
|
|
Covers: |
|
1. /chat/completions |
|
2. /embeddings |
|
3. /image/generation |
|
""" |
|
|
|
asyncio.create_task(self.response_taking_too_long()) |
|
|
|
try: |
|
for callback in litellm.callbacks: |
|
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars( |
|
callback.__class__ |
|
): |
|
response = await callback.async_pre_call_hook( |
|
user_api_key_dict=user_api_key_dict, |
|
cache=self.call_details["user_api_key_cache"], |
|
data=data, |
|
call_type=call_type, |
|
) |
|
if response is not None: |
|
data = response |
|
|
|
print_verbose(f"final data being sent to {call_type} call: {data}") |
|
return data |
|
except Exception as e: |
|
raise e |
|
|
|
async def success_handler( |
|
self, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
response: Any, |
|
call_type: Literal["completion", "embeddings"], |
|
start_time, |
|
end_time, |
|
): |
|
""" |
|
Log successful API calls / db read/writes |
|
""" |
|
|
|
pass |
|
|
|
async def response_taking_too_long( |
|
self, |
|
start_time: Optional[float] = None, |
|
end_time: Optional[float] = None, |
|
type: Literal["hanging_request", "slow_response"] = "hanging_request", |
|
): |
|
if type == "hanging_request": |
|
|
|
await asyncio.sleep( |
|
self.alerting_threshold |
|
) |
|
|
|
await self.alerting_handler( |
|
message=f"Requests are hanging - {self.alerting_threshold}s+ request time", |
|
level="Medium", |
|
) |
|
|
|
elif ( |
|
type == "slow_response" and start_time is not None and end_time is not None |
|
): |
|
if end_time - start_time > self.alerting_threshold: |
|
await self.alerting_handler( |
|
message=f"Responses are slow - {round(end_time-start_time,2)}s response time", |
|
level="Low", |
|
) |
|
|
|
async def alerting_handler( |
|
self, message: str, level: Literal["Low", "Medium", "High"] |
|
): |
|
""" |
|
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298 |
|
|
|
- Responses taking too long |
|
- Requests are hanging |
|
- Calls are failing |
|
- DB Read/Writes are failing |
|
|
|
Parameters: |
|
level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'. |
|
message: str - what is the alert about |
|
""" |
|
formatted_message = f"Level: {level}\n\nMessage: {message}" |
|
if self.alerting is None: |
|
return |
|
|
|
for client in self.alerting: |
|
if client == "slack": |
|
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None) |
|
if slack_webhook_url is None: |
|
raise Exception("Missing SLACK_WEBHOOK_URL from environment") |
|
payload = {"text": formatted_message} |
|
headers = {"Content-type": "application/json"} |
|
async with aiohttp.ClientSession() as session: |
|
async with session.post( |
|
slack_webhook_url, json=payload, headers=headers |
|
) as response: |
|
if response.status == 200: |
|
pass |
|
elif client == "sentry": |
|
if litellm.utils.sentry_sdk_instance is not None: |
|
litellm.utils.sentry_sdk_instance.capture_message(formatted_message) |
|
else: |
|
raise Exception("Missing SENTRY_DSN from environment") |
|
|
|
async def failure_handler(self, original_exception): |
|
""" |
|
Log failed db read/writes |
|
|
|
Currently only logs exceptions to sentry |
|
""" |
|
|
|
if isinstance(original_exception, HTTPException): |
|
error_message = original_exception.detail |
|
else: |
|
error_message = str(original_exception) |
|
asyncio.create_task( |
|
self.alerting_handler( |
|
message=f"DB read/write call failed: {error_message}", |
|
level="High", |
|
) |
|
) |
|
|
|
if litellm.utils.capture_exception: |
|
litellm.utils.capture_exception(error=original_exception) |
|
|
|
async def post_call_failure_hook( |
|
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth |
|
): |
|
""" |
|
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body. |
|
|
|
Covers: |
|
1. /chat/completions |
|
2. /embeddings |
|
3. /image/generation |
|
""" |
|
|
|
|
|
asyncio.create_task( |
|
self.alerting_handler( |
|
message=f"LLM API call failed: {str(original_exception)}", level="High" |
|
) |
|
) |
|
|
|
for callback in litellm.callbacks: |
|
try: |
|
if isinstance(callback, CustomLogger): |
|
await callback.async_post_call_failure_hook( |
|
user_api_key_dict=user_api_key_dict, |
|
original_exception=original_exception, |
|
) |
|
except Exception as e: |
|
raise e |
|
return |
|
|
|
|
|
|
|
|
|
|
|
def on_backoff(details): |
|
|
|
print_verbose(f"Backing off... this was attempt #{details['tries']}") |
|
|
|
|
|
class PrismaClient: |
|
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): |
|
print_verbose( |
|
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'" |
|
) |
|
|
|
self.proxy_logging_obj = proxy_logging_obj |
|
try: |
|
from prisma import Prisma |
|
except Exception as e: |
|
os.environ["DATABASE_URL"] = database_url |
|
|
|
original_dir = os.getcwd() |
|
|
|
abspath = os.path.abspath(__file__) |
|
dname = os.path.dirname(abspath) |
|
os.chdir(dname) |
|
|
|
try: |
|
subprocess.run(["prisma", "generate"]) |
|
subprocess.run( |
|
["prisma", "db", "push", "--accept-data-loss"] |
|
) |
|
except: |
|
raise Exception( |
|
f"Unable to run prisma commands. Run `pip install prisma`" |
|
) |
|
finally: |
|
os.chdir(original_dir) |
|
|
|
from prisma import Prisma |
|
|
|
self.db = Prisma( |
|
http={ |
|
"limits": httpx.Limits( |
|
max_connections=1000, max_keepalive_connections=100 |
|
) |
|
} |
|
) |
|
|
|
def hash_token(self, token: str): |
|
|
|
hashed_token = hashlib.sha256(token.encode()).hexdigest() |
|
|
|
return hashed_token |
|
|
|
def jsonify_object(self, data: dict) -> dict: |
|
db_data = copy.deepcopy(data) |
|
|
|
for k, v in db_data.items(): |
|
if isinstance(v, dict): |
|
db_data[k] = json.dumps(v) |
|
return db_data |
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=3, |
|
max_time=10, |
|
on_backoff=on_backoff, |
|
) |
|
async def get_generic_data( |
|
self, |
|
key: str, |
|
value: Any, |
|
table_name: Literal["users", "keys", "config"], |
|
): |
|
""" |
|
Generic implementation of get data |
|
""" |
|
try: |
|
if table_name == "users": |
|
response = await self.db.litellm_usertable.find_first( |
|
where={key: value} |
|
) |
|
elif table_name == "keys": |
|
response = await self.db.litellm_verificationtoken.find_first( |
|
where={key: value} |
|
) |
|
elif table_name == "config": |
|
response = await self.db.litellm_config.find_first( |
|
where={key: value} |
|
) |
|
return response |
|
except Exception as e: |
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler(original_exception=e) |
|
) |
|
raise e |
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=3, |
|
max_time=10, |
|
on_backoff=on_backoff, |
|
) |
|
async def get_data( |
|
self, |
|
token: Optional[str] = None, |
|
user_id: Optional[str] = None, |
|
table_name: Optional[Literal["user", "key", "config"]] = None, |
|
query_type: Literal["find_unique", "find_all"] = "find_unique", |
|
): |
|
try: |
|
print_verbose("PrismaClient: get_data") |
|
|
|
response: Any = None |
|
if token is not None or (table_name is not None and table_name == "key"): |
|
|
|
if token is not None: |
|
hashed_token = token |
|
if token.startswith("sk-"): |
|
hashed_token = self.hash_token(token=token) |
|
print_verbose("PrismaClient: find_unique") |
|
if query_type == "find_unique": |
|
response = await self.db.litellm_verificationtoken.find_unique( |
|
where={"token": hashed_token} |
|
) |
|
elif query_type == "find_all" and user_id is not None: |
|
response = await self.db.litellm_verificationtoken.find_many( |
|
where={"user_id": user_id} |
|
) |
|
print_verbose(f"PrismaClient: response={response}") |
|
if response is not None: |
|
return response |
|
else: |
|
|
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="invalid user key", |
|
) |
|
elif user_id is not None: |
|
response = await self.db.litellm_usertable.find_unique( |
|
where={ |
|
"user_id": user_id, |
|
} |
|
) |
|
return response |
|
except Exception as e: |
|
print_verbose(f"LiteLLM Prisma Client Exception: {e}") |
|
import traceback |
|
|
|
traceback.print_exc() |
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler(original_exception=e) |
|
) |
|
raise e |
|
|
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=3, |
|
max_time=10, |
|
on_backoff=on_backoff, |
|
) |
|
async def insert_data( |
|
self, data: dict, table_name: Literal["user+key", "config"] = "user+key" |
|
): |
|
""" |
|
Add a key to the database. If it already exists, do nothing. |
|
""" |
|
try: |
|
if table_name == "user+key": |
|
token = data["token"] |
|
hashed_token = self.hash_token(token=token) |
|
db_data = self.jsonify_object(data=data) |
|
db_data["token"] = hashed_token |
|
max_budget = db_data.pop("max_budget", None) |
|
user_email = db_data.pop("user_email", None) |
|
print_verbose( |
|
"PrismaClient: Before upsert into litellm_verificationtoken" |
|
) |
|
new_verification_token = await self.db.litellm_verificationtoken.upsert( |
|
where={ |
|
"token": hashed_token, |
|
}, |
|
data={ |
|
"create": {**db_data}, |
|
"update": {}, |
|
}, |
|
) |
|
|
|
new_user_row = await self.db.litellm_usertable.upsert( |
|
where={"user_id": data["user_id"]}, |
|
data={ |
|
"create": { |
|
"user_id": data["user_id"], |
|
"max_budget": max_budget, |
|
"user_email": user_email, |
|
}, |
|
"update": {}, |
|
}, |
|
) |
|
return new_verification_token |
|
elif table_name == "config": |
|
""" |
|
For each param, |
|
get the existing table values |
|
|
|
Add the new values |
|
|
|
Update DB |
|
""" |
|
tasks = [] |
|
for k, v in data.items(): |
|
updated_data = v |
|
updated_data = json.dumps(updated_data) |
|
updated_table_row = self.db.litellm_config.upsert( |
|
where={"param_name": k}, |
|
data={ |
|
"create": {"param_name": k, "param_value": updated_data}, |
|
"update": {"param_value": updated_data}, |
|
}, |
|
) |
|
|
|
tasks.append(updated_table_row) |
|
|
|
await asyncio.gather(*tasks) |
|
except Exception as e: |
|
print_verbose(f"LiteLLM Prisma Client Exception: {e}") |
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler(original_exception=e) |
|
) |
|
raise e |
|
|
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=3, |
|
max_time=10, |
|
on_backoff=on_backoff, |
|
) |
|
async def update_data( |
|
self, |
|
token: Optional[str] = None, |
|
data: dict = {}, |
|
user_id: Optional[str] = None, |
|
): |
|
""" |
|
Update existing data |
|
""" |
|
try: |
|
db_data = self.jsonify_object(data=data) |
|
if token is not None: |
|
print_verbose(f"token: {token}") |
|
|
|
if token.startswith("sk-"): |
|
token = self.hash_token(token=token) |
|
db_data["token"] = token |
|
response = await self.db.litellm_verificationtoken.update( |
|
where={"token": token}, |
|
data={**db_data}, |
|
) |
|
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m") |
|
return {"token": token, "data": db_data} |
|
elif user_id is not None: |
|
""" |
|
If data['spend'] + data['user'], update the user table with spend info as well |
|
""" |
|
update_user_row = await self.db.litellm_usertable.update( |
|
where={"user_id": user_id}, |
|
data={**db_data}, |
|
) |
|
return {"user_id": user_id, "data": db_data} |
|
except Exception as e: |
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler(original_exception=e) |
|
) |
|
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m") |
|
raise e |
|
|
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=3, |
|
max_time=10, |
|
on_backoff=on_backoff, |
|
) |
|
async def delete_data(self, tokens: List): |
|
""" |
|
Allow user to delete a key(s) |
|
""" |
|
try: |
|
hashed_tokens = [self.hash_token(token=token) for token in tokens] |
|
await self.db.litellm_verificationtoken.delete_many( |
|
where={"token": {"in": hashed_tokens}} |
|
) |
|
return {"deleted_keys": tokens} |
|
except Exception as e: |
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler(original_exception=e) |
|
) |
|
raise e |
|
|
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=3, |
|
max_time=10, |
|
on_backoff=on_backoff, |
|
) |
|
async def connect(self): |
|
try: |
|
if self.db.is_connected() == False: |
|
await self.db.connect() |
|
except Exception as e: |
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler(original_exception=e) |
|
) |
|
raise e |
|
|
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
Exception, |
|
max_tries=3, |
|
max_time=10, |
|
on_backoff=on_backoff, |
|
) |
|
async def disconnect(self): |
|
try: |
|
await self.db.disconnect() |
|
except Exception as e: |
|
asyncio.create_task( |
|
self.proxy_logging_obj.failure_handler(original_exception=e) |
|
) |
|
raise e |
|
|
|
|
|
class DBClient: |
|
""" |
|
Routes requests for CustomAuth |
|
|
|
[TODO] route b/w customauth and prisma |
|
""" |
|
|
|
def __init__( |
|
self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict |
|
) -> None: |
|
if custom_db_type == "dynamo_db": |
|
from litellm.proxy.db.dynamo_db import DynamoDBWrapper |
|
|
|
self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args)) |
|
|
|
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]): |
|
""" |
|
Check if key valid |
|
""" |
|
return await self.db.get_data(key=key, table_name=table_name) |
|
|
|
async def insert_data( |
|
self, value: Any, table_name: Literal["user", "key", "config"] |
|
): |
|
""" |
|
For new key / user logic |
|
""" |
|
return await self.db.insert_data(value=value, table_name=table_name) |
|
|
|
async def update_data( |
|
self, key: str, value: Any, table_name: Literal["user", "key", "config"] |
|
): |
|
""" |
|
For cost tracking logic |
|
|
|
key - hash_key value \n |
|
value - dict with updated values |
|
""" |
|
return await self.db.update_data(key=key, value=value, table_name=table_name) |
|
|
|
async def delete_data( |
|
self, keys: List[str], table_name: Literal["user", "key", "config"] |
|
): |
|
""" |
|
For /key/delete endpoints |
|
""" |
|
return await self.db.delete_data(keys=keys, table_name=table_name) |
|
|
|
async def connect(self): |
|
""" |
|
For connecting to db and creating / updating any tables |
|
""" |
|
return await self.db.connect() |
|
|
|
async def disconnect(self): |
|
""" |
|
For closing connection on server shutdown |
|
""" |
|
return await self.db.disconnect() |
|
|
|
|
|
|
|
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: |
|
try: |
|
print_verbose(f"value: {value}") |
|
|
|
parts = value.split(".") |
|
|
|
|
|
module_name = ".".join(parts[:-1]) |
|
instance_name = parts[-1] |
|
|
|
|
|
if config_file_path is not None: |
|
directory = os.path.dirname(config_file_path) |
|
module_file_path = os.path.join(directory, *module_name.split(".")) |
|
module_file_path += ".py" |
|
|
|
spec = importlib.util.spec_from_file_location(module_name, module_file_path) |
|
if spec is None: |
|
raise ImportError( |
|
f"Could not find a module specification for {module_file_path}" |
|
) |
|
module = importlib.util.module_from_spec(spec) |
|
spec.loader.exec_module(module) |
|
else: |
|
|
|
module = importlib.import_module(module_name) |
|
|
|
|
|
instance = getattr(module, instance_name) |
|
|
|
return instance |
|
except ImportError as e: |
|
|
|
raise ImportError(f"Could not import {instance_name} from {module_name}") from e |
|
except Exception as e: |
|
raise e |
|
|
|
|
|
|
|
async def _cache_user_row( |
|
user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient] |
|
): |
|
""" |
|
Check if a user_id exists in cache, |
|
if not retrieve it. |
|
""" |
|
print_verbose(f"Prisma: _cache_user_row, user_id: {user_id}") |
|
cache_key = f"{user_id}_user_api_key_user_id" |
|
response = cache.get_cache(key=cache_key) |
|
if response is None: |
|
if isinstance(db, PrismaClient): |
|
user_row = await db.get_data(user_id=user_id) |
|
elif isinstance(db, DBClient): |
|
user_row = await db.get_data(key=user_id, table_name="user") |
|
if user_row is not None: |
|
print_verbose(f"User Row: {user_row}, type = {type(user_row)}") |
|
if hasattr(user_row, "model_dump_json") and callable( |
|
getattr(user_row, "model_dump_json") |
|
): |
|
cache_value = user_row.model_dump_json() |
|
cache.set_cache( |
|
key=cache_key, value=cache_value, ttl=600 |
|
) |
|
return |
|
|
|
|
|
async def send_email(sender_name, sender_email, receiver_email, subject, html): |
|
""" |
|
smtp_host, |
|
smtp_port, |
|
smtp_username, |
|
smtp_password, |
|
sender_name, |
|
sender_email, |
|
""" |
|
|
|
smtp_host = os.getenv("SMTP_HOST") |
|
smtp_port = os.getenv("SMTP_PORT", 587) |
|
smtp_username = os.getenv("SMTP_USERNAME") |
|
smtp_password = os.getenv("SMTP_PASSWORD") |
|
|
|
email_message = MIMEMultipart() |
|
email_message["From"] = f"{sender_name} <{sender_email}>" |
|
email_message["To"] = receiver_email |
|
email_message["Subject"] = subject |
|
|
|
|
|
email_message.attach(MIMEText(html, "html")) |
|
|
|
try: |
|
print_verbose(f"SMTP Connection Init") |
|
|
|
with smtplib.SMTP(smtp_host, smtp_port) as server: |
|
server.starttls() |
|
|
|
|
|
server.login(smtp_username, smtp_password) |
|
|
|
|
|
server.send_message(email_message) |
|
|
|
except Exception as e: |
|
print_verbose("An error occurred while sending the email:", str(e)) |
|
|