test3 / litellm /proxy /guardrails /guardrail_registry.py
DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
# litellm/proxy/guardrails/guardrail_registry.py
import importlib
import os
import uuid
from datetime import datetime, timezone
from typing import Dict, List, Optional, cast
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy.utils import PrismaClient
from litellm.secret_managers.main import get_secret
from litellm.types.guardrails import (
Guardrail,
GuardrailEventHooks,
LakeraCategoryThresholds,
LitellmParams,
SupportedGuardrailIntegrations,
)
from .guardrail_initializers import (
initialize_aim,
initialize_aporia,
initialize_bedrock,
initialize_guardrails_ai,
initialize_hide_secrets,
initialize_lakera,
initialize_lakera_v2,
initialize_lasso,
initialize_pangea,
initialize_presidio,
)
guardrail_initializer_registry = {
SupportedGuardrailIntegrations.APORIA.value: initialize_aporia,
SupportedGuardrailIntegrations.BEDROCK.value: initialize_bedrock,
SupportedGuardrailIntegrations.LAKERA.value: initialize_lakera,
SupportedGuardrailIntegrations.LAKERA_V2.value: initialize_lakera_v2,
SupportedGuardrailIntegrations.AIM.value: initialize_aim,
SupportedGuardrailIntegrations.PRESIDIO.value: initialize_presidio,
SupportedGuardrailIntegrations.HIDE_SECRETS.value: initialize_hide_secrets,
SupportedGuardrailIntegrations.GURDRAILS_AI.value: initialize_guardrails_ai,
SupportedGuardrailIntegrations.PANGEA.value: initialize_pangea,
SupportedGuardrailIntegrations.LASSO.value: initialize_lasso,
}
class GuardrailRegistry:
"""
Registry for guardrails
Handles adding, removing, and getting guardrails in DB + in memory
"""
def __init__(self):
pass
###########################################################
########### In memory management helpers for guardrails ###########
############################################################
def get_initialized_guardrail_callback(
self, guardrail_name: str
) -> Optional[CustomGuardrail]:
"""
Returns the initialized guardrail callback for a given guardrail name
"""
active_guardrails = (
litellm.logging_callback_manager.get_custom_loggers_for_type(
callback_type=CustomGuardrail
)
)
for active_guardrail in active_guardrails:
if isinstance(active_guardrail, CustomGuardrail):
if active_guardrail.guardrail_name == guardrail_name:
return active_guardrail
return None
###########################################################
########### DB management helpers for guardrails ###########
############################################################
async def add_guardrail_to_db(
self, guardrail: Guardrail, prisma_client: PrismaClient
):
"""
Add a guardrail to the database
"""
try:
guardrail_name = guardrail.get("guardrail_name")
litellm_params: str = safe_dumps(dict(guardrail.get("litellm_params", {})))
guardrail_info: str = safe_dumps(guardrail.get("guardrail_info", {}))
# Create guardrail in DB
created_guardrail = await prisma_client.db.litellm_guardrailstable.create(
data={
"guardrail_name": guardrail_name,
"litellm_params": litellm_params,
"guardrail_info": guardrail_info,
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
}
)
# Add guardrail_id to the returned guardrail object
guardrail_dict = dict(guardrail)
guardrail_dict["guardrail_id"] = created_guardrail.guardrail_id
return guardrail_dict
except Exception as e:
raise Exception(f"Error adding guardrail to DB: {str(e)}")
async def delete_guardrail_from_db(
self, guardrail_id: str, prisma_client: PrismaClient
):
"""
Delete a guardrail from the database
"""
try:
# Delete from DB
await prisma_client.db.litellm_guardrailstable.delete(
where={"guardrail_id": guardrail_id}
)
return {"message": f"Guardrail {guardrail_id} deleted successfully"}
except Exception as e:
raise Exception(f"Error deleting guardrail from DB: {str(e)}")
async def update_guardrail_in_db(
self, guardrail_id: str, guardrail: Guardrail, prisma_client: PrismaClient
):
"""
Update a guardrail in the database
"""
try:
guardrail_name = guardrail.get("guardrail_name")
litellm_params: str = safe_dumps(dict(guardrail.get("litellm_params", {})))
guardrail_info: str = safe_dumps(guardrail.get("guardrail_info", {}))
# Update in DB
updated_guardrail = await prisma_client.db.litellm_guardrailstable.update(
where={"guardrail_id": guardrail_id},
data={
"guardrail_name": guardrail_name,
"litellm_params": litellm_params,
"guardrail_info": guardrail_info,
"updated_at": datetime.now(timezone.utc),
},
)
# Convert to dict and return
return dict(updated_guardrail)
except Exception as e:
raise Exception(f"Error updating guardrail in DB: {str(e)}")
@staticmethod
async def get_all_guardrails_from_db(
prisma_client: PrismaClient,
) -> List[Guardrail]:
"""
Get all guardrails from the database
"""
try:
guardrails_from_db = (
await prisma_client.db.litellm_guardrailstable.find_many(
order={"created_at": "desc"},
)
)
guardrails: List[Guardrail] = []
for guardrail in guardrails_from_db:
guardrails.append(Guardrail(**(dict(guardrail))))
return guardrails
except Exception as e:
raise Exception(f"Error getting guardrails from DB: {str(e)}")
async def get_guardrail_by_id_from_db(
self, guardrail_id: str, prisma_client: PrismaClient
) -> Optional[Guardrail]:
"""
Get a guardrail by its ID from the database
"""
try:
guardrail = await prisma_client.db.litellm_guardrailstable.find_unique(
where={"guardrail_id": guardrail_id}
)
if not guardrail:
return None
return Guardrail(**(dict(guardrail)))
except Exception as e:
raise Exception(f"Error getting guardrail from DB: {str(e)}")
async def get_guardrail_by_name_from_db(
self, guardrail_name: str, prisma_client: PrismaClient
) -> Optional[Guardrail]:
"""
Get a guardrail by its name from the database
"""
try:
guardrail = await prisma_client.db.litellm_guardrailstable.find_unique(
where={"guardrail_name": guardrail_name}
)
if not guardrail:
return None
return Guardrail(**(dict(guardrail)))
except Exception as e:
raise Exception(f"Error getting guardrail from DB: {str(e)}")
class InMemoryGuardrailHandler:
"""
Class that handles initializing guardrails and adding them to the CallbackManager
"""
def __init__(self):
self.IN_MEMORY_GUARDRAILS: Dict[str, Guardrail] = {}
"""
Guardrail id to Guardrail object mapping
"""
self.guardrail_id_to_custom_guardrail: Dict[str, Optional[CustomGuardrail]] = {}
"""
Guardrail id to CustomGuardrail object mapping
"""
def initialize_guardrail(
self,
guardrail: Dict,
config_file_path: Optional[str] = None,
) -> Optional[Guardrail]:
"""
Initialize a guardrail from a dictionary and add it to the litellm callback manager
Returns a Guardrail object if the guardrail is initialized successfully
"""
guardrail_id = guardrail.get("guardrail_id") or str(uuid.uuid4())
guardrail["guardrail_id"] = guardrail_id
if guardrail_id in self.IN_MEMORY_GUARDRAILS:
verbose_proxy_logger.debug(
"guardrail_id already exists in IN_MEMORY_GUARDRAILS"
)
return self.IN_MEMORY_GUARDRAILS[guardrail_id]
custom_guardrail_callback: Optional[CustomGuardrail] = None
litellm_params_data = guardrail["litellm_params"]
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
litellm_params = LitellmParams(**litellm_params_data)
if (
"category_thresholds" in litellm_params_data
and litellm_params_data["category_thresholds"]
):
lakera_category_thresholds = LakeraCategoryThresholds(
**litellm_params_data["category_thresholds"]
)
litellm_params.category_thresholds = lakera_category_thresholds
if litellm_params.api_key and litellm_params.api_key.startswith("os.environ/"):
litellm_params.api_key = str(get_secret(litellm_params.api_key))
if litellm_params.api_base and litellm_params.api_base.startswith(
"os.environ/"
):
litellm_params.api_base = str(get_secret(litellm_params.api_base))
guardrail_type = litellm_params.guardrail
if guardrail_type is None:
raise ValueError("guardrail_type is required")
initializer = guardrail_initializer_registry.get(guardrail_type)
if initializer:
custom_guardrail_callback = initializer(litellm_params, guardrail)
elif isinstance(guardrail_type, str) and "." in guardrail_type:
custom_guardrail_callback = self.initialize_custom_guardrail(
guardrail=guardrail,
guardrail_type=guardrail_type,
litellm_params=litellm_params,
config_file_path=config_file_path,
)
else:
raise ValueError(f"Unsupported guardrail: {guardrail_type}")
parsed_guardrail = Guardrail(
guardrail_id=guardrail.get("guardrail_id"),
guardrail_name=guardrail["guardrail_name"],
litellm_params=litellm_params,
)
# store references to the guardrail in memory
self.IN_MEMORY_GUARDRAILS[guardrail_id] = parsed_guardrail
self.guardrail_id_to_custom_guardrail[guardrail_id] = custom_guardrail_callback
return parsed_guardrail
def initialize_custom_guardrail(
self,
guardrail: Dict,
guardrail_type: str,
litellm_params: LitellmParams,
config_file_path: Optional[str] = None,
) -> Optional[CustomGuardrail]:
"""
Initialize a Custom Guardrail from a python file
This initializes it by adding it to the litellm callback manager
"""
if not config_file_path:
raise Exception(
"GuardrailsAIException - Please pass the config_file_path to initialize_guardrails_v2"
)
_file_name, _class_name = guardrail_type.split(".")
verbose_proxy_logger.debug(
"Initializing custom guardrail: %s, file_name: %s, class_name: %s",
guardrail_type,
_file_name,
_class_name,
)
directory = os.path.dirname(config_file_path)
module_file_path = os.path.join(directory, _file_name) + ".py"
spec = importlib.util.spec_from_file_location(_class_name, module_file_path) # type: ignore
if not spec:
raise ImportError(
f"Could not find a module specification for {module_file_path}"
)
module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module) # type: ignore
_guardrail_class = getattr(module, _class_name)
mode = litellm_params.mode
if mode is None:
raise ValueError(
f"mode is required for guardrail {guardrail_type} please set mode to one of the following: {', '.join(GuardrailEventHooks)}"
)
default_on = litellm_params.default_on
_guardrail_callback = _guardrail_class(
guardrail_name=guardrail["guardrail_name"],
event_hook=mode,
default_on=default_on,
)
litellm.logging_callback_manager.add_litellm_callback(_guardrail_callback) # type: ignore
return _guardrail_callback
def update_in_memory_guardrail(
self, guardrail_id: str, guardrail: Guardrail
) -> None:
"""
Update a guardrail in memory
- updates the guardrail in memory
- updates the guardrail params in litellm.callback_manager
"""
self.IN_MEMORY_GUARDRAILS[guardrail_id] = guardrail
custom_guardrail_callback = self.guardrail_id_to_custom_guardrail.get(
guardrail_id
)
if custom_guardrail_callback:
updated_litellm_params = cast(
LitellmParams, guardrail.get("litellm_params", {})
)
custom_guardrail_callback.update_in_memory_litellm_params(
litellm_params=updated_litellm_params
)
def delete_in_memory_guardrail(self, guardrail_id: str) -> None:
"""
Delete a guardrail in memory
"""
self.IN_MEMORY_GUARDRAILS.pop(guardrail_id, None)
def list_in_memory_guardrails(self) -> List[Guardrail]:
"""
List all guardrails in memory
"""
return list(self.IN_MEMORY_GUARDRAILS.values())
def get_guardrail_by_id(self, guardrail_id: str) -> Optional[Guardrail]:
"""
Get a guardrail by its ID from memory
"""
return self.IN_MEMORY_GUARDRAILS.get(guardrail_id)
########################################################
# In Memory Guardrail Handler for LiteLLM Proxy
########################################################
IN_MEMORY_GUARDRAIL_HANDLER = InMemoryGuardrailHandler()
########################################################