litellm / litellm /proxy /utils.py
ka1kuk's picture
Upload 235 files
7db0ae4 verified
from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
from fastapi import HTTPException, status
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
def print_verbose(print_statement):
if litellm.set_verbose:
print(f"LiteLLM Proxy: {print_statement}") # noqa
### LOGGING ###
class ProxyLogging:
"""
Logging/Custom Handlers for proxy.
Implemented mainly to:
- log successful/failed db read/writes
- support the max parallel request integration
"""
def __init__(self, user_api_key_cache: DualCache):
## INITIALIZE LITELLM CALLBACKS ##
self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
self.max_budget_limiter = MaxBudgetLimiter()
self.alerting: Optional[List] = None
self.alerting_threshold: float = 300 # default to 5 min. threshold
pass
def update_values(
self, alerting: Optional[List], alerting_threshold: Optional[float]
):
self.alerting = alerting
if alerting_threshold is not None:
self.alerting_threshold = alerting_threshold
def _init_litellm_callbacks(self):
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
litellm.callbacks.append(self.max_parallel_request_limiter)
litellm.callbacks.append(self.max_budget_limiter)
for callback in litellm.callbacks:
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
if callback not in litellm.success_callback:
litellm.success_callback.append(callback)
if callback not in litellm.failure_callback:
litellm.failure_callback.append(callback)
if callback not in litellm._async_success_callback:
litellm._async_success_callback.append(callback)
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback)
if (
len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0
or len(litellm.failure_callback) > 0
):
callback_list = list(
set(
litellm.input_callback
+ litellm.success_callback
+ litellm.failure_callback
)
)
litellm.utils.set_callbacks(callback_list=callback_list)
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal["completion", "embeddings"],
):
"""
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
Covers:
1. /chat/completions
2. /embeddings
3. /image/generation
"""
### ALERTING ###
asyncio.create_task(self.response_taking_too_long())
try:
for callback in litellm.callbacks:
if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars(
callback.__class__
):
response = await callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"],
data=data,
call_type=call_type,
)
if response is not None:
data = response
print_verbose(f"final data being sent to {call_type} call: {data}")
return data
except Exception as e:
raise e
async def success_handler(
self,
user_api_key_dict: UserAPIKeyAuth,
response: Any,
call_type: Literal["completion", "embeddings"],
start_time,
end_time,
):
"""
Log successful API calls / db read/writes
"""
pass
async def response_taking_too_long(
self,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
type: Literal["hanging_request", "slow_response"] = "hanging_request",
):
if type == "hanging_request":
# Simulate a long-running operation that could take more than 5 minutes
await asyncio.sleep(
self.alerting_threshold
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
await self.alerting_handler(
message=f"Requests are hanging - {self.alerting_threshold}s+ request time",
level="Medium",
)
elif (
type == "slow_response" and start_time is not None and end_time is not None
):
if end_time - start_time > self.alerting_threshold:
await self.alerting_handler(
message=f"Responses are slow - {round(end_time-start_time,2)}s response time",
level="Low",
)
async def alerting_handler(
self, message: str, level: Literal["Low", "Medium", "High"]
):
"""
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
- Responses taking too long
- Requests are hanging
- Calls are failing
- DB Read/Writes are failing
Parameters:
level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'.
message: str - what is the alert about
"""
formatted_message = f"Level: {level}\n\nMessage: {message}"
if self.alerting is None:
return
for client in self.alerting:
if client == "slack":
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
if slack_webhook_url is None:
raise Exception("Missing SLACK_WEBHOOK_URL from environment")
payload = {"text": formatted_message}
headers = {"Content-type": "application/json"}
async with aiohttp.ClientSession() as session:
async with session.post(
slack_webhook_url, json=payload, headers=headers
) as response:
if response.status == 200:
pass
elif client == "sentry":
if litellm.utils.sentry_sdk_instance is not None:
litellm.utils.sentry_sdk_instance.capture_message(formatted_message)
else:
raise Exception("Missing SENTRY_DSN from environment")
async def failure_handler(self, original_exception):
"""
Log failed db read/writes
Currently only logs exceptions to sentry
"""
### ALERTING ###
if isinstance(original_exception, HTTPException):
error_message = original_exception.detail
else:
error_message = str(original_exception)
asyncio.create_task(
self.alerting_handler(
message=f"DB read/write call failed: {error_message}",
level="High",
)
)
if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception)
async def post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
"""
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
Covers:
1. /chat/completions
2. /embeddings
3. /image/generation
"""
### ALERTING ###
asyncio.create_task(
self.alerting_handler(
message=f"LLM API call failed: {str(original_exception)}", level="High"
)
)
for callback in litellm.callbacks:
try:
if isinstance(callback, CustomLogger):
await callback.async_post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=original_exception,
)
except Exception as e:
raise e
return
### DB CONNECTOR ###
# Define the retry decorator with backoff strategy
# Function to be called whenever a retry is about to happen
def on_backoff(details):
# The 'tries' key in the details dictionary contains the number of completed tries
print_verbose(f"Backing off... this was attempt #{details['tries']}")
class PrismaClient:
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print_verbose(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
)
## init logging object
self.proxy_logging_obj = proxy_logging_obj
try:
from prisma import Prisma # type: ignore
except Exception as e:
os.environ["DATABASE_URL"] = database_url
# Save the current working directory
original_dir = os.getcwd()
# set the working directory to where this script is
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
try:
subprocess.run(["prisma", "generate"])
subprocess.run(
["prisma", "db", "push", "--accept-data-loss"]
) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss
except:
raise Exception(
f"Unable to run prisma commands. Run `pip install prisma`"
)
finally:
os.chdir(original_dir)
# Now you can import the Prisma Client
from prisma import Prisma # type: ignore
self.db = Prisma(
http={
"limits": httpx.Limits(
max_connections=1000, max_keepalive_connections=100
)
}
) # Client to connect to Prisma db
def hash_token(self, token: str):
# Hash the string using SHA-256
hashed_token = hashlib.sha256(token.encode()).hexdigest()
return hashed_token
def jsonify_object(self, data: dict) -> dict:
db_data = copy.deepcopy(data)
for k, v in db_data.items():
if isinstance(v, dict):
db_data[k] = json.dumps(v)
return db_data
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def get_generic_data(
self,
key: str,
value: Any,
table_name: Literal["users", "keys", "config"],
):
"""
Generic implementation of get data
"""
try:
if table_name == "users":
response = await self.db.litellm_usertable.find_first(
where={key: value} # type: ignore
)
elif table_name == "keys":
response = await self.db.litellm_verificationtoken.find_first( # type: ignore
where={key: value} # type: ignore
)
elif table_name == "config":
response = await self.db.litellm_config.find_first( # type: ignore
where={key: value} # type: ignore
)
return response
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def get_data(
self,
token: Optional[str] = None,
user_id: Optional[str] = None,
table_name: Optional[Literal["user", "key", "config"]] = None,
query_type: Literal["find_unique", "find_all"] = "find_unique",
):
try:
print_verbose("PrismaClient: get_data")
response: Any = None
if token is not None or (table_name is not None and table_name == "key"):
# check if plain text or hash
if token is not None:
hashed_token = token
if token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
print_verbose("PrismaClient: find_unique")
if query_type == "find_unique":
response = await self.db.litellm_verificationtoken.find_unique(
where={"token": hashed_token}
)
elif query_type == "find_all" and user_id is not None:
response = await self.db.litellm_verificationtoken.find_many(
where={"user_id": user_id}
)
print_verbose(f"PrismaClient: response={response}")
if response is not None:
return response
else:
# Token does not exist.
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid user key",
)
elif user_id is not None:
response = await self.db.litellm_usertable.find_unique( # type: ignore
where={
"user_id": user_id,
}
)
return response
except Exception as e:
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
import traceback
traceback.print_exc()
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def insert_data(
self, data: dict, table_name: Literal["user+key", "config"] = "user+key"
):
"""
Add a key to the database. If it already exists, do nothing.
"""
try:
if table_name == "user+key":
token = data["token"]
hashed_token = self.hash_token(token=token)
db_data = self.jsonify_object(data=data)
db_data["token"] = hashed_token
max_budget = db_data.pop("max_budget", None)
user_email = db_data.pop("user_email", None)
print_verbose(
"PrismaClient: Before upsert into litellm_verificationtoken"
)
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
where={
"token": hashed_token,
},
data={
"create": {**db_data}, # type: ignore
"update": {}, # don't do anything if it already exists
},
)
new_user_row = await self.db.litellm_usertable.upsert(
where={"user_id": data["user_id"]},
data={
"create": {
"user_id": data["user_id"],
"max_budget": max_budget,
"user_email": user_email,
},
"update": {}, # don't do anything if it already exists
},
)
return new_verification_token
elif table_name == "config":
"""
For each param,
get the existing table values
Add the new values
Update DB
"""
tasks = []
for k, v in data.items():
updated_data = v
updated_data = json.dumps(updated_data)
updated_table_row = self.db.litellm_config.upsert(
where={"param_name": k},
data={
"create": {"param_name": k, "param_value": updated_data},
"update": {"param_value": updated_data},
},
)
tasks.append(updated_table_row)
await asyncio.gather(*tasks)
except Exception as e:
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def update_data(
self,
token: Optional[str] = None,
data: dict = {},
user_id: Optional[str] = None,
):
"""
Update existing data
"""
try:
db_data = self.jsonify_object(data=data)
if token is not None:
print_verbose(f"token: {token}")
# check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={"token": token}, # type: ignore
data={**db_data}, # type: ignore
)
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
return {"token": token, "data": db_data}
elif user_id is not None:
"""
If data['spend'] + data['user'], update the user table with spend info as well
"""
update_user_row = await self.db.litellm_usertable.update(
where={"user_id": user_id}, # type: ignore
data={**db_data}, # type: ignore
)
return {"user_id": user_id, "data": db_data}
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def delete_data(self, tokens: List):
"""
Allow user to delete a key(s)
"""
try:
hashed_tokens = [self.hash_token(token=token) for token in tokens]
await self.db.litellm_verificationtoken.delete_many(
where={"token": {"in": hashed_tokens}}
)
return {"deleted_keys": tokens}
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def connect(self):
try:
if self.db.is_connected() == False:
await self.db.connect()
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
# Define a retrying strategy with exponential backoff
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def disconnect(self):
try:
await self.db.disconnect()
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
)
raise e
class DBClient:
"""
Routes requests for CustomAuth
[TODO] route b/w customauth and prisma
"""
def __init__(
self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict
) -> None:
if custom_db_type == "dynamo_db":
from litellm.proxy.db.dynamo_db import DynamoDBWrapper
self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args))
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
"""
Check if key valid
"""
return await self.db.get_data(key=key, table_name=table_name)
async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config"]
):
"""
For new key / user logic
"""
return await self.db.insert_data(value=value, table_name=table_name)
async def update_data(
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
):
"""
For cost tracking logic
key - hash_key value \n
value - dict with updated values
"""
return await self.db.update_data(key=key, value=value, table_name=table_name)
async def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
"""
For /key/delete endpoints
"""
return await self.db.delete_data(keys=keys, table_name=table_name)
async def connect(self):
"""
For connecting to db and creating / updating any tables
"""
return await self.db.connect()
async def disconnect(self):
"""
For closing connection on server shutdown
"""
return await self.db.disconnect()
### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
try:
print_verbose(f"value: {value}")
# Split the path by dots to separate module from instance
parts = value.split(".")
# The module path is all but the last part, and the instance_name is the last part
module_name = ".".join(parts[:-1])
instance_name = parts[-1]
# If config_file_path is provided, use it to determine the module spec and load the module
if config_file_path is not None:
directory = os.path.dirname(config_file_path)
module_file_path = os.path.join(directory, *module_name.split("."))
module_file_path += ".py"
spec = importlib.util.spec_from_file_location(module_name, module_file_path)
if spec is None:
raise ImportError(
f"Could not find a module specification for {module_file_path}"
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
else:
# Dynamically import the module
module = importlib.import_module(module_name)
# Get the instance from the module
instance = getattr(module, instance_name)
return instance
except ImportError as e:
# Re-raise the exception with a user-friendly message
raise ImportError(f"Could not import {instance_name} from {module_name}") from e
except Exception as e:
raise e
### HELPER FUNCTIONS ###
async def _cache_user_row(
user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]
):
"""
Check if a user_id exists in cache,
if not retrieve it.
"""
print_verbose(f"Prisma: _cache_user_row, user_id: {user_id}")
cache_key = f"{user_id}_user_api_key_user_id"
response = cache.get_cache(key=cache_key)
if response is None: # Cache miss
if isinstance(db, PrismaClient):
user_row = await db.get_data(user_id=user_id)
elif isinstance(db, DBClient):
user_row = await db.get_data(key=user_id, table_name="user")
if user_row is not None:
print_verbose(f"User Row: {user_row}, type = {type(user_row)}")
if hasattr(user_row, "model_dump_json") and callable(
getattr(user_row, "model_dump_json")
):
cache_value = user_row.model_dump_json()
cache.set_cache(
key=cache_key, value=cache_value, ttl=600
) # store for 10 minutes
return
async def send_email(sender_name, sender_email, receiver_email, subject, html):
"""
smtp_host,
smtp_port,
smtp_username,
smtp_password,
sender_name,
sender_email,
"""
## SERVER SETUP ##
smtp_host = os.getenv("SMTP_HOST")
smtp_port = os.getenv("SMTP_PORT", 587) # default to port 587
smtp_username = os.getenv("SMTP_USERNAME")
smtp_password = os.getenv("SMTP_PASSWORD")
## EMAIL SETUP ##
email_message = MIMEMultipart()
email_message["From"] = f"{sender_name} <{sender_email}>"
email_message["To"] = receiver_email
email_message["Subject"] = subject
# Attach the body to the email
email_message.attach(MIMEText(html, "html"))
try:
print_verbose(f"SMTP Connection Init")
# Establish a secure connection with the SMTP server
with smtplib.SMTP(smtp_host, smtp_port) as server:
server.starttls()
# Login to your email account
server.login(smtp_username, smtp_password)
# Send the email
server.send_message(email_message)
except Exception as e:
print_verbose("An error occurred while sending the email:", str(e))