Spaces:
Configuration error
Configuration error
# +-----------------------------------------------+ | |
# | | | |
# | PII Masking | | |
# | with Microsoft Presidio | | |
# | https://github.com/BerriAI/litellm/issues/ | | |
# +-----------------------------------------------+ | |
# | |
# Tell us how we can improve! - Krrish & Ishaan | |
import asyncio | |
import json | |
import uuid | |
from datetime import datetime | |
from typing import ( | |
Any, | |
AsyncGenerator, | |
Dict, | |
List, | |
Literal, | |
Optional, | |
Tuple, | |
Union, | |
cast, | |
) | |
import aiohttp | |
import litellm # noqa: E401 | |
from litellm import get_secret | |
from litellm._logging import verbose_proxy_logger | |
from litellm.caching.caching import DualCache | |
from litellm.exceptions import BlockedPiiEntityError | |
from litellm.integrations.custom_guardrail import CustomGuardrail | |
from litellm.proxy._types import UserAPIKeyAuth | |
from litellm.types.guardrails import ( | |
GuardrailEventHooks, | |
LitellmParams, | |
PiiAction, | |
PiiEntityType, | |
PresidioPerRequestConfig, | |
) | |
from litellm.types.proxy.guardrails.guardrail_hooks.presidio import ( | |
PresidioAnalyzeRequest, | |
PresidioAnalyzeResponseItem, | |
) | |
from litellm.types.utils import CallTypes as LitellmCallTypes | |
from litellm.utils import ( | |
EmbeddingResponse, | |
ImageResponse, | |
ModelResponse, | |
ModelResponseStream, | |
StreamingChoices, | |
) | |
class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): | |
user_api_key_cache = None | |
ad_hoc_recognizers = None | |
# Class variables or attributes | |
def __init__( | |
self, | |
mock_testing: bool = False, | |
mock_redacted_text: Optional[dict] = None, | |
presidio_analyzer_api_base: Optional[str] = None, | |
presidio_anonymizer_api_base: Optional[str] = None, | |
output_parse_pii: Optional[bool] = False, | |
presidio_ad_hoc_recognizers: Optional[str] = None, | |
logging_only: Optional[bool] = None, | |
pii_entities_config: Optional[Dict[PiiEntityType, PiiAction]] = None, | |
presidio_language: Optional[str] = None, | |
**kwargs, | |
): | |
if logging_only is True: | |
self.logging_only = True | |
kwargs["event_hook"] = GuardrailEventHooks.logging_only | |
super().__init__(**kwargs) | |
self.pii_tokens: dict = ( | |
{} | |
) # mapping of PII token to original text - only used with Presidio `replace` operation | |
self.mock_redacted_text = mock_redacted_text | |
self.output_parse_pii = output_parse_pii or False | |
self.pii_entities_config: Dict[PiiEntityType, PiiAction] = ( | |
pii_entities_config or {} | |
) | |
self.presidio_language = presidio_language or "en" | |
if mock_testing is True: # for testing purposes only | |
return | |
ad_hoc_recognizers = presidio_ad_hoc_recognizers | |
if ad_hoc_recognizers is not None: | |
try: | |
with open(ad_hoc_recognizers, "r") as file: | |
self.ad_hoc_recognizers = json.load(file) | |
except FileNotFoundError: | |
raise Exception(f"File not found. file_path={ad_hoc_recognizers}") | |
except json.JSONDecodeError as e: | |
raise Exception( | |
f"Error decoding JSON file: {str(e)}, file_path={ad_hoc_recognizers}" | |
) | |
except Exception as e: | |
raise Exception( | |
f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}" | |
) | |
self.validate_environment( | |
presidio_analyzer_api_base=presidio_analyzer_api_base, | |
presidio_anonymizer_api_base=presidio_anonymizer_api_base, | |
) | |
def validate_environment( | |
self, | |
presidio_analyzer_api_base: Optional[str] = None, | |
presidio_anonymizer_api_base: Optional[str] = None, | |
): | |
self.presidio_analyzer_api_base: Optional[ | |
str | |
] = presidio_analyzer_api_base or get_secret( | |
"PRESIDIO_ANALYZER_API_BASE", None | |
) # type: ignore | |
self.presidio_anonymizer_api_base: Optional[ | |
str | |
] = presidio_anonymizer_api_base or litellm.get_secret( | |
"PRESIDIO_ANONYMIZER_API_BASE", None | |
) # type: ignore | |
if self.presidio_analyzer_api_base is None: | |
raise Exception("Missing `PRESIDIO_ANALYZER_API_BASE` from environment") | |
if not self.presidio_analyzer_api_base.endswith("/"): | |
self.presidio_analyzer_api_base += "/" | |
if not ( | |
self.presidio_analyzer_api_base.startswith("http://") | |
or self.presidio_analyzer_api_base.startswith("https://") | |
): | |
# add http:// if unset, assume communicating over private network - e.g. render | |
self.presidio_analyzer_api_base = ( | |
"http://" + self.presidio_analyzer_api_base | |
) | |
if self.presidio_anonymizer_api_base is None: | |
raise Exception("Missing `PRESIDIO_ANONYMIZER_API_BASE` from environment") | |
if not self.presidio_anonymizer_api_base.endswith("/"): | |
self.presidio_anonymizer_api_base += "/" | |
if not ( | |
self.presidio_anonymizer_api_base.startswith("http://") | |
or self.presidio_anonymizer_api_base.startswith("https://") | |
): | |
# add http:// if unset, assume communicating over private network - e.g. render | |
self.presidio_anonymizer_api_base = ( | |
"http://" + self.presidio_anonymizer_api_base | |
) | |
def _get_presidio_analyze_request_payload( | |
self, | |
text: str, | |
presidio_config: Optional[PresidioPerRequestConfig], | |
request_data: dict, | |
) -> PresidioAnalyzeRequest: | |
""" | |
Construct the payload for the Presidio analyze request | |
API Ref: https://microsoft.github.io/presidio/api-docs/api-docs.html#tag/Analyzer/paths/~1analyze/post | |
""" | |
analyze_payload: PresidioAnalyzeRequest = PresidioAnalyzeRequest( | |
text=text, | |
language=self.presidio_language, | |
) | |
################################################################## | |
###### Check if user has configured any params for this guardrail | |
################################################################ | |
if self.ad_hoc_recognizers is not None: | |
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers | |
if self.pii_entities_config: | |
analyze_payload["entities"] = list(self.pii_entities_config.keys()) | |
################################################################## | |
######### End of adding config params | |
################################################################## | |
# Check if client side request passed any dynamic params | |
if presidio_config and presidio_config.language: | |
analyze_payload["language"] = presidio_config.language | |
casted_analyze_payload: dict = cast(dict, analyze_payload) | |
casted_analyze_payload.update( | |
self.get_guardrail_dynamic_request_body_params(request_data=request_data) | |
) | |
return cast(PresidioAnalyzeRequest, casted_analyze_payload) | |
async def analyze_text( | |
self, | |
text: str, | |
presidio_config: Optional[PresidioPerRequestConfig], | |
request_data: dict, | |
) -> Union[List[PresidioAnalyzeResponseItem], Dict]: | |
""" | |
Send text to the Presidio analyzer endpoint and get analysis results | |
""" | |
try: | |
async with aiohttp.ClientSession() as session: | |
if self.mock_redacted_text is not None: | |
return self.mock_redacted_text | |
# Make the request to /analyze | |
analyze_url = f"{self.presidio_analyzer_api_base}analyze" | |
analyze_payload: PresidioAnalyzeRequest = ( | |
self._get_presidio_analyze_request_payload( | |
text=text, | |
presidio_config=presidio_config, | |
request_data=request_data, | |
) | |
) | |
verbose_proxy_logger.debug( | |
"Making request to: %s with payload: %s", | |
analyze_url, | |
analyze_payload, | |
) | |
async with session.post(analyze_url, json=analyze_payload) as response: | |
analyze_results = await response.json() | |
verbose_proxy_logger.debug("analyze_results: %s", analyze_results) | |
final_results = [] | |
for item in analyze_results: | |
final_results.append(PresidioAnalyzeResponseItem(**item)) | |
return final_results | |
except Exception as e: | |
raise e | |
async def anonymize_text( | |
self, | |
text: str, | |
analyze_results: Any, | |
output_parse_pii: bool, | |
masked_entity_count: Dict[str, int], | |
) -> str: | |
""" | |
Send analysis results to the Presidio anonymizer endpoint to get redacted text | |
""" | |
try: | |
async with aiohttp.ClientSession() as session: | |
# Make the request to /anonymize | |
anonymize_url = f"{self.presidio_anonymizer_api_base}anonymize" | |
verbose_proxy_logger.debug("Making request to: %s", anonymize_url) | |
anonymize_payload = { | |
"text": text, | |
"analyzer_results": analyze_results, | |
} | |
async with session.post( | |
anonymize_url, json=anonymize_payload | |
) as response: | |
redacted_text = await response.json() | |
new_text = text | |
if redacted_text is not None: | |
verbose_proxy_logger.debug("redacted_text: %s", redacted_text) | |
for item in redacted_text["items"]: | |
start = item["start"] | |
end = item["end"] | |
replacement = item["text"] # replacement token | |
if item["operator"] == "replace" and output_parse_pii is True: | |
# check if token in dict | |
# if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing | |
if replacement in self.pii_tokens: | |
replacement = replacement + str(uuid.uuid4()) | |
self.pii_tokens[replacement] = new_text[ | |
start:end | |
] # get text it'll replace | |
new_text = new_text[:start] + replacement + new_text[end:] | |
entity_type = item.get("entity_type", None) | |
if entity_type is not None: | |
masked_entity_count[entity_type] = ( | |
masked_entity_count.get(entity_type, 0) + 1 | |
) | |
return redacted_text["text"] | |
else: | |
raise Exception(f"Invalid anonymizer response: {redacted_text}") | |
except Exception as e: | |
raise e | |
def raise_exception_if_blocked_entities_detected( | |
self, analyze_results: Union[List[PresidioAnalyzeResponseItem], Dict] | |
): | |
""" | |
Raise an exception if blocked entities are detected | |
""" | |
if self.pii_entities_config is None: | |
return | |
if isinstance(analyze_results, Dict): | |
# if mock testing is enabled, analyze_results is a dict | |
# we don't need to raise an exception in this case | |
return | |
for result in analyze_results: | |
entity_type = result.get("entity_type") | |
if entity_type: | |
casted_entity_type: PiiEntityType = cast(PiiEntityType, entity_type) | |
if ( | |
casted_entity_type in self.pii_entities_config | |
and self.pii_entities_config[casted_entity_type] == PiiAction.BLOCK | |
): | |
raise BlockedPiiEntityError( | |
entity_type=entity_type, | |
guardrail_name=self.guardrail_name, | |
) | |
async def check_pii( | |
self, | |
text: str, | |
output_parse_pii: bool, | |
presidio_config: Optional[PresidioPerRequestConfig], | |
request_data: dict, | |
) -> str: | |
""" | |
Calls Presidio Analyze + Anonymize endpoints for PII Analysis + Masking | |
""" | |
start_time = datetime.now() | |
analyze_results: Optional[Union[List[PresidioAnalyzeResponseItem], Dict]] = None | |
status: Literal["success", "failure"] = "success" | |
masked_entity_count: Dict[str, int] = {} | |
exception_str: str = "" | |
try: | |
if self.mock_redacted_text is not None: | |
redacted_text = self.mock_redacted_text | |
else: | |
# First get analysis results | |
analyze_results = await self.analyze_text( | |
text=text, | |
presidio_config=presidio_config, | |
request_data=request_data, | |
) | |
verbose_proxy_logger.debug("analyze_results: %s", analyze_results) | |
#################################################### | |
# Blocked Entities check | |
#################################################### | |
self.raise_exception_if_blocked_entities_detected( | |
analyze_results=analyze_results | |
) | |
# Then anonymize the text using the analysis results | |
return await self.anonymize_text( | |
text=text, | |
analyze_results=analyze_results, | |
output_parse_pii=output_parse_pii, | |
masked_entity_count=masked_entity_count, | |
) | |
return redacted_text["text"] | |
except Exception as e: | |
status = "failure" | |
exception_str = str(e) | |
raise e | |
finally: | |
#################################################### | |
# Create Guardrail Trace for logging on Langfuse, Datadog, etc. | |
#################################################### | |
guardrail_json_response: Union[Exception, str, dict, List[dict]] = {} | |
if status == "success": | |
if isinstance(analyze_results, List): | |
guardrail_json_response = [dict(item) for item in analyze_results] | |
else: | |
guardrail_json_response = exception_str | |
self.add_standard_logging_guardrail_information_to_request_data( | |
guardrail_json_response=guardrail_json_response, | |
request_data=request_data, | |
guardrail_status=status, | |
start_time=start_time.timestamp(), | |
end_time=datetime.now().timestamp(), | |
duration=(datetime.now() - start_time).total_seconds(), | |
masked_entity_count=masked_entity_count, | |
) | |
async def async_pre_call_hook( | |
self, | |
user_api_key_dict: UserAPIKeyAuth, | |
cache: DualCache, | |
data: dict, | |
call_type: str, | |
): | |
""" | |
- Check if request turned off pii | |
- Check if user allowed to turn off pii (key permissions -> 'allow_pii_controls') | |
- Take the request data | |
- Call /analyze -> get the results | |
- Call /anonymize w/ the analyze results -> get the redacted text | |
For multiple messages in /chat/completions, we'll need to call them in parallel. | |
""" | |
try: | |
content_safety = data.get("content_safety", None) | |
verbose_proxy_logger.debug("content_safety: %s", content_safety) | |
presidio_config = self.get_presidio_settings_from_request_data(data) | |
if call_type in [ | |
LitellmCallTypes.completion.value, | |
LitellmCallTypes.acompletion.value, | |
]: | |
messages = data["messages"] | |
tasks = [] | |
for m in messages: | |
content = m.get("content", None) | |
if content is None: | |
continue | |
if isinstance(content, str): | |
tasks.append( | |
self.check_pii( | |
text=content, | |
output_parse_pii=self.output_parse_pii, | |
presidio_config=presidio_config, | |
request_data=data, | |
) | |
) | |
responses = await asyncio.gather(*tasks) | |
for index, r in enumerate(responses): | |
content = messages[index].get("content", None) | |
if content is None: | |
continue | |
if isinstance(content, str): | |
messages[index][ | |
"content" | |
] = r # replace content with redacted string | |
verbose_proxy_logger.info( | |
f"Presidio PII Masking: Redacted pii message: {data['messages']}" | |
) | |
data["messages"] = messages | |
else: | |
verbose_proxy_logger.debug( | |
f"Not running async_pre_call_hook for call_type={call_type}" | |
) | |
return data | |
except Exception as e: | |
raise e | |
def logging_hook( | |
self, kwargs: dict, result: Any, call_type: str | |
) -> Tuple[dict, Any]: | |
from concurrent.futures import ThreadPoolExecutor | |
def run_in_new_loop(): | |
"""Run the coroutine in a new event loop within this thread.""" | |
new_loop = asyncio.new_event_loop() | |
try: | |
asyncio.set_event_loop(new_loop) | |
return new_loop.run_until_complete( | |
self.async_logging_hook( | |
kwargs=kwargs, result=result, call_type=call_type | |
) | |
) | |
finally: | |
new_loop.close() | |
asyncio.set_event_loop(None) | |
try: | |
# First, try to get the current event loop | |
_ = asyncio.get_running_loop() | |
# If we're already in an event loop, run in a separate thread | |
# to avoid nested event loop issues | |
with ThreadPoolExecutor(max_workers=1) as executor: | |
future = executor.submit(run_in_new_loop) | |
return future.result() | |
except RuntimeError: | |
# No running event loop, we can safely run in this thread | |
return run_in_new_loop() | |
async def async_logging_hook( | |
self, kwargs: dict, result: Any, call_type: str | |
) -> Tuple[dict, Any]: | |
""" | |
Masks the input before logging to langfuse, datadog, etc. | |
""" | |
if ( | |
call_type == "completion" or call_type == "acompletion" | |
): # /chat/completions requests | |
messages: Optional[List] = kwargs.get("messages", None) | |
tasks = [] | |
if messages is None: | |
return kwargs, result | |
presidio_config = self.get_presidio_settings_from_request_data(kwargs) | |
for m in messages: | |
text_str = "" | |
content = m.get("content", None) | |
if content is None: | |
continue | |
if isinstance(content, str): | |
text_str = content | |
tasks.append( | |
self.check_pii( | |
text=text_str, | |
output_parse_pii=False, | |
presidio_config=presidio_config, | |
request_data=kwargs, | |
) | |
) # need to pass separately b/c presidio has context window limits | |
responses = await asyncio.gather(*tasks) | |
for index, r in enumerate(responses): | |
content = messages[index].get("content", None) | |
if content is None: | |
continue | |
if isinstance(content, str): | |
messages[index][ | |
"content" | |
] = r # replace content with redacted string | |
verbose_proxy_logger.info( | |
f"Presidio PII Masking: Redacted pii message: {messages}" | |
) | |
kwargs["messages"] = messages | |
return kwargs, result | |
async def async_post_call_success_hook( # type: ignore | |
self, | |
data: dict, | |
user_api_key_dict: UserAPIKeyAuth, | |
response: Union[ModelResponse, EmbeddingResponse, ImageResponse], | |
): | |
""" | |
Output parse the response object to replace the masked tokens with user sent values | |
""" | |
verbose_proxy_logger.debug( | |
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}" | |
) | |
if self.output_parse_pii is False and litellm.output_parse_pii is False: | |
return response | |
if isinstance(response, ModelResponse) and not isinstance( | |
response.choices[0], StreamingChoices | |
): # /chat/completions requests | |
if isinstance(response.choices[0].message.content, str): | |
verbose_proxy_logger.debug( | |
f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}" | |
) | |
for key, value in self.pii_tokens.items(): | |
response.choices[0].message.content = response.choices[ | |
0 | |
].message.content.replace(key, value) | |
return response | |
async def async_post_call_streaming_iterator_hook( | |
self, | |
user_api_key_dict: UserAPIKeyAuth, | |
response: Any, | |
request_data: dict, | |
) -> AsyncGenerator[ModelResponseStream, None]: | |
""" | |
Process streaming response chunks to unmask PII tokens when needed. | |
If PII processing is enabled, this collects all chunks, applies PII unmasking, | |
and returns a reconstructed stream. Otherwise, it passes through the original stream. | |
""" | |
# If PII unmasking not needed, just pass through the original stream | |
if not (self.output_parse_pii and self.pii_tokens): | |
async for chunk in response: | |
yield chunk | |
return | |
# Import here to avoid circular imports | |
from litellm.llms.base_llm.base_model_iterator import MockResponseIterator | |
from litellm.types.utils import Choices, Message | |
try: | |
# Collect all chunks to process them together | |
collected_content = "" | |
last_chunk = None | |
async for chunk in response: | |
last_chunk = chunk | |
# Extract content safely with proper attribute checks | |
if ( | |
hasattr(chunk, "choices") | |
and chunk.choices | |
and hasattr(chunk.choices[0], "delta") | |
and hasattr(chunk.choices[0].delta, "content") | |
and isinstance(chunk.choices[0].delta.content, str) | |
): | |
collected_content += chunk.choices[0].delta.content | |
# No need to proceed if we didn't capture a valid chunk | |
if not last_chunk: | |
async for chunk in response: | |
yield chunk | |
return | |
# Apply PII unmasking to the complete content | |
for token, original_text in self.pii_tokens.items(): | |
collected_content = collected_content.replace(token, original_text) | |
# Reconstruct the response with unmasked content | |
mock_response = MockResponseIterator( | |
model_response=ModelResponse( | |
id=last_chunk.id, | |
object=last_chunk.object, | |
created=last_chunk.created, | |
model=last_chunk.model, | |
choices=[ | |
Choices( | |
message=Message( | |
role="assistant", | |
content=collected_content, | |
), | |
index=0, | |
finish_reason="stop", | |
) | |
], | |
), | |
json_mode=False, | |
) | |
# Return the reconstructed stream | |
async for chunk in mock_response: | |
yield chunk | |
except Exception as e: | |
verbose_proxy_logger.error(f"Error in PII streaming processing: {str(e)}") | |
# Fallback to original stream on error | |
async for chunk in response: | |
yield chunk | |
def get_presidio_settings_from_request_data( | |
self, data: dict | |
) -> Optional[PresidioPerRequestConfig]: | |
if "metadata" in data: | |
_metadata = data.get("metadata", None) | |
if _metadata is None: | |
return None | |
_guardrail_config = _metadata.get("guardrail_config") | |
if _guardrail_config: | |
_presidio_config = PresidioPerRequestConfig(**_guardrail_config) | |
return _presidio_config | |
return None | |
def print_verbose(self, print_statement): | |
try: | |
verbose_proxy_logger.debug(print_statement) | |
if litellm.set_verbose: | |
print(print_statement) # noqa | |
except Exception: | |
pass | |
async def apply_guardrail( | |
self, | |
text: str, | |
language: Optional[str] = None, | |
entities: Optional[List[PiiEntityType]] = None, | |
) -> str: | |
""" | |
UI will call this function to check: | |
1. If the connection to the guardrail is working | |
2. When Testing the guardrail with some text, this function will be called with the input text and returns a text after applying the guardrail | |
""" | |
text = await self.check_pii( | |
text=text, | |
output_parse_pii=self.output_parse_pii, | |
presidio_config=None, | |
request_data={}, | |
) | |
return text | |
def update_in_memory_litellm_params(self, litellm_params: LitellmParams) -> None: | |
""" | |
Update the guardrails litellm params in memory | |
""" | |
if litellm_params.pii_entities_config: | |
self.pii_entities_config = litellm_params.pii_entities_config | |