Spaces:
Running
Running
guardrails-genie
/
guardrails_genie
/guardrails
/entity_recognition
/presidio_entity_recognition_guardrail.py
from typing import Any, Dict, List, Optional | |
import weave | |
from presidio_analyzer import ( | |
AnalyzerEngine, | |
Pattern, | |
PatternRecognizer, | |
RecognizerRegistry, | |
) | |
from presidio_anonymizer import AnonymizerEngine | |
from pydantic import BaseModel | |
from ..base import Guardrail | |
class PresidioEntityRecognitionResponse(BaseModel): | |
contains_entities: bool | |
detected_entities: Dict[str, List[str]] | |
explanation: str | |
anonymized_text: Optional[str] = None | |
def safe(self) -> bool: | |
return not self.contains_entities | |
class PresidioEntityRecognitionSimpleResponse(BaseModel): | |
contains_entities: bool | |
explanation: str | |
anonymized_text: Optional[str] = None | |
def safe(self) -> bool: | |
return not self.contains_entities | |
# TODO: Add support for transformers workflow and not just Spacy | |
class PresidioEntityRecognitionGuardrail(Guardrail): | |
""" | |
A guardrail class for entity recognition and anonymization using Presidio. | |
This class extends the Guardrail base class to provide functionality for | |
detecting and optionally anonymizing entities in text using the Presidio | |
library. It leverages Presidio's AnalyzerEngine and AnonymizerEngine to | |
perform these tasks. | |
!!! example "Using PresidioEntityRecognitionGuardrail" | |
```python | |
from guardrails_genie.guardrails.entity_recognition import PresidioEntityRecognitionGuardrail | |
# Initialize with default entities | |
guardrail = PresidioEntityRecognitionGuardrail(should_anonymize=True) | |
# Or with specific entities | |
selected_entities = ["CREDIT_CARD", "US_SSN", "EMAIL_ADDRESS"] | |
guardrail = PresidioEntityRecognitionGuardrail( | |
selected_entities=selected_entities, | |
should_anonymize=True | |
) | |
``` | |
Attributes: | |
analyzer (AnalyzerEngine): The Presidio engine used for entity analysis. | |
anonymizer (AnonymizerEngine): The Presidio engine used for text anonymization. | |
selected_entities (List[str]): A list of entity types to detect in the text. | |
should_anonymize (bool): A flag indicating whether detected entities should be anonymized. | |
language (str): The language of the text to be analyzed. | |
Args: | |
selected_entities (Optional[List[str]]): A list of entity types to detect in the text. | |
should_anonymize (bool): A flag indicating whether detected entities should be anonymized. | |
language (str): The language of the text to be analyzed. | |
deny_lists (Optional[Dict[str, List[str]]]): A dictionary of entity types and their | |
corresponding deny lists. | |
regex_patterns (Optional[Dict[str, List[Dict[str, str]]]]): A dictionary of entity | |
types and their corresponding regex patterns. | |
custom_recognizers (Optional[List[Any]]): A list of custom recognizers to add to the | |
analyzer. | |
show_available_entities (bool): A flag indicating whether to print available entities. | |
""" | |
def get_available_entities() -> List[str]: | |
registry = RecognizerRegistry() | |
analyzer = AnalyzerEngine(registry=registry) | |
return [ | |
recognizer.supported_entities[0] | |
for recognizer in analyzer.registry.recognizers | |
] | |
analyzer: AnalyzerEngine | |
anonymizer: AnonymizerEngine | |
selected_entities: List[str] | |
should_anonymize: bool | |
language: str | |
def __init__( | |
self, | |
selected_entities: Optional[List[str]] = None, | |
should_anonymize: bool = False, | |
language: str = "en", | |
deny_lists: Optional[Dict[str, List[str]]] = None, | |
regex_patterns: Optional[Dict[str, List[Dict[str, str]]]] = None, | |
custom_recognizers: Optional[List[Any]] = None, | |
show_available_entities: bool = False, | |
): | |
# If show_available_entities is True, print available entities | |
if show_available_entities: | |
available_entities = self.get_available_entities() | |
print("\nAvailable entities:") | |
print("=" * 25) | |
for entity in available_entities: | |
print(f"- {entity}") | |
print("=" * 25 + "\n") | |
# Initialize default values to all available entities | |
if selected_entities is None: | |
selected_entities = self.get_available_entities() | |
# Get available entities dynamically | |
available_entities = self.get_available_entities() | |
# Filter out invalid entities and warn user | |
invalid_entities = [e for e in selected_entities if e not in available_entities] | |
valid_entities = [e for e in selected_entities if e in available_entities] | |
if invalid_entities: | |
print( | |
f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}" | |
) | |
print(f"Continuing with valid entities: {valid_entities}") | |
selected_entities = valid_entities | |
# Initialize analyzer with default recognizers | |
analyzer = AnalyzerEngine() | |
# Add custom recognizers if provided | |
if custom_recognizers: | |
for recognizer in custom_recognizers: | |
analyzer.registry.add_recognizer(recognizer) | |
# Add deny list recognizers if provided | |
if deny_lists: | |
for entity_type, tokens in deny_lists.items(): | |
deny_list_recognizer = PatternRecognizer( | |
supported_entity=entity_type, deny_list=tokens | |
) | |
analyzer.registry.add_recognizer(deny_list_recognizer) | |
# Add regex pattern recognizers if provided | |
if regex_patterns: | |
for entity_type, patterns in regex_patterns.items(): | |
presidio_patterns = [ | |
Pattern( | |
name=pattern.get("name", f"pattern_{i}"), | |
regex=pattern["regex"], | |
score=pattern.get("score", 0.5), | |
) | |
for i, pattern in enumerate(patterns) | |
] | |
regex_recognizer = PatternRecognizer( | |
supported_entity=entity_type, patterns=presidio_patterns | |
) | |
analyzer.registry.add_recognizer(regex_recognizer) | |
# Initialize Presidio engines | |
anonymizer = AnonymizerEngine() | |
# Call parent class constructor with all fields | |
super().__init__( | |
analyzer=analyzer, | |
anonymizer=anonymizer, | |
selected_entities=selected_entities, | |
should_anonymize=should_anonymize, | |
language=language, | |
) | |
def guard( | |
self, prompt: str, return_detected_types: bool = True, **kwargs | |
) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse: | |
""" | |
Analyzes the input prompt for entity recognition using the Presidio framework. | |
This function utilizes the Presidio AnalyzerEngine to detect entities within the | |
provided text prompt. It supports custom recognizers, deny lists, and regex patterns | |
for entity detection. The detected entities are grouped by their types and an | |
explanation of the findings is generated. If anonymization is enabled, the detected | |
entities in the text are anonymized. | |
Args: | |
prompt (str): The text to be analyzed for entity recognition. | |
return_detected_types (bool): Determines the type of response. If True, the | |
response includes detailed information about detected entity types. | |
Returns: | |
PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse: | |
A response object containing information about whether entities were detected, | |
the types and instances of detected entities, an explanation of the analysis, | |
and optionally, the anonymized text if anonymization is enabled. | |
""" | |
# Analyze text for entities | |
analyzer_results = self.analyzer.analyze( | |
text=str(prompt), entities=self.selected_entities, language=self.language | |
) | |
# Group results by entity type | |
detected_entities = {} | |
for result in analyzer_results: | |
entity_type = result.entity_type | |
text_slice = prompt[result.start : result.end] | |
if entity_type not in detected_entities: | |
detected_entities[entity_type] = [] | |
detected_entities[entity_type].append(text_slice) | |
# Create explanation | |
explanation_parts = [] | |
if detected_entities: | |
explanation_parts.append("Found the following entities in the text:") | |
for entity_type, instances in detected_entities.items(): | |
explanation_parts.append( | |
f"- {entity_type}: {len(instances)} instance(s)" | |
) | |
else: | |
explanation_parts.append("No entities detected in the text.") | |
# Add information about what was checked | |
explanation_parts.append("\nChecked for these entity types:") | |
for entity in self.selected_entities: | |
explanation_parts.append(f"- {entity}") | |
# Anonymize if requested | |
anonymized_text = None | |
if self.should_anonymize and detected_entities: | |
anonymized_result = self.anonymizer.anonymize( | |
text=prompt, analyzer_results=analyzer_results | |
) | |
anonymized_text = anonymized_result.text | |
if return_detected_types: | |
return PresidioEntityRecognitionResponse( | |
contains_entities=bool(detected_entities), | |
detected_entities=detected_entities, | |
explanation="\n".join(explanation_parts), | |
anonymized_text=anonymized_text, | |
) | |
else: | |
return PresidioEntityRecognitionSimpleResponse( | |
contains_entities=bool(detected_entities), | |
explanation="\n".join(explanation_parts), | |
anonymized_text=anonymized_text, | |
) | |
def predict( | |
self, prompt: str, return_detected_types: bool = True, **kwargs | |
) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse: | |
return self.guard(prompt, return_detected_types=return_detected_types, **kwargs) | |