DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
# +-----------------------------------------------+
# | |
# | PII Masking |
# | with Microsoft Presidio |
# | https://github.com/BerriAI/litellm/issues/ |
# +-----------------------------------------------+
#
# Tell us how we can improve! - Krrish & Ishaan
import asyncio
import json
import uuid
from datetime import datetime
from typing import (
Any,
AsyncGenerator,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
cast,
)
import aiohttp
import litellm # noqa: E401
from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.exceptions import BlockedPiiEntityError
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.guardrails import (
GuardrailEventHooks,
LitellmParams,
PiiAction,
PiiEntityType,
PresidioPerRequestConfig,
)
from litellm.types.proxy.guardrails.guardrail_hooks.presidio import (
PresidioAnalyzeRequest,
PresidioAnalyzeResponseItem,
)
from litellm.types.utils import CallTypes as LitellmCallTypes
from litellm.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
ModelResponseStream,
StreamingChoices,
)
class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
user_api_key_cache = None
ad_hoc_recognizers = None
# Class variables or attributes
def __init__(
self,
mock_testing: bool = False,
mock_redacted_text: Optional[dict] = None,
presidio_analyzer_api_base: Optional[str] = None,
presidio_anonymizer_api_base: Optional[str] = None,
output_parse_pii: Optional[bool] = False,
presidio_ad_hoc_recognizers: Optional[str] = None,
logging_only: Optional[bool] = None,
pii_entities_config: Optional[Dict[PiiEntityType, PiiAction]] = None,
presidio_language: Optional[str] = None,
**kwargs,
):
if logging_only is True:
self.logging_only = True
kwargs["event_hook"] = GuardrailEventHooks.logging_only
super().__init__(**kwargs)
self.pii_tokens: dict = (
{}
) # mapping of PII token to original text - only used with Presidio `replace` operation
self.mock_redacted_text = mock_redacted_text
self.output_parse_pii = output_parse_pii or False
self.pii_entities_config: Dict[PiiEntityType, PiiAction] = (
pii_entities_config or {}
)
self.presidio_language = presidio_language or "en"
if mock_testing is True: # for testing purposes only
return
ad_hoc_recognizers = presidio_ad_hoc_recognizers
if ad_hoc_recognizers is not None:
try:
with open(ad_hoc_recognizers, "r") as file:
self.ad_hoc_recognizers = json.load(file)
except FileNotFoundError:
raise Exception(f"File not found. file_path={ad_hoc_recognizers}")
except json.JSONDecodeError as e:
raise Exception(
f"Error decoding JSON file: {str(e)}, file_path={ad_hoc_recognizers}"
)
except Exception as e:
raise Exception(
f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}"
)
self.validate_environment(
presidio_analyzer_api_base=presidio_analyzer_api_base,
presidio_anonymizer_api_base=presidio_anonymizer_api_base,
)
def validate_environment(
self,
presidio_analyzer_api_base: Optional[str] = None,
presidio_anonymizer_api_base: Optional[str] = None,
):
self.presidio_analyzer_api_base: Optional[
str
] = presidio_analyzer_api_base or get_secret(
"PRESIDIO_ANALYZER_API_BASE", None
) # type: ignore
self.presidio_anonymizer_api_base: Optional[
str
] = presidio_anonymizer_api_base or litellm.get_secret(
"PRESIDIO_ANONYMIZER_API_BASE", None
) # type: ignore
if self.presidio_analyzer_api_base is None:
raise Exception("Missing `PRESIDIO_ANALYZER_API_BASE` from environment")
if not self.presidio_analyzer_api_base.endswith("/"):
self.presidio_analyzer_api_base += "/"
if not (
self.presidio_analyzer_api_base.startswith("http://")
or self.presidio_analyzer_api_base.startswith("https://")
):
# add http:// if unset, assume communicating over private network - e.g. render
self.presidio_analyzer_api_base = (
"http://" + self.presidio_analyzer_api_base
)
if self.presidio_anonymizer_api_base is None:
raise Exception("Missing `PRESIDIO_ANONYMIZER_API_BASE` from environment")
if not self.presidio_anonymizer_api_base.endswith("/"):
self.presidio_anonymizer_api_base += "/"
if not (
self.presidio_anonymizer_api_base.startswith("http://")
or self.presidio_anonymizer_api_base.startswith("https://")
):
# add http:// if unset, assume communicating over private network - e.g. render
self.presidio_anonymizer_api_base = (
"http://" + self.presidio_anonymizer_api_base
)
def _get_presidio_analyze_request_payload(
self,
text: str,
presidio_config: Optional[PresidioPerRequestConfig],
request_data: dict,
) -> PresidioAnalyzeRequest:
"""
Construct the payload for the Presidio analyze request
API Ref: https://microsoft.github.io/presidio/api-docs/api-docs.html#tag/Analyzer/paths/~1analyze/post
"""
analyze_payload: PresidioAnalyzeRequest = PresidioAnalyzeRequest(
text=text,
language=self.presidio_language,
)
##################################################################
###### Check if user has configured any params for this guardrail
################################################################
if self.ad_hoc_recognizers is not None:
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
if self.pii_entities_config:
analyze_payload["entities"] = list(self.pii_entities_config.keys())
##################################################################
######### End of adding config params
##################################################################
# Check if client side request passed any dynamic params
if presidio_config and presidio_config.language:
analyze_payload["language"] = presidio_config.language
casted_analyze_payload: dict = cast(dict, analyze_payload)
casted_analyze_payload.update(
self.get_guardrail_dynamic_request_body_params(request_data=request_data)
)
return cast(PresidioAnalyzeRequest, casted_analyze_payload)
async def analyze_text(
self,
text: str,
presidio_config: Optional[PresidioPerRequestConfig],
request_data: dict,
) -> Union[List[PresidioAnalyzeResponseItem], Dict]:
"""
Send text to the Presidio analyzer endpoint and get analysis results
"""
try:
async with aiohttp.ClientSession() as session:
if self.mock_redacted_text is not None:
return self.mock_redacted_text
# Make the request to /analyze
analyze_url = f"{self.presidio_analyzer_api_base}analyze"
analyze_payload: PresidioAnalyzeRequest = (
self._get_presidio_analyze_request_payload(
text=text,
presidio_config=presidio_config,
request_data=request_data,
)
)
verbose_proxy_logger.debug(
"Making request to: %s with payload: %s",
analyze_url,
analyze_payload,
)
async with session.post(analyze_url, json=analyze_payload) as response:
analyze_results = await response.json()
verbose_proxy_logger.debug("analyze_results: %s", analyze_results)
final_results = []
for item in analyze_results:
final_results.append(PresidioAnalyzeResponseItem(**item))
return final_results
except Exception as e:
raise e
async def anonymize_text(
self,
text: str,
analyze_results: Any,
output_parse_pii: bool,
masked_entity_count: Dict[str, int],
) -> str:
"""
Send analysis results to the Presidio anonymizer endpoint to get redacted text
"""
try:
async with aiohttp.ClientSession() as session:
# Make the request to /anonymize
anonymize_url = f"{self.presidio_anonymizer_api_base}anonymize"
verbose_proxy_logger.debug("Making request to: %s", anonymize_url)
anonymize_payload = {
"text": text,
"analyzer_results": analyze_results,
}
async with session.post(
anonymize_url, json=anonymize_payload
) as response:
redacted_text = await response.json()
new_text = text
if redacted_text is not None:
verbose_proxy_logger.debug("redacted_text: %s", redacted_text)
for item in redacted_text["items"]:
start = item["start"]
end = item["end"]
replacement = item["text"] # replacement token
if item["operator"] == "replace" and output_parse_pii is True:
# check if token in dict
# if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing
if replacement in self.pii_tokens:
replacement = replacement + str(uuid.uuid4())
self.pii_tokens[replacement] = new_text[
start:end
] # get text it'll replace
new_text = new_text[:start] + replacement + new_text[end:]
entity_type = item.get("entity_type", None)
if entity_type is not None:
masked_entity_count[entity_type] = (
masked_entity_count.get(entity_type, 0) + 1
)
return redacted_text["text"]
else:
raise Exception(f"Invalid anonymizer response: {redacted_text}")
except Exception as e:
raise e
def raise_exception_if_blocked_entities_detected(
self, analyze_results: Union[List[PresidioAnalyzeResponseItem], Dict]
):
"""
Raise an exception if blocked entities are detected
"""
if self.pii_entities_config is None:
return
if isinstance(analyze_results, Dict):
# if mock testing is enabled, analyze_results is a dict
# we don't need to raise an exception in this case
return
for result in analyze_results:
entity_type = result.get("entity_type")
if entity_type:
casted_entity_type: PiiEntityType = cast(PiiEntityType, entity_type)
if (
casted_entity_type in self.pii_entities_config
and self.pii_entities_config[casted_entity_type] == PiiAction.BLOCK
):
raise BlockedPiiEntityError(
entity_type=entity_type,
guardrail_name=self.guardrail_name,
)
async def check_pii(
self,
text: str,
output_parse_pii: bool,
presidio_config: Optional[PresidioPerRequestConfig],
request_data: dict,
) -> str:
"""
Calls Presidio Analyze + Anonymize endpoints for PII Analysis + Masking
"""
start_time = datetime.now()
analyze_results: Optional[Union[List[PresidioAnalyzeResponseItem], Dict]] = None
status: Literal["success", "failure"] = "success"
masked_entity_count: Dict[str, int] = {}
exception_str: str = ""
try:
if self.mock_redacted_text is not None:
redacted_text = self.mock_redacted_text
else:
# First get analysis results
analyze_results = await self.analyze_text(
text=text,
presidio_config=presidio_config,
request_data=request_data,
)
verbose_proxy_logger.debug("analyze_results: %s", analyze_results)
####################################################
# Blocked Entities check
####################################################
self.raise_exception_if_blocked_entities_detected(
analyze_results=analyze_results
)
# Then anonymize the text using the analysis results
return await self.anonymize_text(
text=text,
analyze_results=analyze_results,
output_parse_pii=output_parse_pii,
masked_entity_count=masked_entity_count,
)
return redacted_text["text"]
except Exception as e:
status = "failure"
exception_str = str(e)
raise e
finally:
####################################################
# Create Guardrail Trace for logging on Langfuse, Datadog, etc.
####################################################
guardrail_json_response: Union[Exception, str, dict, List[dict]] = {}
if status == "success":
if isinstance(analyze_results, List):
guardrail_json_response = [dict(item) for item in analyze_results]
else:
guardrail_json_response = exception_str
self.add_standard_logging_guardrail_information_to_request_data(
guardrail_json_response=guardrail_json_response,
request_data=request_data,
guardrail_status=status,
start_time=start_time.timestamp(),
end_time=datetime.now().timestamp(),
duration=(datetime.now() - start_time).total_seconds(),
masked_entity_count=masked_entity_count,
)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
"""
- Check if request turned off pii
- Check if user allowed to turn off pii (key permissions -> 'allow_pii_controls')
- Take the request data
- Call /analyze -> get the results
- Call /anonymize w/ the analyze results -> get the redacted text
For multiple messages in /chat/completions, we'll need to call them in parallel.
"""
try:
content_safety = data.get("content_safety", None)
verbose_proxy_logger.debug("content_safety: %s", content_safety)
presidio_config = self.get_presidio_settings_from_request_data(data)
if call_type in [
LitellmCallTypes.completion.value,
LitellmCallTypes.acompletion.value,
]:
messages = data["messages"]
tasks = []
for m in messages:
content = m.get("content", None)
if content is None:
continue
if isinstance(content, str):
tasks.append(
self.check_pii(
text=content,
output_parse_pii=self.output_parse_pii,
presidio_config=presidio_config,
request_data=data,
)
)
responses = await asyncio.gather(*tasks)
for index, r in enumerate(responses):
content = messages[index].get("content", None)
if content is None:
continue
if isinstance(content, str):
messages[index][
"content"
] = r # replace content with redacted string
verbose_proxy_logger.info(
f"Presidio PII Masking: Redacted pii message: {data['messages']}"
)
data["messages"] = messages
else:
verbose_proxy_logger.debug(
f"Not running async_pre_call_hook for call_type={call_type}"
)
return data
except Exception as e:
raise e
def logging_hook(
self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]:
from concurrent.futures import ThreadPoolExecutor
def run_in_new_loop():
"""Run the coroutine in a new event loop within this thread."""
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
return new_loop.run_until_complete(
self.async_logging_hook(
kwargs=kwargs, result=result, call_type=call_type
)
)
finally:
new_loop.close()
asyncio.set_event_loop(None)
try:
# First, try to get the current event loop
_ = asyncio.get_running_loop()
# If we're already in an event loop, run in a separate thread
# to avoid nested event loop issues
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
return future.result()
except RuntimeError:
# No running event loop, we can safely run in this thread
return run_in_new_loop()
async def async_logging_hook(
self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]:
"""
Masks the input before logging to langfuse, datadog, etc.
"""
if (
call_type == "completion" or call_type == "acompletion"
): # /chat/completions requests
messages: Optional[List] = kwargs.get("messages", None)
tasks = []
if messages is None:
return kwargs, result
presidio_config = self.get_presidio_settings_from_request_data(kwargs)
for m in messages:
text_str = ""
content = m.get("content", None)
if content is None:
continue
if isinstance(content, str):
text_str = content
tasks.append(
self.check_pii(
text=text_str,
output_parse_pii=False,
presidio_config=presidio_config,
request_data=kwargs,
)
) # need to pass separately b/c presidio has context window limits
responses = await asyncio.gather(*tasks)
for index, r in enumerate(responses):
content = messages[index].get("content", None)
if content is None:
continue
if isinstance(content, str):
messages[index][
"content"
] = r # replace content with redacted string
verbose_proxy_logger.info(
f"Presidio PII Masking: Redacted pii message: {messages}"
)
kwargs["messages"] = messages
return kwargs, result
async def async_post_call_success_hook( # type: ignore
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
):
"""
Output parse the response object to replace the masked tokens with user sent values
"""
verbose_proxy_logger.debug(
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
)
if self.output_parse_pii is False and litellm.output_parse_pii is False:
return response
if isinstance(response, ModelResponse) and not isinstance(
response.choices[0], StreamingChoices
): # /chat/completions requests
if isinstance(response.choices[0].message.content, str):
verbose_proxy_logger.debug(
f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}"
)
for key, value in self.pii_tokens.items():
response.choices[0].message.content = response.choices[
0
].message.content.replace(key, value)
return response
async def async_post_call_streaming_iterator_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response: Any,
request_data: dict,
) -> AsyncGenerator[ModelResponseStream, None]:
"""
Process streaming response chunks to unmask PII tokens when needed.
If PII processing is enabled, this collects all chunks, applies PII unmasking,
and returns a reconstructed stream. Otherwise, it passes through the original stream.
"""
# If PII unmasking not needed, just pass through the original stream
if not (self.output_parse_pii and self.pii_tokens):
async for chunk in response:
yield chunk
return
# Import here to avoid circular imports
from litellm.llms.base_llm.base_model_iterator import MockResponseIterator
from litellm.types.utils import Choices, Message
try:
# Collect all chunks to process them together
collected_content = ""
last_chunk = None
async for chunk in response:
last_chunk = chunk
# Extract content safely with proper attribute checks
if (
hasattr(chunk, "choices")
and chunk.choices
and hasattr(chunk.choices[0], "delta")
and hasattr(chunk.choices[0].delta, "content")
and isinstance(chunk.choices[0].delta.content, str)
):
collected_content += chunk.choices[0].delta.content
# No need to proceed if we didn't capture a valid chunk
if not last_chunk:
async for chunk in response:
yield chunk
return
# Apply PII unmasking to the complete content
for token, original_text in self.pii_tokens.items():
collected_content = collected_content.replace(token, original_text)
# Reconstruct the response with unmasked content
mock_response = MockResponseIterator(
model_response=ModelResponse(
id=last_chunk.id,
object=last_chunk.object,
created=last_chunk.created,
model=last_chunk.model,
choices=[
Choices(
message=Message(
role="assistant",
content=collected_content,
),
index=0,
finish_reason="stop",
)
],
),
json_mode=False,
)
# Return the reconstructed stream
async for chunk in mock_response:
yield chunk
except Exception as e:
verbose_proxy_logger.error(f"Error in PII streaming processing: {str(e)}")
# Fallback to original stream on error
async for chunk in response:
yield chunk
def get_presidio_settings_from_request_data(
self, data: dict
) -> Optional[PresidioPerRequestConfig]:
if "metadata" in data:
_metadata = data.get("metadata", None)
if _metadata is None:
return None
_guardrail_config = _metadata.get("guardrail_config")
if _guardrail_config:
_presidio_config = PresidioPerRequestConfig(**_guardrail_config)
return _presidio_config
return None
def print_verbose(self, print_statement):
try:
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except Exception:
pass
async def apply_guardrail(
self,
text: str,
language: Optional[str] = None,
entities: Optional[List[PiiEntityType]] = None,
) -> str:
"""
UI will call this function to check:
1. If the connection to the guardrail is working
2. When Testing the guardrail with some text, this function will be called with the input text and returns a text after applying the guardrail
"""
text = await self.check_pii(
text=text,
output_parse_pii=self.output_parse_pii,
presidio_config=None,
request_data={},
)
return text
def update_in_memory_litellm_params(self, litellm_params: LitellmParams) -> None:
"""
Update the guardrails litellm params in memory
"""
if litellm_params.pii_entities_config:
self.pii_entities_config = litellm_params.pii_entities_config