Spaces:
Configuration error
Configuration error
# litellm/proxy/guardrails/guardrail_registry.py | |
import importlib | |
import os | |
import uuid | |
from datetime import datetime, timezone | |
from typing import Dict, List, Optional, cast | |
import litellm | |
from litellm._logging import verbose_proxy_logger | |
from litellm.integrations.custom_guardrail import CustomGuardrail | |
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps | |
from litellm.proxy.utils import PrismaClient | |
from litellm.secret_managers.main import get_secret | |
from litellm.types.guardrails import ( | |
Guardrail, | |
GuardrailEventHooks, | |
LakeraCategoryThresholds, | |
LitellmParams, | |
SupportedGuardrailIntegrations, | |
) | |
from .guardrail_initializers import ( | |
initialize_aim, | |
initialize_aporia, | |
initialize_bedrock, | |
initialize_guardrails_ai, | |
initialize_hide_secrets, | |
initialize_lakera, | |
initialize_lakera_v2, | |
initialize_lasso, | |
initialize_pangea, | |
initialize_presidio, | |
) | |
guardrail_initializer_registry = { | |
SupportedGuardrailIntegrations.APORIA.value: initialize_aporia, | |
SupportedGuardrailIntegrations.BEDROCK.value: initialize_bedrock, | |
SupportedGuardrailIntegrations.LAKERA.value: initialize_lakera, | |
SupportedGuardrailIntegrations.LAKERA_V2.value: initialize_lakera_v2, | |
SupportedGuardrailIntegrations.AIM.value: initialize_aim, | |
SupportedGuardrailIntegrations.PRESIDIO.value: initialize_presidio, | |
SupportedGuardrailIntegrations.HIDE_SECRETS.value: initialize_hide_secrets, | |
SupportedGuardrailIntegrations.GURDRAILS_AI.value: initialize_guardrails_ai, | |
SupportedGuardrailIntegrations.PANGEA.value: initialize_pangea, | |
SupportedGuardrailIntegrations.LASSO.value: initialize_lasso, | |
} | |
class GuardrailRegistry: | |
""" | |
Registry for guardrails | |
Handles adding, removing, and getting guardrails in DB + in memory | |
""" | |
def __init__(self): | |
pass | |
########################################################### | |
########### In memory management helpers for guardrails ########### | |
############################################################ | |
def get_initialized_guardrail_callback( | |
self, guardrail_name: str | |
) -> Optional[CustomGuardrail]: | |
""" | |
Returns the initialized guardrail callback for a given guardrail name | |
""" | |
active_guardrails = ( | |
litellm.logging_callback_manager.get_custom_loggers_for_type( | |
callback_type=CustomGuardrail | |
) | |
) | |
for active_guardrail in active_guardrails: | |
if isinstance(active_guardrail, CustomGuardrail): | |
if active_guardrail.guardrail_name == guardrail_name: | |
return active_guardrail | |
return None | |
########################################################### | |
########### DB management helpers for guardrails ########### | |
############################################################ | |
async def add_guardrail_to_db( | |
self, guardrail: Guardrail, prisma_client: PrismaClient | |
): | |
""" | |
Add a guardrail to the database | |
""" | |
try: | |
guardrail_name = guardrail.get("guardrail_name") | |
litellm_params: str = safe_dumps(dict(guardrail.get("litellm_params", {}))) | |
guardrail_info: str = safe_dumps(guardrail.get("guardrail_info", {})) | |
# Create guardrail in DB | |
created_guardrail = await prisma_client.db.litellm_guardrailstable.create( | |
data={ | |
"guardrail_name": guardrail_name, | |
"litellm_params": litellm_params, | |
"guardrail_info": guardrail_info, | |
"created_at": datetime.now(timezone.utc), | |
"updated_at": datetime.now(timezone.utc), | |
} | |
) | |
# Add guardrail_id to the returned guardrail object | |
guardrail_dict = dict(guardrail) | |
guardrail_dict["guardrail_id"] = created_guardrail.guardrail_id | |
return guardrail_dict | |
except Exception as e: | |
raise Exception(f"Error adding guardrail to DB: {str(e)}") | |
async def delete_guardrail_from_db( | |
self, guardrail_id: str, prisma_client: PrismaClient | |
): | |
""" | |
Delete a guardrail from the database | |
""" | |
try: | |
# Delete from DB | |
await prisma_client.db.litellm_guardrailstable.delete( | |
where={"guardrail_id": guardrail_id} | |
) | |
return {"message": f"Guardrail {guardrail_id} deleted successfully"} | |
except Exception as e: | |
raise Exception(f"Error deleting guardrail from DB: {str(e)}") | |
async def update_guardrail_in_db( | |
self, guardrail_id: str, guardrail: Guardrail, prisma_client: PrismaClient | |
): | |
""" | |
Update a guardrail in the database | |
""" | |
try: | |
guardrail_name = guardrail.get("guardrail_name") | |
litellm_params: str = safe_dumps(dict(guardrail.get("litellm_params", {}))) | |
guardrail_info: str = safe_dumps(guardrail.get("guardrail_info", {})) | |
# Update in DB | |
updated_guardrail = await prisma_client.db.litellm_guardrailstable.update( | |
where={"guardrail_id": guardrail_id}, | |
data={ | |
"guardrail_name": guardrail_name, | |
"litellm_params": litellm_params, | |
"guardrail_info": guardrail_info, | |
"updated_at": datetime.now(timezone.utc), | |
}, | |
) | |
# Convert to dict and return | |
return dict(updated_guardrail) | |
except Exception as e: | |
raise Exception(f"Error updating guardrail in DB: {str(e)}") | |
async def get_all_guardrails_from_db( | |
prisma_client: PrismaClient, | |
) -> List[Guardrail]: | |
""" | |
Get all guardrails from the database | |
""" | |
try: | |
guardrails_from_db = ( | |
await prisma_client.db.litellm_guardrailstable.find_many( | |
order={"created_at": "desc"}, | |
) | |
) | |
guardrails: List[Guardrail] = [] | |
for guardrail in guardrails_from_db: | |
guardrails.append(Guardrail(**(dict(guardrail)))) | |
return guardrails | |
except Exception as e: | |
raise Exception(f"Error getting guardrails from DB: {str(e)}") | |
async def get_guardrail_by_id_from_db( | |
self, guardrail_id: str, prisma_client: PrismaClient | |
) -> Optional[Guardrail]: | |
""" | |
Get a guardrail by its ID from the database | |
""" | |
try: | |
guardrail = await prisma_client.db.litellm_guardrailstable.find_unique( | |
where={"guardrail_id": guardrail_id} | |
) | |
if not guardrail: | |
return None | |
return Guardrail(**(dict(guardrail))) | |
except Exception as e: | |
raise Exception(f"Error getting guardrail from DB: {str(e)}") | |
async def get_guardrail_by_name_from_db( | |
self, guardrail_name: str, prisma_client: PrismaClient | |
) -> Optional[Guardrail]: | |
""" | |
Get a guardrail by its name from the database | |
""" | |
try: | |
guardrail = await prisma_client.db.litellm_guardrailstable.find_unique( | |
where={"guardrail_name": guardrail_name} | |
) | |
if not guardrail: | |
return None | |
return Guardrail(**(dict(guardrail))) | |
except Exception as e: | |
raise Exception(f"Error getting guardrail from DB: {str(e)}") | |
class InMemoryGuardrailHandler: | |
""" | |
Class that handles initializing guardrails and adding them to the CallbackManager | |
""" | |
def __init__(self): | |
self.IN_MEMORY_GUARDRAILS: Dict[str, Guardrail] = {} | |
""" | |
Guardrail id to Guardrail object mapping | |
""" | |
self.guardrail_id_to_custom_guardrail: Dict[str, Optional[CustomGuardrail]] = {} | |
""" | |
Guardrail id to CustomGuardrail object mapping | |
""" | |
def initialize_guardrail( | |
self, | |
guardrail: Dict, | |
config_file_path: Optional[str] = None, | |
) -> Optional[Guardrail]: | |
""" | |
Initialize a guardrail from a dictionary and add it to the litellm callback manager | |
Returns a Guardrail object if the guardrail is initialized successfully | |
""" | |
guardrail_id = guardrail.get("guardrail_id") or str(uuid.uuid4()) | |
guardrail["guardrail_id"] = guardrail_id | |
if guardrail_id in self.IN_MEMORY_GUARDRAILS: | |
verbose_proxy_logger.debug( | |
"guardrail_id already exists in IN_MEMORY_GUARDRAILS" | |
) | |
return self.IN_MEMORY_GUARDRAILS[guardrail_id] | |
custom_guardrail_callback: Optional[CustomGuardrail] = None | |
litellm_params_data = guardrail["litellm_params"] | |
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data) | |
litellm_params = LitellmParams(**litellm_params_data) | |
if ( | |
"category_thresholds" in litellm_params_data | |
and litellm_params_data["category_thresholds"] | |
): | |
lakera_category_thresholds = LakeraCategoryThresholds( | |
**litellm_params_data["category_thresholds"] | |
) | |
litellm_params.category_thresholds = lakera_category_thresholds | |
if litellm_params.api_key and litellm_params.api_key.startswith("os.environ/"): | |
litellm_params.api_key = str(get_secret(litellm_params.api_key)) | |
if litellm_params.api_base and litellm_params.api_base.startswith( | |
"os.environ/" | |
): | |
litellm_params.api_base = str(get_secret(litellm_params.api_base)) | |
guardrail_type = litellm_params.guardrail | |
if guardrail_type is None: | |
raise ValueError("guardrail_type is required") | |
initializer = guardrail_initializer_registry.get(guardrail_type) | |
if initializer: | |
custom_guardrail_callback = initializer(litellm_params, guardrail) | |
elif isinstance(guardrail_type, str) and "." in guardrail_type: | |
custom_guardrail_callback = self.initialize_custom_guardrail( | |
guardrail=guardrail, | |
guardrail_type=guardrail_type, | |
litellm_params=litellm_params, | |
config_file_path=config_file_path, | |
) | |
else: | |
raise ValueError(f"Unsupported guardrail: {guardrail_type}") | |
parsed_guardrail = Guardrail( | |
guardrail_id=guardrail.get("guardrail_id"), | |
guardrail_name=guardrail["guardrail_name"], | |
litellm_params=litellm_params, | |
) | |
# store references to the guardrail in memory | |
self.IN_MEMORY_GUARDRAILS[guardrail_id] = parsed_guardrail | |
self.guardrail_id_to_custom_guardrail[guardrail_id] = custom_guardrail_callback | |
return parsed_guardrail | |
def initialize_custom_guardrail( | |
self, | |
guardrail: Dict, | |
guardrail_type: str, | |
litellm_params: LitellmParams, | |
config_file_path: Optional[str] = None, | |
) -> Optional[CustomGuardrail]: | |
""" | |
Initialize a Custom Guardrail from a python file | |
This initializes it by adding it to the litellm callback manager | |
""" | |
if not config_file_path: | |
raise Exception( | |
"GuardrailsAIException - Please pass the config_file_path to initialize_guardrails_v2" | |
) | |
_file_name, _class_name = guardrail_type.split(".") | |
verbose_proxy_logger.debug( | |
"Initializing custom guardrail: %s, file_name: %s, class_name: %s", | |
guardrail_type, | |
_file_name, | |
_class_name, | |
) | |
directory = os.path.dirname(config_file_path) | |
module_file_path = os.path.join(directory, _file_name) + ".py" | |
spec = importlib.util.spec_from_file_location(_class_name, module_file_path) # type: ignore | |
if not spec: | |
raise ImportError( | |
f"Could not find a module specification for {module_file_path}" | |
) | |
module = importlib.util.module_from_spec(spec) # type: ignore | |
spec.loader.exec_module(module) # type: ignore | |
_guardrail_class = getattr(module, _class_name) | |
mode = litellm_params.mode | |
if mode is None: | |
raise ValueError( | |
f"mode is required for guardrail {guardrail_type} please set mode to one of the following: {', '.join(GuardrailEventHooks)}" | |
) | |
default_on = litellm_params.default_on | |
_guardrail_callback = _guardrail_class( | |
guardrail_name=guardrail["guardrail_name"], | |
event_hook=mode, | |
default_on=default_on, | |
) | |
litellm.logging_callback_manager.add_litellm_callback(_guardrail_callback) # type: ignore | |
return _guardrail_callback | |
def update_in_memory_guardrail( | |
self, guardrail_id: str, guardrail: Guardrail | |
) -> None: | |
""" | |
Update a guardrail in memory | |
- updates the guardrail in memory | |
- updates the guardrail params in litellm.callback_manager | |
""" | |
self.IN_MEMORY_GUARDRAILS[guardrail_id] = guardrail | |
custom_guardrail_callback = self.guardrail_id_to_custom_guardrail.get( | |
guardrail_id | |
) | |
if custom_guardrail_callback: | |
updated_litellm_params = cast( | |
LitellmParams, guardrail.get("litellm_params", {}) | |
) | |
custom_guardrail_callback.update_in_memory_litellm_params( | |
litellm_params=updated_litellm_params | |
) | |
def delete_in_memory_guardrail(self, guardrail_id: str) -> None: | |
""" | |
Delete a guardrail in memory | |
""" | |
self.IN_MEMORY_GUARDRAILS.pop(guardrail_id, None) | |
def list_in_memory_guardrails(self) -> List[Guardrail]: | |
""" | |
List all guardrails in memory | |
""" | |
return list(self.IN_MEMORY_GUARDRAILS.values()) | |
def get_guardrail_by_id(self, guardrail_id: str) -> Optional[Guardrail]: | |
""" | |
Get a guardrail by its ID from memory | |
""" | |
return self.IN_MEMORY_GUARDRAILS.get(guardrail_id) | |
######################################################## | |
# In Memory Guardrail Handler for LiteLLM Proxy | |
######################################################## | |
IN_MEMORY_GUARDRAIL_HANDLER = InMemoryGuardrailHandler() | |
######################################################## | |