Spaces:
Running
Running
Merge pull request #6 from soumik12345/feat/pii-banned-words
Browse files- .gitignore +2 -1
- application_pages/chat_app.py +28 -0
- guardrails_genie/guardrails/ReadMe.md +136 -0
- guardrails_genie/guardrails/__init__.py +10 -0
- guardrails_genie/guardrails/entity_recognition/__init__.py +10 -0
- guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_benchmark.py +0 -0
- guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_examples.py +178 -0
- guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_llm_judge.py +47 -0
- guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_regex_model.py +46 -0
- guardrails_genie/guardrails/entity_recognition/llm_judge_entity_recognition_guardrail.py +166 -0
- guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark.py +215 -0
- guardrails_genie/guardrails/entity_recognition/pii_examples/pii_test_examples.py +150 -0
- guardrails_genie/guardrails/entity_recognition/pii_examples/run_presidio_model.py +42 -0
- guardrails_genie/guardrails/entity_recognition/pii_examples/run_regex_model.py +42 -0
- guardrails_genie/guardrails/entity_recognition/pii_examples/run_transformers.py +43 -0
- guardrails_genie/guardrails/entity_recognition/presidio_entity_recognition_guardrail.py +191 -0
- guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py +138 -0
- guardrails_genie/guardrails/entity_recognition/transformers_entity_recognition_guardrail.py +190 -0
- guardrails_genie/guardrails/manager.py +7 -4
- guardrails_genie/regex_model.py +65 -0
- pyproject.toml +2 -0
.gitignore
CHANGED
@@ -168,4 +168,5 @@ temp.txt
|
|
168 |
**.csv
|
169 |
binary-classifier/
|
170 |
wandb/
|
171 |
-
artifacts/
|
|
|
|
168 |
**.csv
|
169 |
binary-classifier/
|
170 |
wandb/
|
171 |
+
artifacts/
|
172 |
+
evaluation_results/
|
application_pages/chat_app.py
CHANGED
@@ -61,6 +61,34 @@ def initialize_guardrails():
|
|
61 |
guardrail_name,
|
62 |
)(model_name=classifier_model_name)
|
63 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
st.session_state.guardrails_manager = GuardrailManager(
|
65 |
guardrails=st.session_state.guardrails
|
66 |
)
|
|
|
61 |
guardrail_name,
|
62 |
)(model_name=classifier_model_name)
|
63 |
)
|
64 |
+
elif guardrail_name == "PresidioEntityRecognitionGuardrail":
|
65 |
+
st.session_state.guardrails.append(
|
66 |
+
getattr(
|
67 |
+
importlib.import_module("guardrails_genie.guardrails"),
|
68 |
+
guardrail_name,
|
69 |
+
)()
|
70 |
+
)
|
71 |
+
elif guardrail_name == "RegexEntityRecognitionGuardrail":
|
72 |
+
st.session_state.guardrails.append(
|
73 |
+
getattr(
|
74 |
+
importlib.import_module("guardrails_genie.guardrails"),
|
75 |
+
guardrail_name,
|
76 |
+
)()
|
77 |
+
)
|
78 |
+
elif guardrail_name == "TransformersEntityRecognitionGuardrail":
|
79 |
+
st.session_state.guardrails.append(
|
80 |
+
getattr(
|
81 |
+
importlib.import_module("guardrails_genie.guardrails"),
|
82 |
+
guardrail_name,
|
83 |
+
)()
|
84 |
+
)
|
85 |
+
elif guardrail_name == "RestrictedTermsJudge":
|
86 |
+
st.session_state.guardrails.append(
|
87 |
+
getattr(
|
88 |
+
importlib.import_module("guardrails_genie.guardrails"),
|
89 |
+
guardrail_name,
|
90 |
+
)()
|
91 |
+
)
|
92 |
st.session_state.guardrails_manager = GuardrailManager(
|
93 |
guardrails=st.session_state.guardrails
|
94 |
)
|
guardrails_genie/guardrails/ReadMe.md
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Entity Recognition Guardrails
|
2 |
+
|
3 |
+
A collection of guardrails for detecting and anonymizing various types of entities in text, including PII (Personally Identifiable Information), restricted terms, and custom entities.
|
4 |
+
|
5 |
+
## Available Guardrails
|
6 |
+
|
7 |
+
### 1. Regex Entity Recognition
|
8 |
+
Simple pattern-based entity detection using regular expressions.
|
9 |
+
|
10 |
+
```python
|
11 |
+
from guardrails_genie.guardrails.entity_recognition import RegexEntityRecognitionGuardrail
|
12 |
+
|
13 |
+
# Initialize with default PII patterns
|
14 |
+
guardrail = RegexEntityRecognitionGuardrail(should_anonymize=True)
|
15 |
+
|
16 |
+
# Or with custom patterns
|
17 |
+
custom_patterns = {
|
18 |
+
"employee_id": r"EMP\d{6}",
|
19 |
+
"project_code": r"PRJ-[A-Z]{2}-\d{4}"
|
20 |
+
}
|
21 |
+
guardrail = RegexEntityRecognitionGuardrail(patterns=custom_patterns, should_anonymize=True)
|
22 |
+
```
|
23 |
+
|
24 |
+
### 2. Presidio Entity Recognition
|
25 |
+
Advanced entity detection using Microsoft's Presidio analyzer.
|
26 |
+
|
27 |
+
```python
|
28 |
+
from guardrails_genie.guardrails.entity_recognition import PresidioEntityRecognitionGuardrail
|
29 |
+
|
30 |
+
# Initialize with default entities
|
31 |
+
guardrail = PresidioEntityRecognitionGuardrail(should_anonymize=True)
|
32 |
+
|
33 |
+
# Or with specific entities
|
34 |
+
selected_entities = ["CREDIT_CARD", "US_SSN", "EMAIL_ADDRESS"]
|
35 |
+
guardrail = PresidioEntityRecognitionGuardrail(
|
36 |
+
selected_entities=selected_entities,
|
37 |
+
should_anonymize=True
|
38 |
+
)
|
39 |
+
```
|
40 |
+
|
41 |
+
### 3. Transformers Entity Recognition
|
42 |
+
Entity detection using transformer-based models.
|
43 |
+
|
44 |
+
```python
|
45 |
+
from guardrails_genie.guardrails.entity_recognition import TransformersEntityRecognitionGuardrail
|
46 |
+
|
47 |
+
# Initialize with default model
|
48 |
+
guardrail = TransformersEntityRecognitionGuardrail(should_anonymize=True)
|
49 |
+
|
50 |
+
# Or with specific model and entities
|
51 |
+
guardrail = TransformersEntityRecognitionGuardrail(
|
52 |
+
model_name="iiiorg/piiranha-v1-detect-personal-information",
|
53 |
+
selected_entities=["GIVENNAME", "SURNAME", "EMAIL"],
|
54 |
+
should_anonymize=True
|
55 |
+
)
|
56 |
+
```
|
57 |
+
|
58 |
+
### 4. LLM Judge for Restricted Terms
|
59 |
+
Advanced detection of restricted terms, competitor mentions, and brand protection using LLMs.
|
60 |
+
|
61 |
+
```python
|
62 |
+
from guardrails_genie.guardrails.entity_recognition import RestrictedTermsJudge
|
63 |
+
|
64 |
+
# Initialize with OpenAI model
|
65 |
+
guardrail = RestrictedTermsJudge(should_anonymize=True)
|
66 |
+
|
67 |
+
# Check for specific terms
|
68 |
+
result = guardrail.guard(
|
69 |
+
text="Let's implement features like Salesforce",
|
70 |
+
custom_terms=["Salesforce", "Oracle", "AWS"]
|
71 |
+
)
|
72 |
+
```
|
73 |
+
|
74 |
+
## Usage
|
75 |
+
|
76 |
+
All guardrails follow a consistent interface:
|
77 |
+
|
78 |
+
```python
|
79 |
+
# Initialize a guardrail
|
80 |
+
guardrail = RegexEntityRecognitionGuardrail(should_anonymize=True)
|
81 |
+
|
82 |
+
# Check text for entities
|
83 |
+
result = guardrail.guard("Hello, my email is john@example.com")
|
84 |
+
|
85 |
+
# Access results
|
86 |
+
print(f"Contains entities: {result.contains_entities}")
|
87 |
+
print(f"Detected entities: {result.detected_entities}")
|
88 |
+
print(f"Explanation: {result.explanation}")
|
89 |
+
print(f"Anonymized text: {result.anonymized_text}")
|
90 |
+
```
|
91 |
+
|
92 |
+
## Evaluation Tools
|
93 |
+
|
94 |
+
The module includes comprehensive evaluation tools and test cases:
|
95 |
+
|
96 |
+
- `pii_examples/`: Test cases for PII detection
|
97 |
+
- `banned_terms_examples/`: Test cases for restricted terms
|
98 |
+
- Benchmark scripts for evaluating model performance
|
99 |
+
|
100 |
+
### Running Evaluations
|
101 |
+
|
102 |
+
```python
|
103 |
+
# PII Detection Benchmark
|
104 |
+
from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_benchmark import main
|
105 |
+
main()
|
106 |
+
|
107 |
+
# (TODO): Restricted Terms Testing
|
108 |
+
from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_benchmark import main
|
109 |
+
main()
|
110 |
+
```
|
111 |
+
|
112 |
+
## Features
|
113 |
+
|
114 |
+
- Entity detection and anonymization
|
115 |
+
- Support for multiple detection methods (regex, Presidio, transformers, LLMs)
|
116 |
+
- Customizable entity types and patterns
|
117 |
+
- Detailed explanations of detected entities
|
118 |
+
- Comprehensive evaluation framework
|
119 |
+
- Support for custom terms and patterns
|
120 |
+
- Batch processing capabilities
|
121 |
+
- Performance metrics and benchmarking
|
122 |
+
|
123 |
+
## Response Format
|
124 |
+
|
125 |
+
All guardrails return responses with the following structure:
|
126 |
+
|
127 |
+
```python
|
128 |
+
{
|
129 |
+
"contains_entities": bool,
|
130 |
+
"detected_entities": {
|
131 |
+
"entity_type": ["detected_value_1", "detected_value_2"]
|
132 |
+
},
|
133 |
+
"explanation": str,
|
134 |
+
"anonymized_text": Optional[str]
|
135 |
+
}
|
136 |
+
```
|
guardrails_genie/guardrails/__init__.py
CHANGED
@@ -2,10 +2,20 @@ from .injection import (
|
|
2 |
PromptInjectionClassifierGuardrail,
|
3 |
PromptInjectionSurveyGuardrail,
|
4 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from .manager import GuardrailManager
|
6 |
|
7 |
__all__ = [
|
8 |
"PromptInjectionSurveyGuardrail",
|
9 |
"PromptInjectionClassifierGuardrail",
|
|
|
|
|
|
|
|
|
10 |
"GuardrailManager",
|
11 |
]
|
|
|
2 |
PromptInjectionClassifierGuardrail,
|
3 |
PromptInjectionSurveyGuardrail,
|
4 |
)
|
5 |
+
from .entity_recognition import (
|
6 |
+
PresidioEntityRecognitionGuardrail,
|
7 |
+
RegexEntityRecognitionGuardrail,
|
8 |
+
TransformersEntityRecognitionGuardrail,
|
9 |
+
RestrictedTermsJudge,
|
10 |
+
)
|
11 |
from .manager import GuardrailManager
|
12 |
|
13 |
__all__ = [
|
14 |
"PromptInjectionSurveyGuardrail",
|
15 |
"PromptInjectionClassifierGuardrail",
|
16 |
+
"PresidioEntityRecognitionGuardrail",
|
17 |
+
"RegexEntityRecognitionGuardrail",
|
18 |
+
"TransformersEntityRecognitionGuardrail",
|
19 |
+
"RestrictedTermsJudge",
|
20 |
"GuardrailManager",
|
21 |
]
|
guardrails_genie/guardrails/entity_recognition/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
|
2 |
+
from .regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
|
3 |
+
from .transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
|
4 |
+
from .llm_judge_entity_recognition_guardrail import RestrictedTermsJudge
|
5 |
+
__all__ = [
|
6 |
+
"PresidioEntityRecognitionGuardrail",
|
7 |
+
"RegexEntityRecognitionGuardrail",
|
8 |
+
"TransformersEntityRecognitionGuardrail",
|
9 |
+
"RestrictedTermsJudge"
|
10 |
+
]
|
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_benchmark.py
ADDED
File without changes
|
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_examples.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Collection of restricted terms test examples with expected outcomes for entity recognition testing.
|
3 |
+
Focuses on banned terms, competitor mentions, and brand protection scenarios.
|
4 |
+
"""
|
5 |
+
|
6 |
+
RESTRICTED_TERMS_EXAMPLES = [
|
7 |
+
{
|
8 |
+
"description": "Competitor Product Discussion",
|
9 |
+
"input_text": """
|
10 |
+
I think we should implement features similar to Salesforce's Einstein AI
|
11 |
+
and Oracle's Cloud Infrastructure. Maybe we could also look at how
|
12 |
+
AWS handles their lambda functions.
|
13 |
+
""",
|
14 |
+
"custom_terms": ["Salesforce", "Oracle", "AWS", "Einstein AI", "Cloud Infrastructure", "lambda"],
|
15 |
+
"expected_entities": {
|
16 |
+
"Salesforce": ["Salesforce"],
|
17 |
+
"Oracle": ["Oracle"],
|
18 |
+
"AWS": ["AWS"],
|
19 |
+
"Einstein AI": ["Einstein AI"],
|
20 |
+
"Cloud Infrastructure": ["Cloud Infrastructure"],
|
21 |
+
"lambda": ["lambda"]
|
22 |
+
}
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"description": "Inappropriate Language in Support Ticket",
|
26 |
+
"input_text": """
|
27 |
+
This damn product keeps crashing! What the hell is wrong with your
|
28 |
+
stupid service? I've wasted so much freaking time on this crap.
|
29 |
+
""",
|
30 |
+
"custom_terms": ["damn", "hell", "stupid", "crap"],
|
31 |
+
"expected_entities": {
|
32 |
+
"damn": ["damn"],
|
33 |
+
"hell": ["hell"],
|
34 |
+
"stupid": ["stupid"],
|
35 |
+
"crap": ["crap"]
|
36 |
+
}
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"description": "Confidential Project Names",
|
40 |
+
"input_text": """
|
41 |
+
Project Titan's launch date has been moved up. We should coordinate
|
42 |
+
with Project Phoenix team and the Blue Dragon initiative for resource allocation.
|
43 |
+
""",
|
44 |
+
"custom_terms": ["Project Titan", "Project Phoenix", "Blue Dragon"],
|
45 |
+
"expected_entities": {
|
46 |
+
"Project Titan": ["Project Titan"],
|
47 |
+
"Project Phoenix": ["Project Phoenix"],
|
48 |
+
"Blue Dragon": ["Blue Dragon"]
|
49 |
+
}
|
50 |
+
}
|
51 |
+
]
|
52 |
+
|
53 |
+
# Edge cases and special formats
|
54 |
+
EDGE_CASE_EXAMPLES = [
|
55 |
+
{
|
56 |
+
"description": "Common Corporate Abbreviations and Stock Symbols",
|
57 |
+
"input_text": """
|
58 |
+
MSFT's Azure and O365 platform is gaining market share.
|
59 |
+
Have you seen what GOOGL/GOOG and FB/META are doing with their AI?
|
60 |
+
CRM (Salesforce) and ORCL (Oracle) have interesting features too.
|
61 |
+
""",
|
62 |
+
"custom_terms": ["Microsoft", "Google", "Meta", "Facebook", "Salesforce", "Oracle"],
|
63 |
+
"expected_entities": {
|
64 |
+
"Microsoft": ["MSFT"],
|
65 |
+
"Google": ["GOOGL", "GOOG"],
|
66 |
+
"Meta": ["META"],
|
67 |
+
"Facebook": ["FB"],
|
68 |
+
"Salesforce": ["CRM", "Salesforce"],
|
69 |
+
"Oracle": ["ORCL"]
|
70 |
+
}
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"description": "L33t Speak and Intentional Obfuscation",
|
74 |
+
"input_text": """
|
75 |
+
S4l3sf0rc3 is better than 0r4cl3!
|
76 |
+
M1cr0$oft and G00gl3 are the main competitors.
|
77 |
+
Let's check F8book and Met@ too.
|
78 |
+
""",
|
79 |
+
"custom_terms": ["Salesforce", "Oracle", "Microsoft", "Google", "Facebook", "Meta"],
|
80 |
+
"expected_entities": {
|
81 |
+
"Salesforce": ["S4l3sf0rc3"],
|
82 |
+
"Oracle": ["0r4cl3"],
|
83 |
+
"Microsoft": ["M1cr0$oft"],
|
84 |
+
"Google": ["G00gl3"],
|
85 |
+
"Facebook": ["F8book"],
|
86 |
+
"Meta": ["Met@"]
|
87 |
+
}
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"description": "Case Variations and Partial Matches",
|
91 |
+
"input_text": """
|
92 |
+
salesFORCE and ORACLE are competitors.
|
93 |
+
MicroSoft and google are too.
|
94 |
+
Have you tried micro-soft or Google_Cloud?
|
95 |
+
""",
|
96 |
+
"custom_terms": ["Microsoft", "Google", "Salesforce", "Oracle"],
|
97 |
+
"expected_entities": {
|
98 |
+
"Microsoft": ["MicroSoft", "micro-soft"],
|
99 |
+
"Google": ["google", "Google_Cloud"],
|
100 |
+
"Salesforce": ["salesFORCE"],
|
101 |
+
"Oracle": ["ORACLE"]
|
102 |
+
}
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"description": "Common Misspellings and Typos",
|
106 |
+
"input_text": """
|
107 |
+
Microsft and Microsooft are common typos.
|
108 |
+
Goggle, Googel, and Gooogle are search related.
|
109 |
+
Salezforce and Oracel need checking too.
|
110 |
+
""",
|
111 |
+
"custom_terms": ["Microsoft", "Google", "Salesforce", "Oracle"],
|
112 |
+
"expected_entities": {
|
113 |
+
"Microsoft": ["Microsft", "Microsooft"],
|
114 |
+
"Google": ["Goggle", "Googel", "Gooogle"],
|
115 |
+
"Salesforce": ["Salezforce"],
|
116 |
+
"Oracle": ["Oracel"]
|
117 |
+
}
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"description": "Mixed Variations and Context",
|
121 |
+
"input_text": """
|
122 |
+
The M$ cloud competes with AWS (Amazon Web Services).
|
123 |
+
FB/Meta's social platform and GOOGL's search dominate.
|
124 |
+
SF.com and Oracle-DB are industry standards.
|
125 |
+
""",
|
126 |
+
"custom_terms": ["Microsoft", "Amazon Web Services", "Facebook", "Meta", "Google", "Salesforce", "Oracle"],
|
127 |
+
"expected_entities": {
|
128 |
+
"Microsoft": ["M$"],
|
129 |
+
"Amazon Web Services": ["AWS"],
|
130 |
+
"Facebook": ["FB"],
|
131 |
+
"Meta": ["Meta"],
|
132 |
+
"Google": ["GOOGL"],
|
133 |
+
"Salesforce": ["SF.com"],
|
134 |
+
"Oracle": ["Oracle-DB"]
|
135 |
+
}
|
136 |
+
}
|
137 |
+
]
|
138 |
+
|
139 |
+
def validate_entities(detected: dict, expected: dict) -> bool:
|
140 |
+
"""Compare detected entities with expected entities"""
|
141 |
+
if set(detected.keys()) != set(expected.keys()):
|
142 |
+
return False
|
143 |
+
return all(set(detected[k]) == set(expected[k]) for k in expected.keys())
|
144 |
+
|
145 |
+
def run_test_case(guardrail, test_case, test_type="Main"):
|
146 |
+
"""Run a single test case and print results"""
|
147 |
+
print(f"\n{test_type} Test Case: {test_case['description']}")
|
148 |
+
print("-" * 50)
|
149 |
+
|
150 |
+
result = guardrail.guard(
|
151 |
+
test_case['input_text'],
|
152 |
+
custom_terms=test_case['custom_terms']
|
153 |
+
)
|
154 |
+
expected = test_case['expected_entities']
|
155 |
+
|
156 |
+
# Validate results
|
157 |
+
matches = validate_entities(result.detected_entities, expected)
|
158 |
+
|
159 |
+
print(f"Test Status: {'✓ PASS' if matches else '✗ FAIL'}")
|
160 |
+
print(f"Contains Restricted Terms: {result.contains_entities}")
|
161 |
+
|
162 |
+
if not matches:
|
163 |
+
print("\nEntity Comparison:")
|
164 |
+
all_entity_types = set(list(result.detected_entities.keys()) + list(expected.keys()))
|
165 |
+
for entity_type in all_entity_types:
|
166 |
+
detected = set(result.detected_entities.get(entity_type, []))
|
167 |
+
expected_set = set(expected.get(entity_type, []))
|
168 |
+
print(f"\nEntity Type: {entity_type}")
|
169 |
+
print(f" Expected: {sorted(expected_set)}")
|
170 |
+
print(f" Detected: {sorted(detected)}")
|
171 |
+
if detected != expected_set:
|
172 |
+
print(f" Missing: {sorted(expected_set - detected)}")
|
173 |
+
print(f" Extra: {sorted(detected - expected_set)}")
|
174 |
+
|
175 |
+
if result.anonymized_text:
|
176 |
+
print(f"\nAnonymized Text:\n{result.anonymized_text}")
|
177 |
+
|
178 |
+
return matches
|
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_llm_judge.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from guardrails_genie.guardrails.entity_recognition.llm_judge_entity_recognition_guardrail import RestrictedTermsJudge
|
2 |
+
from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
|
3 |
+
RESTRICTED_TERMS_EXAMPLES,
|
4 |
+
EDGE_CASE_EXAMPLES,
|
5 |
+
run_test_case
|
6 |
+
)
|
7 |
+
from guardrails_genie.llm import OpenAIModel
|
8 |
+
import weave
|
9 |
+
|
10 |
+
def test_restricted_terms_detection():
|
11 |
+
"""Test restricted terms detection scenarios using predefined test cases"""
|
12 |
+
weave.init("guardrails-genie-restricted-terms-llm-judge")
|
13 |
+
|
14 |
+
# Create the guardrail with OpenAI model
|
15 |
+
llm_judge = RestrictedTermsJudge(
|
16 |
+
should_anonymize=True,
|
17 |
+
llm_model=OpenAIModel()
|
18 |
+
)
|
19 |
+
|
20 |
+
# Test statistics
|
21 |
+
total_tests = len(RESTRICTED_TERMS_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
|
22 |
+
passed_tests = 0
|
23 |
+
|
24 |
+
# Test main restricted terms examples
|
25 |
+
print("\nRunning Main Restricted Terms Tests")
|
26 |
+
print("=" * 80)
|
27 |
+
for test_case in RESTRICTED_TERMS_EXAMPLES:
|
28 |
+
if run_test_case(llm_judge, test_case):
|
29 |
+
passed_tests += 1
|
30 |
+
|
31 |
+
# Test edge cases
|
32 |
+
print("\nRunning Edge Cases")
|
33 |
+
print("=" * 80)
|
34 |
+
for test_case in EDGE_CASE_EXAMPLES:
|
35 |
+
if run_test_case(llm_judge, test_case, "Edge"):
|
36 |
+
passed_tests += 1
|
37 |
+
|
38 |
+
# Print summary
|
39 |
+
print("\nTest Summary")
|
40 |
+
print("=" * 80)
|
41 |
+
print(f"Total Tests: {total_tests}")
|
42 |
+
print(f"Passed: {passed_tests}")
|
43 |
+
print(f"Failed: {total_tests - passed_tests}")
|
44 |
+
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
test_restricted_terms_detection()
|
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_regex_model.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
|
2 |
+
from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
|
3 |
+
RESTRICTED_TERMS_EXAMPLES,
|
4 |
+
EDGE_CASE_EXAMPLES,
|
5 |
+
run_test_case
|
6 |
+
)
|
7 |
+
import weave
|
8 |
+
|
9 |
+
def test_restricted_terms_detection():
|
10 |
+
"""Test restricted terms detection scenarios using predefined test cases"""
|
11 |
+
weave.init("guardrails-genie-restricted-terms-regex-model")
|
12 |
+
|
13 |
+
# Create the guardrail with anonymization enabled
|
14 |
+
regex_guardrail = RegexEntityRecognitionGuardrail(
|
15 |
+
use_defaults=False, # Don't use default PII patterns
|
16 |
+
should_anonymize=True
|
17 |
+
)
|
18 |
+
|
19 |
+
# Test statistics
|
20 |
+
total_tests = len(RESTRICTED_TERMS_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
|
21 |
+
passed_tests = 0
|
22 |
+
|
23 |
+
# Test main restricted terms examples
|
24 |
+
print("\nRunning Main Restricted Terms Tests")
|
25 |
+
print("=" * 80)
|
26 |
+
for test_case in RESTRICTED_TERMS_EXAMPLES:
|
27 |
+
if run_test_case(regex_guardrail, test_case):
|
28 |
+
passed_tests += 1
|
29 |
+
|
30 |
+
# Test edge cases
|
31 |
+
print("\nRunning Edge Cases")
|
32 |
+
print("=" * 80)
|
33 |
+
for test_case in EDGE_CASE_EXAMPLES:
|
34 |
+
if run_test_case(regex_guardrail, test_case, "Edge"):
|
35 |
+
passed_tests += 1
|
36 |
+
|
37 |
+
# Print summary
|
38 |
+
print("\nTest Summary")
|
39 |
+
print("=" * 80)
|
40 |
+
print(f"Total Tests: {total_tests}")
|
41 |
+
print(f"Passed: {passed_tests}")
|
42 |
+
print(f"Failed: {total_tests - passed_tests}")
|
43 |
+
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
44 |
+
|
45 |
+
if __name__ == "__main__":
|
46 |
+
test_restricted_terms_detection()
|
guardrails_genie/guardrails/entity_recognition/llm_judge_entity_recognition_guardrail.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
import weave
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
from typing_extensions import Annotated
|
5 |
+
|
6 |
+
from ...llm import OpenAIModel
|
7 |
+
from ..base import Guardrail
|
8 |
+
import instructor
|
9 |
+
|
10 |
+
|
11 |
+
class TermMatch(BaseModel):
|
12 |
+
"""Represents a matched term and its variations"""
|
13 |
+
original_term: str
|
14 |
+
matched_text: str
|
15 |
+
match_type: str = Field(
|
16 |
+
description="Type of match: EXACT, MISSPELLING, ABBREVIATION, or VARIANT"
|
17 |
+
)
|
18 |
+
explanation: str = Field(
|
19 |
+
description="Explanation of why this is considered a match"
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class RestrictedTermsAnalysis(BaseModel):
|
24 |
+
"""Analysis result for restricted terms detection"""
|
25 |
+
contains_restricted_terms: bool = Field(
|
26 |
+
description="Whether any restricted terms were detected"
|
27 |
+
)
|
28 |
+
detected_matches: List[TermMatch] = Field(
|
29 |
+
default_factory=list,
|
30 |
+
description="List of detected term matches with their variations"
|
31 |
+
)
|
32 |
+
explanation: str = Field(
|
33 |
+
description="Detailed explanation of the analysis"
|
34 |
+
)
|
35 |
+
anonymized_text: Optional[str] = Field(
|
36 |
+
default=None,
|
37 |
+
description="Text with restricted terms replaced with category tags"
|
38 |
+
)
|
39 |
+
|
40 |
+
@property
|
41 |
+
def safe(self) -> bool:
|
42 |
+
return not self.contains_restricted_terms
|
43 |
+
|
44 |
+
|
45 |
+
class RestrictedTermsRecognitionResponse(BaseModel):
|
46 |
+
contains_entities: bool
|
47 |
+
detected_entities: Dict[str, List[str]]
|
48 |
+
explanation: str
|
49 |
+
anonymized_text: Optional[str] = None
|
50 |
+
|
51 |
+
@property
|
52 |
+
def safe(self) -> bool:
|
53 |
+
return not self.contains_entities
|
54 |
+
|
55 |
+
|
56 |
+
class RestrictedTermsJudge(Guardrail):
|
57 |
+
llm_model: OpenAIModel = Field(default_factory=lambda: OpenAIModel())
|
58 |
+
should_anonymize: bool = False
|
59 |
+
|
60 |
+
def __init__(self, should_anonymize: bool = False, **kwargs):
|
61 |
+
super().__init__(**kwargs)
|
62 |
+
self.should_anonymize = should_anonymize
|
63 |
+
# Pass the OpenAI client to instructor.patch()
|
64 |
+
instructor.patch(client=self.llm_model._openai_client)
|
65 |
+
|
66 |
+
def format_prompts(self, text: str, custom_terms: List[str]) -> tuple[str, str]:
|
67 |
+
user_prompt = f"""Analyze the following text for restricted terms and variations:
|
68 |
+
|
69 |
+
<text_to_analyze>
|
70 |
+
{text}
|
71 |
+
</text_to_analyze>
|
72 |
+
|
73 |
+
Look for these specific terms and their variations:
|
74 |
+
{', '.join(custom_terms)}
|
75 |
+
|
76 |
+
Analyze the text carefully for:
|
77 |
+
1. Exact matches
|
78 |
+
2. Common misspellings
|
79 |
+
3. Abbreviations
|
80 |
+
4. Intentional variations (l33t speak, special characters, etc.)
|
81 |
+
5. Case variations
|
82 |
+
"""
|
83 |
+
|
84 |
+
system_prompt = """You are an expert system for detecting restricted terms, brand mentions, and inappropriate content.
|
85 |
+
|
86 |
+
Your task is to:
|
87 |
+
1. Identify exact matches of restricted terms
|
88 |
+
2. Detect variations including:
|
89 |
+
- Misspellings (both accidental and intentional)
|
90 |
+
- Abbreviations and acronyms
|
91 |
+
- Case variations
|
92 |
+
- L33t speak or special character substitutions
|
93 |
+
- Partial matches within larger words
|
94 |
+
|
95 |
+
For each match, you must:
|
96 |
+
1. Identify the original restricted term
|
97 |
+
2. Note the actual text that matched
|
98 |
+
3. Classify the match type
|
99 |
+
4. Provide a confidence score
|
100 |
+
5. Explain why it's considered a match
|
101 |
+
|
102 |
+
Be thorough but avoid false positives. Focus on meaningful matches that indicate actual attempts to use restricted terms.
|
103 |
+
|
104 |
+
Return your analysis in the structured format specified by the RestrictedTermsAnalysis model."""
|
105 |
+
|
106 |
+
return user_prompt, system_prompt
|
107 |
+
|
108 |
+
@weave.op()
|
109 |
+
def predict(self, text: str, custom_terms: List[str], **kwargs) -> RestrictedTermsAnalysis:
|
110 |
+
user_prompt, system_prompt = self.format_prompts(text, custom_terms)
|
111 |
+
|
112 |
+
response = self.llm_model.predict(
|
113 |
+
user_prompts=user_prompt,
|
114 |
+
system_prompt=system_prompt,
|
115 |
+
response_format=RestrictedTermsAnalysis,
|
116 |
+
temperature=0.1, # Lower temperature for more consistent analysis
|
117 |
+
**kwargs
|
118 |
+
)
|
119 |
+
|
120 |
+
return response.choices[0].message.parsed
|
121 |
+
|
122 |
+
#TODO: Remove default custom_terms
|
123 |
+
@weave.op()
|
124 |
+
def guard(self, text: str, custom_terms: List[str] = ["Microsoft", "Amazon Web Services", "Facebook", "Meta", "Google", "Salesforce", "Oracle"], aggregate_redaction: bool = True, **kwargs) -> RestrictedTermsRecognitionResponse:
|
125 |
+
"""
|
126 |
+
Guard against restricted terms and their variations.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
text: Text to analyze
|
130 |
+
custom_terms: List of restricted terms to check for
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
RestrictedTermsRecognitionResponse containing safety assessment and detailed analysis
|
134 |
+
"""
|
135 |
+
analysis = self.predict(text, custom_terms, **kwargs)
|
136 |
+
|
137 |
+
# Create a summary of findings
|
138 |
+
if analysis.contains_restricted_terms:
|
139 |
+
summary_parts = ["Restricted terms detected:"]
|
140 |
+
for match in analysis.detected_matches:
|
141 |
+
summary_parts.append(f"\n- {match.original_term}: {match.matched_text} ({match.match_type})")
|
142 |
+
summary = "\n".join(summary_parts)
|
143 |
+
else:
|
144 |
+
summary = "No restricted terms detected."
|
145 |
+
|
146 |
+
# Updated anonymization logic
|
147 |
+
anonymized_text = None
|
148 |
+
if self.should_anonymize and analysis.contains_restricted_terms:
|
149 |
+
anonymized_text = text
|
150 |
+
for match in analysis.detected_matches:
|
151 |
+
replacement = "[redacted]" if aggregate_redaction else f"[{match.match_type.upper()}]"
|
152 |
+
anonymized_text = anonymized_text.replace(match.matched_text, replacement)
|
153 |
+
|
154 |
+
# Convert detected_matches to a dictionary format
|
155 |
+
detected_entities = {}
|
156 |
+
for match in analysis.detected_matches:
|
157 |
+
if match.original_term not in detected_entities:
|
158 |
+
detected_entities[match.original_term] = []
|
159 |
+
detected_entities[match.original_term].append(match.matched_text)
|
160 |
+
|
161 |
+
return RestrictedTermsRecognitionResponse(
|
162 |
+
contains_entities=analysis.contains_restricted_terms,
|
163 |
+
detected_entities=detected_entities,
|
164 |
+
explanation=summary,
|
165 |
+
anonymized_text=anonymized_text
|
166 |
+
)
|
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from typing import Dict, List, Tuple
|
3 |
+
import random
|
4 |
+
from tqdm import tqdm
|
5 |
+
import json
|
6 |
+
from pathlib import Path
|
7 |
+
import weave
|
8 |
+
|
9 |
+
def load_ai4privacy_dataset(num_samples: int = 100, split: str = "validation") -> List[Dict]:
|
10 |
+
"""
|
11 |
+
Load and prepare samples from the ai4privacy dataset.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
num_samples: Number of samples to evaluate
|
15 |
+
split: Dataset split to use ("train" or "validation")
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
List of prepared test cases
|
19 |
+
"""
|
20 |
+
# Load the dataset
|
21 |
+
dataset = load_dataset("ai4privacy/pii-masking-400k")
|
22 |
+
|
23 |
+
# Get the specified split
|
24 |
+
data_split = dataset[split]
|
25 |
+
|
26 |
+
# Randomly sample entries if num_samples is less than total
|
27 |
+
if num_samples < len(data_split):
|
28 |
+
indices = random.sample(range(len(data_split)), num_samples)
|
29 |
+
samples = [data_split[i] for i in indices]
|
30 |
+
else:
|
31 |
+
samples = data_split
|
32 |
+
|
33 |
+
# Convert to test case format
|
34 |
+
test_cases = []
|
35 |
+
for sample in samples:
|
36 |
+
# Extract entities from privacy_mask
|
37 |
+
entities: Dict[str, List[str]] = {}
|
38 |
+
for entity in sample['privacy_mask']:
|
39 |
+
label = entity['label']
|
40 |
+
value = entity['value']
|
41 |
+
if label not in entities:
|
42 |
+
entities[label] = []
|
43 |
+
entities[label].append(value)
|
44 |
+
|
45 |
+
test_case = {
|
46 |
+
"description": f"AI4Privacy Sample (ID: {sample['uid']})",
|
47 |
+
"input_text": sample['source_text'],
|
48 |
+
"expected_entities": entities,
|
49 |
+
"masked_text": sample['masked_text'],
|
50 |
+
"language": sample['language'],
|
51 |
+
"locale": sample['locale']
|
52 |
+
}
|
53 |
+
test_cases.append(test_case)
|
54 |
+
|
55 |
+
return test_cases
|
56 |
+
|
57 |
+
@weave.op()
|
58 |
+
def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]:
|
59 |
+
"""
|
60 |
+
Evaluate a model on the test cases.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
guardrail: Entity recognition guardrail to evaluate
|
64 |
+
test_cases: List of test cases
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Tuple of (metrics dict, detailed results list)
|
68 |
+
"""
|
69 |
+
metrics = {
|
70 |
+
"total": len(test_cases),
|
71 |
+
"passed": 0,
|
72 |
+
"failed": 0,
|
73 |
+
"entity_metrics": {} # Will store precision/recall per entity type
|
74 |
+
}
|
75 |
+
|
76 |
+
detailed_results = []
|
77 |
+
|
78 |
+
for test_case in tqdm(test_cases, desc="Evaluating samples"):
|
79 |
+
# Run detection
|
80 |
+
result = guardrail.guard(test_case['input_text'])
|
81 |
+
detected = result.detected_entities
|
82 |
+
expected = test_case['expected_entities']
|
83 |
+
|
84 |
+
# Track entity-level metrics
|
85 |
+
all_entity_types = set(list(detected.keys()) + list(expected.keys()))
|
86 |
+
entity_results = {}
|
87 |
+
|
88 |
+
for entity_type in all_entity_types:
|
89 |
+
detected_set = set(detected.get(entity_type, []))
|
90 |
+
expected_set = set(expected.get(entity_type, []))
|
91 |
+
|
92 |
+
# Calculate metrics
|
93 |
+
true_positives = len(detected_set & expected_set)
|
94 |
+
false_positives = len(detected_set - expected_set)
|
95 |
+
false_negatives = len(expected_set - detected_set)
|
96 |
+
|
97 |
+
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
|
98 |
+
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
|
99 |
+
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
100 |
+
|
101 |
+
entity_results[entity_type] = {
|
102 |
+
"precision": precision,
|
103 |
+
"recall": recall,
|
104 |
+
"f1": f1,
|
105 |
+
"true_positives": true_positives,
|
106 |
+
"false_positives": false_positives,
|
107 |
+
"false_negatives": false_negatives
|
108 |
+
}
|
109 |
+
|
110 |
+
# Aggregate metrics
|
111 |
+
if entity_type not in metrics["entity_metrics"]:
|
112 |
+
metrics["entity_metrics"][entity_type] = {
|
113 |
+
"total_true_positives": 0,
|
114 |
+
"total_false_positives": 0,
|
115 |
+
"total_false_negatives": 0
|
116 |
+
}
|
117 |
+
metrics["entity_metrics"][entity_type]["total_true_positives"] += true_positives
|
118 |
+
metrics["entity_metrics"][entity_type]["total_false_positives"] += false_positives
|
119 |
+
metrics["entity_metrics"][entity_type]["total_false_negatives"] += false_negatives
|
120 |
+
|
121 |
+
# Store detailed result
|
122 |
+
detailed_result = {
|
123 |
+
"id": test_case.get("description", ""),
|
124 |
+
"language": test_case.get("language", ""),
|
125 |
+
"locale": test_case.get("locale", ""),
|
126 |
+
"input_text": test_case["input_text"],
|
127 |
+
"expected_entities": expected,
|
128 |
+
"detected_entities": detected,
|
129 |
+
"entity_metrics": entity_results,
|
130 |
+
"anonymized_text": result.anonymized_text if result.anonymized_text else None
|
131 |
+
}
|
132 |
+
detailed_results.append(detailed_result)
|
133 |
+
|
134 |
+
# Update pass/fail counts
|
135 |
+
if all(entity_results[et]["f1"] == 1.0 for et in entity_results):
|
136 |
+
metrics["passed"] += 1
|
137 |
+
else:
|
138 |
+
metrics["failed"] += 1
|
139 |
+
|
140 |
+
# Calculate final entity metrics
|
141 |
+
for entity_type, counts in metrics["entity_metrics"].items():
|
142 |
+
tp = counts["total_true_positives"]
|
143 |
+
fp = counts["total_false_positives"]
|
144 |
+
fn = counts["total_false_negatives"]
|
145 |
+
|
146 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
147 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
148 |
+
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
149 |
+
|
150 |
+
metrics["entity_metrics"][entity_type].update({
|
151 |
+
"precision": precision,
|
152 |
+
"recall": recall,
|
153 |
+
"f1": f1
|
154 |
+
})
|
155 |
+
|
156 |
+
return metrics, detailed_results
|
157 |
+
|
158 |
+
def save_results(metrics: Dict, detailed_results: List[Dict], model_name: str, output_dir: str = "evaluation_results"):
|
159 |
+
"""Save evaluation results to files"""
|
160 |
+
output_dir = Path(output_dir)
|
161 |
+
output_dir.mkdir(exist_ok=True)
|
162 |
+
|
163 |
+
# Save metrics summary
|
164 |
+
with open(output_dir / f"{model_name}_metrics.json", "w") as f:
|
165 |
+
json.dump(metrics, f, indent=2)
|
166 |
+
|
167 |
+
# Save detailed results
|
168 |
+
with open(output_dir / f"{model_name}_detailed_results.json", "w") as f:
|
169 |
+
json.dump(detailed_results, f, indent=2)
|
170 |
+
|
171 |
+
def print_metrics_summary(metrics: Dict):
|
172 |
+
"""Print a summary of the evaluation metrics"""
|
173 |
+
print("\nEvaluation Summary")
|
174 |
+
print("=" * 80)
|
175 |
+
print(f"Total Samples: {metrics['total']}")
|
176 |
+
print(f"Passed: {metrics['passed']}")
|
177 |
+
print(f"Failed: {metrics['failed']}")
|
178 |
+
print(f"Success Rate: {(metrics['passed']/metrics['total'])*100:.1f}%")
|
179 |
+
|
180 |
+
print("\nEntity-level Metrics:")
|
181 |
+
print("-" * 80)
|
182 |
+
print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
|
183 |
+
print("-" * 80)
|
184 |
+
for entity_type, entity_metrics in metrics["entity_metrics"].items():
|
185 |
+
print(f"{entity_type:<20} {entity_metrics['precision']:>10.2f} {entity_metrics['recall']:>10.2f} {entity_metrics['f1']:>10.2f}")
|
186 |
+
|
187 |
+
def main():
|
188 |
+
"""Main evaluation function"""
|
189 |
+
weave.init("guardrails-genie-pii-evaluation")
|
190 |
+
|
191 |
+
# Load test cases
|
192 |
+
test_cases = load_ai4privacy_dataset(num_samples=100)
|
193 |
+
|
194 |
+
# Initialize models to evaluate
|
195 |
+
models = {
|
196 |
+
"regex": RegexEntityRecognitionGuardrail(should_anonymize=True),
|
197 |
+
"presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True),
|
198 |
+
"transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True)
|
199 |
+
}
|
200 |
+
|
201 |
+
# Evaluate each model
|
202 |
+
for model_name, guardrail in models.items():
|
203 |
+
print(f"\nEvaluating {model_name} model...")
|
204 |
+
metrics, detailed_results = evaluate_model(guardrail, test_cases)
|
205 |
+
|
206 |
+
# Print and save results
|
207 |
+
print_metrics_summary(metrics)
|
208 |
+
save_results(metrics, detailed_results, model_name)
|
209 |
+
|
210 |
+
if __name__ == "__main__":
|
211 |
+
from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
|
212 |
+
from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
|
213 |
+
from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
|
214 |
+
|
215 |
+
main()
|
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_test_examples.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Collection of PII test examples with expected outcomes for entity recognition testing.
|
3 |
+
Each example includes the input text and expected entities to be detected.
|
4 |
+
"""
|
5 |
+
|
6 |
+
PII_TEST_EXAMPLES = [
|
7 |
+
{
|
8 |
+
"description": "Business Context - Employee Record",
|
9 |
+
"input_text": """
|
10 |
+
Please update our records for employee John Smith:
|
11 |
+
Email: john.smith@company.com
|
12 |
+
Phone: 123-456-7890
|
13 |
+
SSN: 123-45-6789
|
14 |
+
Emergency Contact: Mary Johnson (Tel: 098-765-4321)
|
15 |
+
""",
|
16 |
+
"expected_entities": {
|
17 |
+
"GIVENNAME": ["John", "Mary"],
|
18 |
+
"SURNAME": ["Smith", "Johnson"],
|
19 |
+
"EMAIL": ["john.smith@company.com"],
|
20 |
+
"PHONE_NUMBER": ["123-456-7890", "098-765-4321"],
|
21 |
+
"SOCIALNUM": ["123-45-6789"]
|
22 |
+
}
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"description": "Meeting Notes with Attendees",
|
26 |
+
"input_text": """
|
27 |
+
Meeting Notes - Project Alpha
|
28 |
+
Date: 2024-03-15
|
29 |
+
Attendees:
|
30 |
+
- Sarah Williams (sarah.w@company.com)
|
31 |
+
- Robert Brown (bobby@email.com)
|
32 |
+
- Tom Wilson (555-0123-4567)
|
33 |
+
|
34 |
+
Action Items:
|
35 |
+
1. Sarah to review documentation
|
36 |
+
2. Contact Bob at his alternate number: 777-888-9999
|
37 |
+
""",
|
38 |
+
"expected_entities": {
|
39 |
+
"GIVENNAME": ["Sarah", "Robert", "Tom", "Bob"],
|
40 |
+
"SURNAME": ["Williams", "Brown", "Wilson"],
|
41 |
+
"EMAIL": ["sarah.w@company.com", "bobby@email.com"],
|
42 |
+
"PHONE_NUMBER": ["555-0123-4567", "777-888-9999"]
|
43 |
+
}
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"description": "Medical Record",
|
47 |
+
"input_text": """
|
48 |
+
Patient: Emma Thompson
|
49 |
+
DOB: 05/15/1980
|
50 |
+
Medical Record #: MR-12345
|
51 |
+
Primary Care: Dr. James Wilson
|
52 |
+
Contact: emma.t@email.com
|
53 |
+
Insurance ID: INS-987654321
|
54 |
+
Emergency Contact: Michael Thompson (555-123-4567)
|
55 |
+
""",
|
56 |
+
"expected_entities": {
|
57 |
+
"GIVENNAME": ["Emma", "James", "Michael"],
|
58 |
+
"SURNAME": ["Thompson", "Wilson", "Thompson"],
|
59 |
+
"EMAIL": ["emma.t@email.com"],
|
60 |
+
"PHONE_NUMBER": ["555-123-4567"]
|
61 |
+
}
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"description": "No PII Content",
|
65 |
+
"input_text": """
|
66 |
+
Project Status Update:
|
67 |
+
- All deliverables are on track
|
68 |
+
- Budget is within limits
|
69 |
+
- Next review scheduled for next week
|
70 |
+
""",
|
71 |
+
"expected_entities": {}
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"description": "Mixed Format Phone Numbers",
|
75 |
+
"input_text": """
|
76 |
+
Contact Directory:
|
77 |
+
Main Office: (555) 123-4567
|
78 |
+
Support: 555.987.6543
|
79 |
+
International: +1-555-321-7890
|
80 |
+
Emergency: 555 444 3333
|
81 |
+
""",
|
82 |
+
"expected_entities": {
|
83 |
+
"PHONE_NUMBER": [
|
84 |
+
"(555) 123-4567",
|
85 |
+
"555.987.6543",
|
86 |
+
"+1-555-321-7890",
|
87 |
+
"555 444 3333"
|
88 |
+
]
|
89 |
+
}
|
90 |
+
}
|
91 |
+
]
|
92 |
+
|
93 |
+
# Additional examples can be added to test specific edge cases or formats
|
94 |
+
EDGE_CASE_EXAMPLES = [
|
95 |
+
{
|
96 |
+
"description": "Mixed Case and Special Characters",
|
97 |
+
"input_text": """
|
98 |
+
JoHn.DoE@company.com
|
99 |
+
JANE_SMITH@email.com
|
100 |
+
bob.jones123@domain.co.uk
|
101 |
+
""",
|
102 |
+
"expected_entities": {
|
103 |
+
"EMAIL": [
|
104 |
+
"JoHn.DoE@company.com",
|
105 |
+
"JANE_SMITH@email.com",
|
106 |
+
"bob.jones123@domain.co.uk"
|
107 |
+
],
|
108 |
+
"GIVENNAME": ["John", "Jane", "Bob"],
|
109 |
+
"SURNAME": ["Doe", "Smith", "Jones"]
|
110 |
+
}
|
111 |
+
}
|
112 |
+
]
|
113 |
+
|
114 |
+
def validate_entities(detected: dict, expected: dict) -> bool:
|
115 |
+
"""Compare detected entities with expected entities"""
|
116 |
+
if set(detected.keys()) != set(expected.keys()):
|
117 |
+
return False
|
118 |
+
return all(set(detected[k]) == set(expected[k]) for k in expected.keys())
|
119 |
+
|
120 |
+
def run_test_case(guardrail, test_case, test_type="Main"):
|
121 |
+
"""Run a single test case and print results"""
|
122 |
+
print(f"\n{test_type} Test Case: {test_case['description']}")
|
123 |
+
print("-" * 50)
|
124 |
+
|
125 |
+
result = guardrail.guard(test_case['input_text'])
|
126 |
+
expected = test_case['expected_entities']
|
127 |
+
|
128 |
+
# Validate results
|
129 |
+
matches = validate_entities(result.detected_entities, expected)
|
130 |
+
|
131 |
+
print(f"Test Status: {'✓ PASS' if matches else '✗ FAIL'}")
|
132 |
+
print(f"Contains PII: {result.contains_entities}")
|
133 |
+
|
134 |
+
if not matches:
|
135 |
+
print("\nEntity Comparison:")
|
136 |
+
all_entity_types = set(list(result.detected_entities.keys()) + list(expected.keys()))
|
137 |
+
for entity_type in all_entity_types:
|
138 |
+
detected = set(result.detected_entities.get(entity_type, []))
|
139 |
+
expected_set = set(expected.get(entity_type, []))
|
140 |
+
print(f"\nEntity Type: {entity_type}")
|
141 |
+
print(f" Expected: {sorted(expected_set)}")
|
142 |
+
print(f" Detected: {sorted(detected)}")
|
143 |
+
if detected != expected_set:
|
144 |
+
print(f" Missing: {sorted(expected_set - detected)}")
|
145 |
+
print(f" Extra: {sorted(detected - expected_set)}")
|
146 |
+
|
147 |
+
if result.anonymized_text:
|
148 |
+
print(f"\nAnonymized Text:\n{result.anonymized_text}")
|
149 |
+
|
150 |
+
return matches
|
guardrails_genie/guardrails/entity_recognition/pii_examples/run_presidio_model.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
|
2 |
+
from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import PII_TEST_EXAMPLES, EDGE_CASE_EXAMPLES, run_test_case, validate_entities
|
3 |
+
import weave
|
4 |
+
|
5 |
+
def test_pii_detection():
|
6 |
+
"""Test PII detection scenarios using predefined test cases"""
|
7 |
+
weave.init("guardrails-genie-pii-presidio-model")
|
8 |
+
|
9 |
+
# Create the guardrail with default entities and anonymization enabled
|
10 |
+
pii_guardrail = PresidioEntityRecognitionGuardrail(
|
11 |
+
should_anonymize=True,
|
12 |
+
show_available_entities=True
|
13 |
+
)
|
14 |
+
|
15 |
+
# Test statistics
|
16 |
+
total_tests = len(PII_TEST_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
|
17 |
+
passed_tests = 0
|
18 |
+
|
19 |
+
# Test main PII examples
|
20 |
+
print("\nRunning Main PII Tests")
|
21 |
+
print("=" * 80)
|
22 |
+
for test_case in PII_TEST_EXAMPLES:
|
23 |
+
if run_test_case(pii_guardrail, test_case):
|
24 |
+
passed_tests += 1
|
25 |
+
|
26 |
+
# Test edge cases
|
27 |
+
print("\nRunning Edge Cases")
|
28 |
+
print("=" * 80)
|
29 |
+
for test_case in EDGE_CASE_EXAMPLES:
|
30 |
+
if run_test_case(pii_guardrail, test_case, "Edge"):
|
31 |
+
passed_tests += 1
|
32 |
+
|
33 |
+
# Print summary
|
34 |
+
print("\nTest Summary")
|
35 |
+
print("=" * 80)
|
36 |
+
print(f"Total Tests: {total_tests}")
|
37 |
+
print(f"Passed: {passed_tests}")
|
38 |
+
print(f"Failed: {total_tests - passed_tests}")
|
39 |
+
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
40 |
+
|
41 |
+
if __name__ == "__main__":
|
42 |
+
test_pii_detection()
|
guardrails_genie/guardrails/entity_recognition/pii_examples/run_regex_model.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
|
2 |
+
from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import PII_TEST_EXAMPLES, EDGE_CASE_EXAMPLES, run_test_case, validate_entities
|
3 |
+
import weave
|
4 |
+
|
5 |
+
def test_pii_detection():
|
6 |
+
"""Test PII detection scenarios using predefined test cases"""
|
7 |
+
weave.init("guardrails-genie-pii-regex-model")
|
8 |
+
|
9 |
+
# Create the guardrail with default entities and anonymization enabled
|
10 |
+
pii_guardrail = RegexEntityRecognitionGuardrail(
|
11 |
+
should_anonymize=True,
|
12 |
+
show_available_entities=True
|
13 |
+
)
|
14 |
+
|
15 |
+
# Test statistics
|
16 |
+
total_tests = len(PII_TEST_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
|
17 |
+
passed_tests = 0
|
18 |
+
|
19 |
+
# Test main PII examples
|
20 |
+
print("\nRunning Main PII Tests")
|
21 |
+
print("=" * 80)
|
22 |
+
for test_case in PII_TEST_EXAMPLES:
|
23 |
+
if run_test_case(pii_guardrail, test_case):
|
24 |
+
passed_tests += 1
|
25 |
+
|
26 |
+
# Test edge cases
|
27 |
+
print("\nRunning Edge Cases")
|
28 |
+
print("=" * 80)
|
29 |
+
for test_case in EDGE_CASE_EXAMPLES:
|
30 |
+
if run_test_case(pii_guardrail, test_case, "Edge"):
|
31 |
+
passed_tests += 1
|
32 |
+
|
33 |
+
# Print summary
|
34 |
+
print("\nTest Summary")
|
35 |
+
print("=" * 80)
|
36 |
+
print(f"Total Tests: {total_tests}")
|
37 |
+
print(f"Passed: {passed_tests}")
|
38 |
+
print(f"Failed: {total_tests - passed_tests}")
|
39 |
+
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
40 |
+
|
41 |
+
if __name__ == "__main__":
|
42 |
+
test_pii_detection()
|
guardrails_genie/guardrails/entity_recognition/pii_examples/run_transformers.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
|
2 |
+
from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import PII_TEST_EXAMPLES, EDGE_CASE_EXAMPLES, run_test_case, validate_entities
|
3 |
+
import weave
|
4 |
+
|
5 |
+
def test_pii_detection():
|
6 |
+
"""Test PII detection scenarios using predefined test cases"""
|
7 |
+
weave.init("guardrails-genie-pii-transformers-pipeline-model")
|
8 |
+
|
9 |
+
# Create the guardrail with default entities and anonymization enabled
|
10 |
+
pii_guardrail = TransformersEntityRecognitionGuardrail(
|
11 |
+
selected_entities=["GIVENNAME", "SURNAME", "EMAIL", "TELEPHONENUM", "SOCIALNUM"],
|
12 |
+
should_anonymize=True,
|
13 |
+
show_available_entities=True
|
14 |
+
)
|
15 |
+
|
16 |
+
# Test statistics
|
17 |
+
total_tests = len(PII_TEST_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
|
18 |
+
passed_tests = 0
|
19 |
+
|
20 |
+
# Test main PII examples
|
21 |
+
print("\nRunning Main PII Tests")
|
22 |
+
print("=" * 80)
|
23 |
+
for test_case in PII_TEST_EXAMPLES:
|
24 |
+
if run_test_case(pii_guardrail, test_case):
|
25 |
+
passed_tests += 1
|
26 |
+
|
27 |
+
# Test edge cases
|
28 |
+
print("\nRunning Edge Cases")
|
29 |
+
print("=" * 80)
|
30 |
+
for test_case in EDGE_CASE_EXAMPLES:
|
31 |
+
if run_test_case(pii_guardrail, test_case, "Edge"):
|
32 |
+
passed_tests += 1
|
33 |
+
|
34 |
+
# Print summary
|
35 |
+
print("\nTest Summary")
|
36 |
+
print("=" * 80)
|
37 |
+
print(f"Total Tests: {total_tests}")
|
38 |
+
print(f"Passed: {passed_tests}")
|
39 |
+
print(f"Failed: {total_tests - passed_tests}")
|
40 |
+
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
test_pii_detection()
|
guardrails_genie/guardrails/entity_recognition/presidio_entity_recognition_guardrail.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Optional, ClassVar, Any
|
2 |
+
import weave
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
from presidio_analyzer import AnalyzerEngine, RecognizerRegistry, Pattern, PatternRecognizer
|
6 |
+
from presidio_anonymizer import AnonymizerEngine
|
7 |
+
|
8 |
+
from ..base import Guardrail
|
9 |
+
|
10 |
+
class PresidioEntityRecognitionResponse(BaseModel):
|
11 |
+
contains_entities: bool
|
12 |
+
detected_entities: Dict[str, List[str]]
|
13 |
+
explanation: str
|
14 |
+
anonymized_text: Optional[str] = None
|
15 |
+
|
16 |
+
@property
|
17 |
+
def safe(self) -> bool:
|
18 |
+
return not self.contains_entities
|
19 |
+
|
20 |
+
class PresidioEntityRecognitionSimpleResponse(BaseModel):
|
21 |
+
contains_entities: bool
|
22 |
+
explanation: str
|
23 |
+
anonymized_text: Optional[str] = None
|
24 |
+
|
25 |
+
@property
|
26 |
+
def safe(self) -> bool:
|
27 |
+
return not self.contains_entities
|
28 |
+
|
29 |
+
#TODO: Add support for transformers workflow and not just Spacy
|
30 |
+
class PresidioEntityRecognitionGuardrail(Guardrail):
|
31 |
+
@staticmethod
|
32 |
+
def get_available_entities() -> List[str]:
|
33 |
+
registry = RecognizerRegistry()
|
34 |
+
analyzer = AnalyzerEngine(registry=registry)
|
35 |
+
return [recognizer.supported_entities[0]
|
36 |
+
for recognizer in analyzer.registry.recognizers]
|
37 |
+
|
38 |
+
analyzer: AnalyzerEngine
|
39 |
+
anonymizer: AnonymizerEngine
|
40 |
+
selected_entities: List[str]
|
41 |
+
should_anonymize: bool
|
42 |
+
language: str
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
selected_entities: Optional[List[str]] = None,
|
47 |
+
should_anonymize: bool = False,
|
48 |
+
language: str = "en",
|
49 |
+
deny_lists: Optional[Dict[str, List[str]]] = None,
|
50 |
+
regex_patterns: Optional[Dict[str, List[Dict[str, str]]]] = None,
|
51 |
+
custom_recognizers: Optional[List[Any]] = None,
|
52 |
+
show_available_entities: bool = False
|
53 |
+
):
|
54 |
+
# If show_available_entities is True, print available entities
|
55 |
+
if show_available_entities:
|
56 |
+
available_entities = self.get_available_entities()
|
57 |
+
print("\nAvailable entities:")
|
58 |
+
print("=" * 25)
|
59 |
+
for entity in available_entities:
|
60 |
+
print(f"- {entity}")
|
61 |
+
print("=" * 25 + "\n")
|
62 |
+
|
63 |
+
# Initialize default values
|
64 |
+
if selected_entities is None:
|
65 |
+
selected_entities = [
|
66 |
+
"CREDIT_CARD", "US_SSN", "EMAIL_ADDRESS", "PHONE_NUMBER",
|
67 |
+
"IP_ADDRESS", "URL", "DATE_TIME"
|
68 |
+
]
|
69 |
+
|
70 |
+
# Get available entities dynamically
|
71 |
+
available_entities = self.get_available_entities()
|
72 |
+
|
73 |
+
# Filter out invalid entities and warn user
|
74 |
+
invalid_entities = [e for e in selected_entities if e not in available_entities]
|
75 |
+
valid_entities = [e for e in selected_entities if e in available_entities]
|
76 |
+
|
77 |
+
if invalid_entities:
|
78 |
+
print(f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}")
|
79 |
+
print(f"Continuing with valid entities: {valid_entities}")
|
80 |
+
selected_entities = valid_entities
|
81 |
+
|
82 |
+
# Initialize analyzer with default recognizers
|
83 |
+
analyzer = AnalyzerEngine()
|
84 |
+
|
85 |
+
# Add custom recognizers if provided
|
86 |
+
if custom_recognizers:
|
87 |
+
for recognizer in custom_recognizers:
|
88 |
+
analyzer.registry.add_recognizer(recognizer)
|
89 |
+
|
90 |
+
# Add deny list recognizers if provided
|
91 |
+
if deny_lists:
|
92 |
+
for entity_type, tokens in deny_lists.items():
|
93 |
+
deny_list_recognizer = PatternRecognizer(
|
94 |
+
supported_entity=entity_type,
|
95 |
+
deny_list=tokens
|
96 |
+
)
|
97 |
+
analyzer.registry.add_recognizer(deny_list_recognizer)
|
98 |
+
|
99 |
+
# Add regex pattern recognizers if provided
|
100 |
+
if regex_patterns:
|
101 |
+
for entity_type, patterns in regex_patterns.items():
|
102 |
+
presidio_patterns = [
|
103 |
+
Pattern(
|
104 |
+
name=pattern.get("name", f"pattern_{i}"),
|
105 |
+
regex=pattern["regex"],
|
106 |
+
score=pattern.get("score", 0.5)
|
107 |
+
) for i, pattern in enumerate(patterns)
|
108 |
+
]
|
109 |
+
regex_recognizer = PatternRecognizer(
|
110 |
+
supported_entity=entity_type,
|
111 |
+
patterns=presidio_patterns
|
112 |
+
)
|
113 |
+
analyzer.registry.add_recognizer(regex_recognizer)
|
114 |
+
|
115 |
+
# Initialize Presidio engines
|
116 |
+
anonymizer = AnonymizerEngine()
|
117 |
+
|
118 |
+
# Call parent class constructor with all fields
|
119 |
+
super().__init__(
|
120 |
+
analyzer=analyzer,
|
121 |
+
anonymizer=anonymizer,
|
122 |
+
selected_entities=selected_entities,
|
123 |
+
should_anonymize=should_anonymize,
|
124 |
+
language=language
|
125 |
+
)
|
126 |
+
|
127 |
+
@weave.op()
|
128 |
+
def guard(self, prompt: str, return_detected_types: bool = True, **kwargs) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
|
129 |
+
"""
|
130 |
+
Check if the input prompt contains any entities using Presidio.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
prompt: The text to analyze
|
134 |
+
return_detected_types: If True, returns detailed entity type information
|
135 |
+
"""
|
136 |
+
# Analyze text for entities
|
137 |
+
analyzer_results = self.analyzer.analyze(
|
138 |
+
text=prompt,
|
139 |
+
entities=self.selected_entities,
|
140 |
+
language=self.language
|
141 |
+
)
|
142 |
+
|
143 |
+
# Group results by entity type
|
144 |
+
detected_entities = {}
|
145 |
+
for result in analyzer_results:
|
146 |
+
entity_type = result.entity_type
|
147 |
+
text_slice = prompt[result.start:result.end]
|
148 |
+
if entity_type not in detected_entities:
|
149 |
+
detected_entities[entity_type] = []
|
150 |
+
detected_entities[entity_type].append(text_slice)
|
151 |
+
|
152 |
+
# Create explanation
|
153 |
+
explanation_parts = []
|
154 |
+
if detected_entities:
|
155 |
+
explanation_parts.append("Found the following entities in the text:")
|
156 |
+
for entity_type, instances in detected_entities.items():
|
157 |
+
explanation_parts.append(f"- {entity_type}: {len(instances)} instance(s)")
|
158 |
+
else:
|
159 |
+
explanation_parts.append("No entities detected in the text.")
|
160 |
+
|
161 |
+
# Add information about what was checked
|
162 |
+
explanation_parts.append("\nChecked for these entity types:")
|
163 |
+
for entity in self.selected_entities:
|
164 |
+
explanation_parts.append(f"- {entity}")
|
165 |
+
|
166 |
+
# Anonymize if requested
|
167 |
+
anonymized_text = None
|
168 |
+
if self.should_anonymize and detected_entities:
|
169 |
+
anonymized_result = self.anonymizer.anonymize(
|
170 |
+
text=prompt,
|
171 |
+
analyzer_results=analyzer_results
|
172 |
+
)
|
173 |
+
anonymized_text = anonymized_result.text
|
174 |
+
|
175 |
+
if return_detected_types:
|
176 |
+
return PresidioEntityRecognitionResponse(
|
177 |
+
contains_entities=bool(detected_entities),
|
178 |
+
detected_entities=detected_entities,
|
179 |
+
explanation="\n".join(explanation_parts),
|
180 |
+
anonymized_text=anonymized_text
|
181 |
+
)
|
182 |
+
else:
|
183 |
+
return PresidioEntityRecognitionSimpleResponse(
|
184 |
+
contains_entities=bool(detected_entities),
|
185 |
+
explanation="\n".join(explanation_parts),
|
186 |
+
anonymized_text=anonymized_text
|
187 |
+
)
|
188 |
+
|
189 |
+
@weave.op()
|
190 |
+
def predict(self, prompt: str, return_detected_types: bool = True, **kwargs) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
|
191 |
+
return self.guard(prompt, return_detected_types=return_detected_types, **kwargs)
|
guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, ClassVar
|
2 |
+
|
3 |
+
import weave
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
from ...regex_model import RegexModel
|
7 |
+
from ..base import Guardrail
|
8 |
+
import re
|
9 |
+
|
10 |
+
|
11 |
+
class RegexEntityRecognitionResponse(BaseModel):
|
12 |
+
contains_entities: bool
|
13 |
+
detected_entities: Dict[str, list[str]]
|
14 |
+
explanation: str
|
15 |
+
anonymized_text: Optional[str] = None
|
16 |
+
|
17 |
+
@property
|
18 |
+
def safe(self) -> bool:
|
19 |
+
return not self.contains_entities
|
20 |
+
|
21 |
+
|
22 |
+
class RegexEntityRecognitionSimpleResponse(BaseModel):
|
23 |
+
contains_entities: bool
|
24 |
+
explanation: str
|
25 |
+
anonymized_text: Optional[str] = None
|
26 |
+
|
27 |
+
@property
|
28 |
+
def safe(self) -> bool:
|
29 |
+
return not self.contains_entities
|
30 |
+
|
31 |
+
|
32 |
+
class RegexEntityRecognitionGuardrail(Guardrail):
|
33 |
+
regex_model: RegexModel
|
34 |
+
patterns: Dict[str, str] = {}
|
35 |
+
should_anonymize: bool = False
|
36 |
+
|
37 |
+
DEFAULT_PATTERNS: ClassVar[Dict[str, str]] = {
|
38 |
+
"email": r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
|
39 |
+
"phone_number": r"\b(?:\+?1[-.]?)?\(?(?:[0-9]{3})\)?[-.]?(?:[0-9]{3})[-.]?(?:[0-9]{4})\b",
|
40 |
+
"ssn": r"\b\d{3}[-]?\d{2}[-]?\d{4}\b",
|
41 |
+
"credit_card": r"\b\d{4}[-.]?\d{4}[-.]?\d{4}[-.]?\d{4}\b",
|
42 |
+
"ip_address": r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b",
|
43 |
+
"date_of_birth": r"\b\d{2}[-/]\d{2}[-/]\d{4}\b",
|
44 |
+
"passport": r"\b[A-Z]{1,2}[0-9]{6,9}\b",
|
45 |
+
"drivers_license": r"\b[A-Z]\d{7}\b",
|
46 |
+
"bank_account": r"\b\d{8,17}\b",
|
47 |
+
"zip_code": r"\b\d{5}(?:[-]\d{4})?\b"
|
48 |
+
}
|
49 |
+
|
50 |
+
def __init__(self, use_defaults: bool = True, should_anonymize: bool = False, **kwargs):
|
51 |
+
patterns = {}
|
52 |
+
if use_defaults:
|
53 |
+
patterns = self.DEFAULT_PATTERNS.copy()
|
54 |
+
if kwargs.get("patterns"):
|
55 |
+
patterns.update(kwargs["patterns"])
|
56 |
+
|
57 |
+
# Create the RegexModel instance
|
58 |
+
regex_model = RegexModel(patterns=patterns)
|
59 |
+
|
60 |
+
# Initialize the base class with both the regex_model and patterns
|
61 |
+
super().__init__(
|
62 |
+
regex_model=regex_model,
|
63 |
+
patterns=patterns,
|
64 |
+
should_anonymize=should_anonymize
|
65 |
+
)
|
66 |
+
|
67 |
+
def text_to_pattern(self, text: str) -> str:
|
68 |
+
"""
|
69 |
+
Convert input text into a regex pattern that matches the exact text.
|
70 |
+
"""
|
71 |
+
# Escape special regex characters in the text
|
72 |
+
escaped_text = re.escape(text)
|
73 |
+
# Create a pattern that matches the exact text, case-insensitive
|
74 |
+
return rf"\b{escaped_text}\b"
|
75 |
+
|
76 |
+
@weave.op()
|
77 |
+
def guard(self, prompt: str, custom_terms: Optional[list[str]] = None, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
|
78 |
+
"""
|
79 |
+
Check if the input prompt contains any entities based on the regex patterns.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
prompt: Input text to check for entities
|
83 |
+
custom_terms: List of custom terms to be converted into regex patterns. If provided,
|
84 |
+
only these terms will be checked, ignoring default patterns.
|
85 |
+
return_detected_types: If True, returns detailed entity type information
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
RegexEntityRecognitionResponse or RegexEntityRecognitionSimpleResponse containing detection results
|
89 |
+
"""
|
90 |
+
if custom_terms:
|
91 |
+
# Create a temporary RegexModel with only the custom patterns
|
92 |
+
temp_patterns = {term: self.text_to_pattern(term) for term in custom_terms}
|
93 |
+
temp_model = RegexModel(patterns=temp_patterns)
|
94 |
+
result = temp_model.check(prompt)
|
95 |
+
else:
|
96 |
+
# Use the original regex_model if no custom terms provided
|
97 |
+
result = self.regex_model.check(prompt)
|
98 |
+
|
99 |
+
# Create detailed explanation
|
100 |
+
explanation_parts = []
|
101 |
+
if result.matched_patterns:
|
102 |
+
explanation_parts.append("Found the following entities in the text:")
|
103 |
+
for entity_type, matches in result.matched_patterns.items():
|
104 |
+
explanation_parts.append(f"- {entity_type}: {len(matches)} instance(s)")
|
105 |
+
else:
|
106 |
+
explanation_parts.append("No entities detected in the text.")
|
107 |
+
|
108 |
+
if result.failed_patterns:
|
109 |
+
explanation_parts.append("\nChecked but did not find these entity types:")
|
110 |
+
for pattern in result.failed_patterns:
|
111 |
+
explanation_parts.append(f"- {pattern}")
|
112 |
+
|
113 |
+
# Updated anonymization logic
|
114 |
+
anonymized_text = None
|
115 |
+
if getattr(self, 'should_anonymize', False) and result.matched_patterns:
|
116 |
+
anonymized_text = prompt
|
117 |
+
for entity_type, matches in result.matched_patterns.items():
|
118 |
+
for match in matches:
|
119 |
+
replacement = "[redacted]" if aggregate_redaction else f"[{entity_type.upper()}]"
|
120 |
+
anonymized_text = anonymized_text.replace(match, replacement)
|
121 |
+
|
122 |
+
if return_detected_types:
|
123 |
+
return RegexEntityRecognitionResponse(
|
124 |
+
contains_entities=not result.passed,
|
125 |
+
detected_entities=result.matched_patterns,
|
126 |
+
explanation="\n".join(explanation_parts),
|
127 |
+
anonymized_text=anonymized_text
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
return RegexEntityRecognitionSimpleResponse(
|
131 |
+
contains_entities=not result.passed,
|
132 |
+
explanation="\n".join(explanation_parts),
|
133 |
+
anonymized_text=anonymized_text
|
134 |
+
)
|
135 |
+
|
136 |
+
@weave.op()
|
137 |
+
def predict(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
|
138 |
+
return self.guard(prompt, return_detected_types=return_detected_types, aggregate_redaction=aggregate_redaction, **kwargs)
|
guardrails_genie/guardrails/entity_recognition/transformers_entity_recognition_guardrail.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Optional, ClassVar
|
2 |
+
from transformers import pipeline, AutoConfig
|
3 |
+
import json
|
4 |
+
from pydantic import BaseModel
|
5 |
+
from ..base import Guardrail
|
6 |
+
import weave
|
7 |
+
|
8 |
+
class TransformersEntityRecognitionResponse(BaseModel):
|
9 |
+
contains_entities: bool
|
10 |
+
detected_entities: Dict[str, List[str]]
|
11 |
+
explanation: str
|
12 |
+
anonymized_text: Optional[str] = None
|
13 |
+
|
14 |
+
@property
|
15 |
+
def safe(self) -> bool:
|
16 |
+
return not self.contains_entities
|
17 |
+
|
18 |
+
class TransformersEntityRecognitionSimpleResponse(BaseModel):
|
19 |
+
contains_entities: bool
|
20 |
+
explanation: str
|
21 |
+
anonymized_text: Optional[str] = None
|
22 |
+
|
23 |
+
@property
|
24 |
+
def safe(self) -> bool:
|
25 |
+
return not self.contains_entities
|
26 |
+
|
27 |
+
class TransformersEntityRecognitionGuardrail(Guardrail):
|
28 |
+
"""Generic guardrail for detecting entities using any token classification model."""
|
29 |
+
|
30 |
+
_pipeline: Optional[object] = None
|
31 |
+
selected_entities: List[str]
|
32 |
+
should_anonymize: bool
|
33 |
+
available_entities: List[str]
|
34 |
+
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
model_name: str = "iiiorg/piiranha-v1-detect-personal-information",
|
38 |
+
selected_entities: Optional[List[str]] = None,
|
39 |
+
should_anonymize: bool = False,
|
40 |
+
show_available_entities: bool = True,
|
41 |
+
):
|
42 |
+
# Load model config and extract available entities
|
43 |
+
config = AutoConfig.from_pretrained(model_name)
|
44 |
+
entities = self._extract_entities_from_config(config)
|
45 |
+
|
46 |
+
if show_available_entities:
|
47 |
+
self._print_available_entities(entities)
|
48 |
+
|
49 |
+
# Initialize default values if needed
|
50 |
+
if selected_entities is None:
|
51 |
+
selected_entities = entities # Use all available entities by default
|
52 |
+
|
53 |
+
# Filter out invalid entities and warn user
|
54 |
+
invalid_entities = [e for e in selected_entities if e not in entities]
|
55 |
+
valid_entities = [e for e in selected_entities if e in entities]
|
56 |
+
|
57 |
+
if invalid_entities:
|
58 |
+
print(f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}")
|
59 |
+
print(f"Continuing with valid entities: {valid_entities}")
|
60 |
+
selected_entities = valid_entities
|
61 |
+
|
62 |
+
# Call parent class constructor
|
63 |
+
super().__init__(
|
64 |
+
selected_entities=selected_entities,
|
65 |
+
should_anonymize=should_anonymize,
|
66 |
+
available_entities=entities
|
67 |
+
)
|
68 |
+
|
69 |
+
# Initialize pipeline
|
70 |
+
self._pipeline = pipeline(
|
71 |
+
task="token-classification",
|
72 |
+
model=model_name,
|
73 |
+
aggregation_strategy="simple" # Merge same entities
|
74 |
+
)
|
75 |
+
|
76 |
+
def _extract_entities_from_config(self, config) -> List[str]:
|
77 |
+
"""Extract unique entity types from the model config."""
|
78 |
+
# Get id2label mapping from config
|
79 |
+
id2label = config.id2label
|
80 |
+
|
81 |
+
# Extract unique entity types (removing B- and I- prefixes)
|
82 |
+
entities = set()
|
83 |
+
for label in id2label.values():
|
84 |
+
if label.startswith(('B-', 'I-')):
|
85 |
+
entities.add(label[2:]) # Remove prefix
|
86 |
+
elif label != 'O': # Skip the 'O' (Outside) label
|
87 |
+
entities.add(label)
|
88 |
+
|
89 |
+
return sorted(list(entities))
|
90 |
+
|
91 |
+
def _print_available_entities(self, entities: List[str]):
|
92 |
+
"""Print all available entity types that can be detected by the model."""
|
93 |
+
print("\nAvailable entity types:")
|
94 |
+
print("=" * 25)
|
95 |
+
for entity in entities:
|
96 |
+
print(f"- {entity}")
|
97 |
+
print("=" * 25 + "\n")
|
98 |
+
|
99 |
+
def print_available_entities(self):
|
100 |
+
"""Print all available entity types that can be detected by the model."""
|
101 |
+
self._print_available_entities(self.available_entities)
|
102 |
+
|
103 |
+
def _detect_entities(self, text: str) -> Dict[str, List[str]]:
|
104 |
+
"""Detect entities in the text using the pipeline."""
|
105 |
+
results = self._pipeline(text)
|
106 |
+
|
107 |
+
# Group findings by entity type
|
108 |
+
detected_entities = {}
|
109 |
+
for entity in results:
|
110 |
+
entity_type = entity['entity_group']
|
111 |
+
if entity_type in self.selected_entities:
|
112 |
+
if entity_type not in detected_entities:
|
113 |
+
detected_entities[entity_type] = []
|
114 |
+
detected_entities[entity_type].append(entity['word'])
|
115 |
+
|
116 |
+
return detected_entities
|
117 |
+
|
118 |
+
def _anonymize_text(self, text: str, aggregate_redaction: bool = True) -> str:
|
119 |
+
"""Anonymize detected entities in text using the pipeline."""
|
120 |
+
results = self._pipeline(text)
|
121 |
+
|
122 |
+
# Sort entities by start position in reverse order to avoid offset issues
|
123 |
+
entities = sorted(results, key=lambda x: x['start'], reverse=True)
|
124 |
+
|
125 |
+
# Create a mutable list of characters
|
126 |
+
chars = list(text)
|
127 |
+
|
128 |
+
# Apply redactions
|
129 |
+
for entity in entities:
|
130 |
+
if entity['entity_group'] in self.selected_entities:
|
131 |
+
start, end = entity['start'], entity['end']
|
132 |
+
replacement = ' [redacted] ' if aggregate_redaction else f" [{entity['entity_group']}] "
|
133 |
+
|
134 |
+
# Replace the entity with the redaction marker
|
135 |
+
chars[start:end] = replacement
|
136 |
+
|
137 |
+
# Join characters and clean up only consecutive spaces (preserving newlines)
|
138 |
+
result = ''.join(chars)
|
139 |
+
# Replace multiple spaces with single space, but preserve newlines
|
140 |
+
lines = result.split('\n')
|
141 |
+
cleaned_lines = [' '.join(line.split()) for line in lines]
|
142 |
+
return '\n'.join(cleaned_lines)
|
143 |
+
|
144 |
+
@weave.op()
|
145 |
+
def guard(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True) -> TransformersEntityRecognitionResponse | TransformersEntityRecognitionSimpleResponse:
|
146 |
+
"""Check if the input prompt contains any entities using the transformer pipeline.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
prompt: The text to analyze
|
150 |
+
return_detected_types: If True, returns detailed entity type information
|
151 |
+
aggregate_redaction: If True, uses generic [redacted] instead of entity type
|
152 |
+
"""
|
153 |
+
# Detect entities
|
154 |
+
detected_entities = self._detect_entities(prompt)
|
155 |
+
|
156 |
+
# Create explanation
|
157 |
+
explanation_parts = []
|
158 |
+
if detected_entities:
|
159 |
+
explanation_parts.append("Found the following entities in the text:")
|
160 |
+
for entity_type, instances in detected_entities.items():
|
161 |
+
explanation_parts.append(f"- {entity_type}: {len(instances)} instance(s)")
|
162 |
+
else:
|
163 |
+
explanation_parts.append("No entities detected in the text.")
|
164 |
+
|
165 |
+
explanation_parts.append("\nChecked for these entities:")
|
166 |
+
for entity in self.selected_entities:
|
167 |
+
explanation_parts.append(f"- {entity}")
|
168 |
+
|
169 |
+
# Anonymize if requested
|
170 |
+
anonymized_text = None
|
171 |
+
if self.should_anonymize and detected_entities:
|
172 |
+
anonymized_text = self._anonymize_text(prompt, aggregate_redaction)
|
173 |
+
|
174 |
+
if return_detected_types:
|
175 |
+
return TransformersEntityRecognitionResponse(
|
176 |
+
contains_entities=bool(detected_entities),
|
177 |
+
detected_entities=detected_entities,
|
178 |
+
explanation="\n".join(explanation_parts),
|
179 |
+
anonymized_text=anonymized_text
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
return TransformersEntityRecognitionSimpleResponse(
|
183 |
+
contains_entities=bool(detected_entities),
|
184 |
+
explanation="\n".join(explanation_parts),
|
185 |
+
anonymized_text=anonymized_text
|
186 |
+
)
|
187 |
+
|
188 |
+
@weave.op()
|
189 |
+
def predict(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> TransformersEntityRecognitionResponse | TransformersEntityRecognitionSimpleResponse:
|
190 |
+
return self.guard(prompt, return_detected_types=return_detected_types, aggregate_redaction=aggregate_redaction, **kwargs)
|
guardrails_genie/guardrails/manager.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import weave
|
2 |
from rich.progress import track
|
|
|
3 |
|
4 |
from .base import Guardrail
|
5 |
|
@@ -20,10 +21,12 @@ class GuardrailManager(weave.Model):
|
|
20 |
alerts.append(
|
21 |
{"guardrail_name": guardrail.__class__.__name__, "response": response}
|
22 |
)
|
23 |
-
|
24 |
-
|
25 |
-
f"**{guardrail.__class__.__name__}**: {response
|
26 |
-
|
|
|
|
|
27 |
return {"safe": safe, "alerts": alerts, "summary": summaries}
|
28 |
|
29 |
@weave.op()
|
|
|
1 |
import weave
|
2 |
from rich.progress import track
|
3 |
+
from pydantic import BaseModel
|
4 |
|
5 |
from .base import Guardrail
|
6 |
|
|
|
21 |
alerts.append(
|
22 |
{"guardrail_name": guardrail.__class__.__name__, "response": response}
|
23 |
)
|
24 |
+
if isinstance(response, BaseModel):
|
25 |
+
safe = safe and response.safe
|
26 |
+
summaries += f"**{guardrail.__class__.__name__}**: {response.explanation}\n\n---\n\n"
|
27 |
+
else:
|
28 |
+
safe = safe and response["safe"]
|
29 |
+
summaries += f"**{guardrail.__class__.__name__}**: {response['summary']}\n\n---\n\n"
|
30 |
return {"safe": safe, "alerts": alerts, "summary": summaries}
|
31 |
|
32 |
@weave.op()
|
guardrails_genie/regex_model.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Optional
|
2 |
+
import re
|
3 |
+
import weave
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
|
7 |
+
class RegexResult(BaseModel):
|
8 |
+
passed: bool
|
9 |
+
matched_patterns: Dict[str, List[str]]
|
10 |
+
failed_patterns: List[str]
|
11 |
+
|
12 |
+
|
13 |
+
class RegexModel(weave.Model):
|
14 |
+
patterns: Dict[str, str]
|
15 |
+
|
16 |
+
def __init__(self, patterns: Dict[str, str]) -> None:
|
17 |
+
"""
|
18 |
+
Initialize RegexModel with a dictionary of patterns.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
patterns: Dictionary where key is pattern name and value is regex pattern
|
22 |
+
Example: {"email": r"[^@ \t\r\n]+@[^@ \t\r\n]+\.[^@ \t\r\n]+",
|
23 |
+
"phone": r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b"}
|
24 |
+
"""
|
25 |
+
super().__init__(patterns=patterns)
|
26 |
+
self._compiled_patterns = {
|
27 |
+
name: re.compile(pattern) for name, pattern in patterns.items()
|
28 |
+
}
|
29 |
+
|
30 |
+
@weave.op()
|
31 |
+
def check(self, text: str) -> RegexResult:
|
32 |
+
"""
|
33 |
+
Check text against all patterns and return detailed results.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
text: Input text to check against patterns
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
RegexResult containing pass/fail status and details about matches
|
40 |
+
"""
|
41 |
+
matches: Dict[str, List[str]] = {}
|
42 |
+
failed_patterns: List[str] = []
|
43 |
+
|
44 |
+
for pattern_name, compiled_pattern in self._compiled_patterns.items():
|
45 |
+
found_matches = compiled_pattern.findall(text)
|
46 |
+
if found_matches:
|
47 |
+
matches[pattern_name] = found_matches
|
48 |
+
else:
|
49 |
+
failed_patterns.append(pattern_name)
|
50 |
+
|
51 |
+
# Consider it passed only if no patterns matched (no PII found)
|
52 |
+
passed = len(matches) == 0
|
53 |
+
|
54 |
+
return RegexResult(
|
55 |
+
passed=passed,
|
56 |
+
matched_patterns=matches,
|
57 |
+
failed_patterns=failed_patterns
|
58 |
+
)
|
59 |
+
|
60 |
+
@weave.op()
|
61 |
+
def predict(self, text: str) -> RegexResult:
|
62 |
+
"""
|
63 |
+
Alias for check() to maintain consistency with other models.
|
64 |
+
"""
|
65 |
+
return self.check(text)
|
pyproject.toml
CHANGED
@@ -20,6 +20,8 @@ dependencies = [
|
|
20 |
"pymupdf4llm>=0.0.17",
|
21 |
"transformers>=4.46.3",
|
22 |
"torch>=2.5.1",
|
|
|
|
|
23 |
]
|
24 |
|
25 |
[tool.setuptools]
|
|
|
20 |
"pymupdf4llm>=0.0.17",
|
21 |
"transformers>=4.46.3",
|
22 |
"torch>=2.5.1",
|
23 |
+
"presidio-analyzer>=2.2.355",
|
24 |
+
"presidio-anonymizer>=2.2.355",
|
25 |
]
|
26 |
|
27 |
[tool.setuptools]
|