from typing import Optional, List, Any, Literal, Union import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx import litellm, backoff from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs from litellm.caching import DualCache from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter from litellm.integrations.custom_logger import CustomLogger from litellm.proxy.db.base_client import CustomDB from fastapi import HTTPException, status import smtplib from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart def print_verbose(print_statement): if litellm.set_verbose: print(f"LiteLLM Proxy: {print_statement}") # noqa ### LOGGING ### class ProxyLogging: """ Logging/Custom Handlers for proxy. Implemented mainly to: - log successful/failed db read/writes - support the max parallel request integration """ def __init__(self, user_api_key_cache: DualCache): ## INITIALIZE LITELLM CALLBACKS ## self.call_details: dict = {} self.call_details["user_api_key_cache"] = user_api_key_cache self.max_parallel_request_limiter = MaxParallelRequestsHandler() self.max_budget_limiter = MaxBudgetLimiter() self.alerting: Optional[List] = None self.alerting_threshold: float = 300 # default to 5 min. threshold pass def update_values( self, alerting: Optional[List], alerting_threshold: Optional[float] ): self.alerting = alerting if alerting_threshold is not None: self.alerting_threshold = alerting_threshold def _init_litellm_callbacks(self): print_verbose(f"INITIALIZING LITELLM CALLBACKS!") litellm.callbacks.append(self.max_parallel_request_limiter) litellm.callbacks.append(self.max_budget_limiter) for callback in litellm.callbacks: if callback not in litellm.input_callback: litellm.input_callback.append(callback) if callback not in litellm.success_callback: litellm.success_callback.append(callback) if callback not in litellm.failure_callback: litellm.failure_callback.append(callback) if callback not in litellm._async_success_callback: litellm._async_success_callback.append(callback) if callback not in litellm._async_failure_callback: litellm._async_failure_callback.append(callback) if ( len(litellm.input_callback) > 0 or len(litellm.success_callback) > 0 or len(litellm.failure_callback) > 0 ): callback_list = list( set( litellm.input_callback + litellm.success_callback + litellm.failure_callback ) ) litellm.utils.set_callbacks(callback_list=callback_list) async def pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"], ): """ Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. Covers: 1. /chat/completions 2. /embeddings 3. /image/generation """ ### ALERTING ### asyncio.create_task(self.response_taking_too_long()) try: for callback in litellm.callbacks: if isinstance(callback, CustomLogger) and "async_pre_call_hook" in vars( callback.__class__ ): response = await callback.async_pre_call_hook( user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], data=data, call_type=call_type, ) if response is not None: data = response print_verbose(f"final data being sent to {call_type} call: {data}") return data except Exception as e: raise e async def success_handler( self, user_api_key_dict: UserAPIKeyAuth, response: Any, call_type: Literal["completion", "embeddings"], start_time, end_time, ): """ Log successful API calls / db read/writes """ pass async def response_taking_too_long( self, start_time: Optional[float] = None, end_time: Optional[float] = None, type: Literal["hanging_request", "slow_response"] = "hanging_request", ): if type == "hanging_request": # Simulate a long-running operation that could take more than 5 minutes await asyncio.sleep( self.alerting_threshold ) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests await self.alerting_handler( message=f"Requests are hanging - {self.alerting_threshold}s+ request time", level="Medium", ) elif ( type == "slow_response" and start_time is not None and end_time is not None ): if end_time - start_time > self.alerting_threshold: await self.alerting_handler( message=f"Responses are slow - {round(end_time-start_time,2)}s response time", level="Low", ) async def alerting_handler( self, message: str, level: Literal["Low", "Medium", "High"] ): """ Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298 - Responses taking too long - Requests are hanging - Calls are failing - DB Read/Writes are failing Parameters: level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'. message: str - what is the alert about """ formatted_message = f"Level: {level}\n\nMessage: {message}" if self.alerting is None: return for client in self.alerting: if client == "slack": slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None) if slack_webhook_url is None: raise Exception("Missing SLACK_WEBHOOK_URL from environment") payload = {"text": formatted_message} headers = {"Content-type": "application/json"} async with aiohttp.ClientSession() as session: async with session.post( slack_webhook_url, json=payload, headers=headers ) as response: if response.status == 200: pass elif client == "sentry": if litellm.utils.sentry_sdk_instance is not None: litellm.utils.sentry_sdk_instance.capture_message(formatted_message) else: raise Exception("Missing SENTRY_DSN from environment") async def failure_handler(self, original_exception): """ Log failed db read/writes Currently only logs exceptions to sentry """ ### ALERTING ### if isinstance(original_exception, HTTPException): error_message = original_exception.detail else: error_message = str(original_exception) asyncio.create_task( self.alerting_handler( message=f"DB read/write call failed: {error_message}", level="High", ) ) if litellm.utils.capture_exception: litellm.utils.capture_exception(error=original_exception) async def post_call_failure_hook( self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth ): """ Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body. Covers: 1. /chat/completions 2. /embeddings 3. /image/generation """ ### ALERTING ### asyncio.create_task( self.alerting_handler( message=f"LLM API call failed: {str(original_exception)}", level="High" ) ) for callback in litellm.callbacks: try: if isinstance(callback, CustomLogger): await callback.async_post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=original_exception, ) except Exception as e: raise e return ### DB CONNECTOR ### # Define the retry decorator with backoff strategy # Function to be called whenever a retry is about to happen def on_backoff(details): # The 'tries' key in the details dictionary contains the number of completed tries print_verbose(f"Backing off... this was attempt #{details['tries']}") class PrismaClient: def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): print_verbose( "LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'" ) ## init logging object self.proxy_logging_obj = proxy_logging_obj try: from prisma import Prisma # type: ignore except Exception as e: os.environ["DATABASE_URL"] = database_url # Save the current working directory original_dir = os.getcwd() # set the working directory to where this script is abspath = os.path.abspath(__file__) dname = os.path.dirname(abspath) os.chdir(dname) try: subprocess.run(["prisma", "generate"]) subprocess.run( ["prisma", "db", "push", "--accept-data-loss"] ) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss except: raise Exception( f"Unable to run prisma commands. Run `pip install prisma`" ) finally: os.chdir(original_dir) # Now you can import the Prisma Client from prisma import Prisma # type: ignore self.db = Prisma( http={ "limits": httpx.Limits( max_connections=1000, max_keepalive_connections=100 ) } ) # Client to connect to Prisma db def hash_token(self, token: str): # Hash the string using SHA-256 hashed_token = hashlib.sha256(token.encode()).hexdigest() return hashed_token def jsonify_object(self, data: dict) -> dict: db_data = copy.deepcopy(data) for k, v in db_data.items(): if isinstance(v, dict): db_data[k] = json.dumps(v) return db_data @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff max_tries=3, # maximum number of retries max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) async def get_generic_data( self, key: str, value: Any, table_name: Literal["users", "keys", "config"], ): """ Generic implementation of get data """ try: if table_name == "users": response = await self.db.litellm_usertable.find_first( where={key: value} # type: ignore ) elif table_name == "keys": response = await self.db.litellm_verificationtoken.find_first( # type: ignore where={key: value} # type: ignore ) elif table_name == "config": response = await self.db.litellm_config.find_first( # type: ignore where={key: value} # type: ignore ) return response except Exception as e: asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) ) raise e @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff max_tries=3, # maximum number of retries max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) async def get_data( self, token: Optional[str] = None, user_id: Optional[str] = None, table_name: Optional[Literal["user", "key", "config"]] = None, query_type: Literal["find_unique", "find_all"] = "find_unique", ): try: print_verbose("PrismaClient: get_data") response: Any = None if token is not None or (table_name is not None and table_name == "key"): # check if plain text or hash if token is not None: hashed_token = token if token.startswith("sk-"): hashed_token = self.hash_token(token=token) print_verbose("PrismaClient: find_unique") if query_type == "find_unique": response = await self.db.litellm_verificationtoken.find_unique( where={"token": hashed_token} ) elif query_type == "find_all" and user_id is not None: response = await self.db.litellm_verificationtoken.find_many( where={"user_id": user_id} ) print_verbose(f"PrismaClient: response={response}") if response is not None: return response else: # Token does not exist. raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid user key", ) elif user_id is not None: response = await self.db.litellm_usertable.find_unique( # type: ignore where={ "user_id": user_id, } ) return response except Exception as e: print_verbose(f"LiteLLM Prisma Client Exception: {e}") import traceback traceback.print_exc() asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) ) raise e # Define a retrying strategy with exponential backoff @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff max_tries=3, # maximum number of retries max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) async def insert_data( self, data: dict, table_name: Literal["user+key", "config"] = "user+key" ): """ Add a key to the database. If it already exists, do nothing. """ try: if table_name == "user+key": token = data["token"] hashed_token = self.hash_token(token=token) db_data = self.jsonify_object(data=data) db_data["token"] = hashed_token max_budget = db_data.pop("max_budget", None) user_email = db_data.pop("user_email", None) print_verbose( "PrismaClient: Before upsert into litellm_verificationtoken" ) new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore where={ "token": hashed_token, }, data={ "create": {**db_data}, # type: ignore "update": {}, # don't do anything if it already exists }, ) new_user_row = await self.db.litellm_usertable.upsert( where={"user_id": data["user_id"]}, data={ "create": { "user_id": data["user_id"], "max_budget": max_budget, "user_email": user_email, }, "update": {}, # don't do anything if it already exists }, ) return new_verification_token elif table_name == "config": """ For each param, get the existing table values Add the new values Update DB """ tasks = [] for k, v in data.items(): updated_data = v updated_data = json.dumps(updated_data) updated_table_row = self.db.litellm_config.upsert( where={"param_name": k}, data={ "create": {"param_name": k, "param_value": updated_data}, "update": {"param_value": updated_data}, }, ) tasks.append(updated_table_row) await asyncio.gather(*tasks) except Exception as e: print_verbose(f"LiteLLM Prisma Client Exception: {e}") asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) ) raise e # Define a retrying strategy with exponential backoff @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff max_tries=3, # maximum number of retries max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) async def update_data( self, token: Optional[str] = None, data: dict = {}, user_id: Optional[str] = None, ): """ Update existing data """ try: db_data = self.jsonify_object(data=data) if token is not None: print_verbose(f"token: {token}") # check if plain text or hash if token.startswith("sk-"): token = self.hash_token(token=token) db_data["token"] = token response = await self.db.litellm_verificationtoken.update( where={"token": token}, # type: ignore data={**db_data}, # type: ignore ) print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m") return {"token": token, "data": db_data} elif user_id is not None: """ If data['spend'] + data['user'], update the user table with spend info as well """ update_user_row = await self.db.litellm_usertable.update( where={"user_id": user_id}, # type: ignore data={**db_data}, # type: ignore ) return {"user_id": user_id, "data": db_data} except Exception as e: asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) ) print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m") raise e # Define a retrying strategy with exponential backoff @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff max_tries=3, # maximum number of retries max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) async def delete_data(self, tokens: List): """ Allow user to delete a key(s) """ try: hashed_tokens = [self.hash_token(token=token) for token in tokens] await self.db.litellm_verificationtoken.delete_many( where={"token": {"in": hashed_tokens}} ) return {"deleted_keys": tokens} except Exception as e: asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) ) raise e # Define a retrying strategy with exponential backoff @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff max_tries=3, # maximum number of retries max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) async def connect(self): try: if self.db.is_connected() == False: await self.db.connect() except Exception as e: asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) ) raise e # Define a retrying strategy with exponential backoff @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff max_tries=3, # maximum number of retries max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) async def disconnect(self): try: await self.db.disconnect() except Exception as e: asyncio.create_task( self.proxy_logging_obj.failure_handler(original_exception=e) ) raise e class DBClient: """ Routes requests for CustomAuth [TODO] route b/w customauth and prisma """ def __init__( self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict ) -> None: if custom_db_type == "dynamo_db": from litellm.proxy.db.dynamo_db import DynamoDBWrapper self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args)) async def get_data(self, key: str, table_name: Literal["user", "key", "config"]): """ Check if key valid """ return await self.db.get_data(key=key, table_name=table_name) async def insert_data( self, value: Any, table_name: Literal["user", "key", "config"] ): """ For new key / user logic """ return await self.db.insert_data(value=value, table_name=table_name) async def update_data( self, key: str, value: Any, table_name: Literal["user", "key", "config"] ): """ For cost tracking logic key - hash_key value \n value - dict with updated values """ return await self.db.update_data(key=key, value=value, table_name=table_name) async def delete_data( self, keys: List[str], table_name: Literal["user", "key", "config"] ): """ For /key/delete endpoints """ return await self.db.delete_data(keys=keys, table_name=table_name) async def connect(self): """ For connecting to db and creating / updating any tables """ return await self.db.connect() async def disconnect(self): """ For closing connection on server shutdown """ return await self.db.disconnect() ### CUSTOM FILE ### def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: try: print_verbose(f"value: {value}") # Split the path by dots to separate module from instance parts = value.split(".") # The module path is all but the last part, and the instance_name is the last part module_name = ".".join(parts[:-1]) instance_name = parts[-1] # If config_file_path is provided, use it to determine the module spec and load the module if config_file_path is not None: directory = os.path.dirname(config_file_path) module_file_path = os.path.join(directory, *module_name.split(".")) module_file_path += ".py" spec = importlib.util.spec_from_file_location(module_name, module_file_path) if spec is None: raise ImportError( f"Could not find a module specification for {module_file_path}" ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # type: ignore else: # Dynamically import the module module = importlib.import_module(module_name) # Get the instance from the module instance = getattr(module, instance_name) return instance except ImportError as e: # Re-raise the exception with a user-friendly message raise ImportError(f"Could not import {instance_name} from {module_name}") from e except Exception as e: raise e ### HELPER FUNCTIONS ### async def _cache_user_row( user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient] ): """ Check if a user_id exists in cache, if not retrieve it. """ print_verbose(f"Prisma: _cache_user_row, user_id: {user_id}") cache_key = f"{user_id}_user_api_key_user_id" response = cache.get_cache(key=cache_key) if response is None: # Cache miss if isinstance(db, PrismaClient): user_row = await db.get_data(user_id=user_id) elif isinstance(db, DBClient): user_row = await db.get_data(key=user_id, table_name="user") if user_row is not None: print_verbose(f"User Row: {user_row}, type = {type(user_row)}") if hasattr(user_row, "model_dump_json") and callable( getattr(user_row, "model_dump_json") ): cache_value = user_row.model_dump_json() cache.set_cache( key=cache_key, value=cache_value, ttl=600 ) # store for 10 minutes return async def send_email(sender_name, sender_email, receiver_email, subject, html): """ smtp_host, smtp_port, smtp_username, smtp_password, sender_name, sender_email, """ ## SERVER SETUP ## smtp_host = os.getenv("SMTP_HOST") smtp_port = os.getenv("SMTP_PORT", 587) # default to port 587 smtp_username = os.getenv("SMTP_USERNAME") smtp_password = os.getenv("SMTP_PASSWORD") ## EMAIL SETUP ## email_message = MIMEMultipart() email_message["From"] = f"{sender_name} <{sender_email}>" email_message["To"] = receiver_email email_message["Subject"] = subject # Attach the body to the email email_message.attach(MIMEText(html, "html")) try: print_verbose(f"SMTP Connection Init") # Establish a secure connection with the SMTP server with smtplib.SMTP(smtp_host, smtp_port) as server: server.starttls() # Login to your email account server.login(smtp_username, smtp_password) # Send the email server.send_message(email_message) except Exception as e: print_verbose("An error occurred while sending the email:", str(e))