Anish Shah commited on
Commit
13d2f14
·
unverified ·
2 Parent(s): 8382f82 06b5c11

Merge pull request #6 from soumik12345/feat/pii-banned-words

Browse files
Files changed (21) hide show
  1. .gitignore +2 -1
  2. application_pages/chat_app.py +28 -0
  3. guardrails_genie/guardrails/ReadMe.md +136 -0
  4. guardrails_genie/guardrails/__init__.py +10 -0
  5. guardrails_genie/guardrails/entity_recognition/__init__.py +10 -0
  6. guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_benchmark.py +0 -0
  7. guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_examples.py +178 -0
  8. guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_llm_judge.py +47 -0
  9. guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_regex_model.py +46 -0
  10. guardrails_genie/guardrails/entity_recognition/llm_judge_entity_recognition_guardrail.py +166 -0
  11. guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark.py +215 -0
  12. guardrails_genie/guardrails/entity_recognition/pii_examples/pii_test_examples.py +150 -0
  13. guardrails_genie/guardrails/entity_recognition/pii_examples/run_presidio_model.py +42 -0
  14. guardrails_genie/guardrails/entity_recognition/pii_examples/run_regex_model.py +42 -0
  15. guardrails_genie/guardrails/entity_recognition/pii_examples/run_transformers.py +43 -0
  16. guardrails_genie/guardrails/entity_recognition/presidio_entity_recognition_guardrail.py +191 -0
  17. guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py +138 -0
  18. guardrails_genie/guardrails/entity_recognition/transformers_entity_recognition_guardrail.py +190 -0
  19. guardrails_genie/guardrails/manager.py +7 -4
  20. guardrails_genie/regex_model.py +65 -0
  21. pyproject.toml +2 -0
.gitignore CHANGED
@@ -168,4 +168,5 @@ temp.txt
168
  **.csv
169
  binary-classifier/
170
  wandb/
171
- artifacts/
 
 
168
  **.csv
169
  binary-classifier/
170
  wandb/
171
+ artifacts/
172
+ evaluation_results/
application_pages/chat_app.py CHANGED
@@ -61,6 +61,34 @@ def initialize_guardrails():
61
  guardrail_name,
62
  )(model_name=classifier_model_name)
63
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  st.session_state.guardrails_manager = GuardrailManager(
65
  guardrails=st.session_state.guardrails
66
  )
 
61
  guardrail_name,
62
  )(model_name=classifier_model_name)
63
  )
64
+ elif guardrail_name == "PresidioEntityRecognitionGuardrail":
65
+ st.session_state.guardrails.append(
66
+ getattr(
67
+ importlib.import_module("guardrails_genie.guardrails"),
68
+ guardrail_name,
69
+ )()
70
+ )
71
+ elif guardrail_name == "RegexEntityRecognitionGuardrail":
72
+ st.session_state.guardrails.append(
73
+ getattr(
74
+ importlib.import_module("guardrails_genie.guardrails"),
75
+ guardrail_name,
76
+ )()
77
+ )
78
+ elif guardrail_name == "TransformersEntityRecognitionGuardrail":
79
+ st.session_state.guardrails.append(
80
+ getattr(
81
+ importlib.import_module("guardrails_genie.guardrails"),
82
+ guardrail_name,
83
+ )()
84
+ )
85
+ elif guardrail_name == "RestrictedTermsJudge":
86
+ st.session_state.guardrails.append(
87
+ getattr(
88
+ importlib.import_module("guardrails_genie.guardrails"),
89
+ guardrail_name,
90
+ )()
91
+ )
92
  st.session_state.guardrails_manager = GuardrailManager(
93
  guardrails=st.session_state.guardrails
94
  )
guardrails_genie/guardrails/ReadMe.md ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Entity Recognition Guardrails
2
+
3
+ A collection of guardrails for detecting and anonymizing various types of entities in text, including PII (Personally Identifiable Information), restricted terms, and custom entities.
4
+
5
+ ## Available Guardrails
6
+
7
+ ### 1. Regex Entity Recognition
8
+ Simple pattern-based entity detection using regular expressions.
9
+
10
+ ```python
11
+ from guardrails_genie.guardrails.entity_recognition import RegexEntityRecognitionGuardrail
12
+
13
+ # Initialize with default PII patterns
14
+ guardrail = RegexEntityRecognitionGuardrail(should_anonymize=True)
15
+
16
+ # Or with custom patterns
17
+ custom_patterns = {
18
+ "employee_id": r"EMP\d{6}",
19
+ "project_code": r"PRJ-[A-Z]{2}-\d{4}"
20
+ }
21
+ guardrail = RegexEntityRecognitionGuardrail(patterns=custom_patterns, should_anonymize=True)
22
+ ```
23
+
24
+ ### 2. Presidio Entity Recognition
25
+ Advanced entity detection using Microsoft's Presidio analyzer.
26
+
27
+ ```python
28
+ from guardrails_genie.guardrails.entity_recognition import PresidioEntityRecognitionGuardrail
29
+
30
+ # Initialize with default entities
31
+ guardrail = PresidioEntityRecognitionGuardrail(should_anonymize=True)
32
+
33
+ # Or with specific entities
34
+ selected_entities = ["CREDIT_CARD", "US_SSN", "EMAIL_ADDRESS"]
35
+ guardrail = PresidioEntityRecognitionGuardrail(
36
+ selected_entities=selected_entities,
37
+ should_anonymize=True
38
+ )
39
+ ```
40
+
41
+ ### 3. Transformers Entity Recognition
42
+ Entity detection using transformer-based models.
43
+
44
+ ```python
45
+ from guardrails_genie.guardrails.entity_recognition import TransformersEntityRecognitionGuardrail
46
+
47
+ # Initialize with default model
48
+ guardrail = TransformersEntityRecognitionGuardrail(should_anonymize=True)
49
+
50
+ # Or with specific model and entities
51
+ guardrail = TransformersEntityRecognitionGuardrail(
52
+ model_name="iiiorg/piiranha-v1-detect-personal-information",
53
+ selected_entities=["GIVENNAME", "SURNAME", "EMAIL"],
54
+ should_anonymize=True
55
+ )
56
+ ```
57
+
58
+ ### 4. LLM Judge for Restricted Terms
59
+ Advanced detection of restricted terms, competitor mentions, and brand protection using LLMs.
60
+
61
+ ```python
62
+ from guardrails_genie.guardrails.entity_recognition import RestrictedTermsJudge
63
+
64
+ # Initialize with OpenAI model
65
+ guardrail = RestrictedTermsJudge(should_anonymize=True)
66
+
67
+ # Check for specific terms
68
+ result = guardrail.guard(
69
+ text="Let's implement features like Salesforce",
70
+ custom_terms=["Salesforce", "Oracle", "AWS"]
71
+ )
72
+ ```
73
+
74
+ ## Usage
75
+
76
+ All guardrails follow a consistent interface:
77
+
78
+ ```python
79
+ # Initialize a guardrail
80
+ guardrail = RegexEntityRecognitionGuardrail(should_anonymize=True)
81
+
82
+ # Check text for entities
83
+ result = guardrail.guard("Hello, my email is john@example.com")
84
+
85
+ # Access results
86
+ print(f"Contains entities: {result.contains_entities}")
87
+ print(f"Detected entities: {result.detected_entities}")
88
+ print(f"Explanation: {result.explanation}")
89
+ print(f"Anonymized text: {result.anonymized_text}")
90
+ ```
91
+
92
+ ## Evaluation Tools
93
+
94
+ The module includes comprehensive evaluation tools and test cases:
95
+
96
+ - `pii_examples/`: Test cases for PII detection
97
+ - `banned_terms_examples/`: Test cases for restricted terms
98
+ - Benchmark scripts for evaluating model performance
99
+
100
+ ### Running Evaluations
101
+
102
+ ```python
103
+ # PII Detection Benchmark
104
+ from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_benchmark import main
105
+ main()
106
+
107
+ # (TODO): Restricted Terms Testing
108
+ from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_benchmark import main
109
+ main()
110
+ ```
111
+
112
+ ## Features
113
+
114
+ - Entity detection and anonymization
115
+ - Support for multiple detection methods (regex, Presidio, transformers, LLMs)
116
+ - Customizable entity types and patterns
117
+ - Detailed explanations of detected entities
118
+ - Comprehensive evaluation framework
119
+ - Support for custom terms and patterns
120
+ - Batch processing capabilities
121
+ - Performance metrics and benchmarking
122
+
123
+ ## Response Format
124
+
125
+ All guardrails return responses with the following structure:
126
+
127
+ ```python
128
+ {
129
+ "contains_entities": bool,
130
+ "detected_entities": {
131
+ "entity_type": ["detected_value_1", "detected_value_2"]
132
+ },
133
+ "explanation": str,
134
+ "anonymized_text": Optional[str]
135
+ }
136
+ ```
guardrails_genie/guardrails/__init__.py CHANGED
@@ -2,10 +2,20 @@ from .injection import (
2
  PromptInjectionClassifierGuardrail,
3
  PromptInjectionSurveyGuardrail,
4
  )
 
 
 
 
 
 
5
  from .manager import GuardrailManager
6
 
7
  __all__ = [
8
  "PromptInjectionSurveyGuardrail",
9
  "PromptInjectionClassifierGuardrail",
 
 
 
 
10
  "GuardrailManager",
11
  ]
 
2
  PromptInjectionClassifierGuardrail,
3
  PromptInjectionSurveyGuardrail,
4
  )
5
+ from .entity_recognition import (
6
+ PresidioEntityRecognitionGuardrail,
7
+ RegexEntityRecognitionGuardrail,
8
+ TransformersEntityRecognitionGuardrail,
9
+ RestrictedTermsJudge,
10
+ )
11
  from .manager import GuardrailManager
12
 
13
  __all__ = [
14
  "PromptInjectionSurveyGuardrail",
15
  "PromptInjectionClassifierGuardrail",
16
+ "PresidioEntityRecognitionGuardrail",
17
+ "RegexEntityRecognitionGuardrail",
18
+ "TransformersEntityRecognitionGuardrail",
19
+ "RestrictedTermsJudge",
20
  "GuardrailManager",
21
  ]
guardrails_genie/guardrails/entity_recognition/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
2
+ from .regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
3
+ from .transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
4
+ from .llm_judge_entity_recognition_guardrail import RestrictedTermsJudge
5
+ __all__ = [
6
+ "PresidioEntityRecognitionGuardrail",
7
+ "RegexEntityRecognitionGuardrail",
8
+ "TransformersEntityRecognitionGuardrail",
9
+ "RestrictedTermsJudge"
10
+ ]
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_benchmark.py ADDED
File without changes
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_examples.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Collection of restricted terms test examples with expected outcomes for entity recognition testing.
3
+ Focuses on banned terms, competitor mentions, and brand protection scenarios.
4
+ """
5
+
6
+ RESTRICTED_TERMS_EXAMPLES = [
7
+ {
8
+ "description": "Competitor Product Discussion",
9
+ "input_text": """
10
+ I think we should implement features similar to Salesforce's Einstein AI
11
+ and Oracle's Cloud Infrastructure. Maybe we could also look at how
12
+ AWS handles their lambda functions.
13
+ """,
14
+ "custom_terms": ["Salesforce", "Oracle", "AWS", "Einstein AI", "Cloud Infrastructure", "lambda"],
15
+ "expected_entities": {
16
+ "Salesforce": ["Salesforce"],
17
+ "Oracle": ["Oracle"],
18
+ "AWS": ["AWS"],
19
+ "Einstein AI": ["Einstein AI"],
20
+ "Cloud Infrastructure": ["Cloud Infrastructure"],
21
+ "lambda": ["lambda"]
22
+ }
23
+ },
24
+ {
25
+ "description": "Inappropriate Language in Support Ticket",
26
+ "input_text": """
27
+ This damn product keeps crashing! What the hell is wrong with your
28
+ stupid service? I've wasted so much freaking time on this crap.
29
+ """,
30
+ "custom_terms": ["damn", "hell", "stupid", "crap"],
31
+ "expected_entities": {
32
+ "damn": ["damn"],
33
+ "hell": ["hell"],
34
+ "stupid": ["stupid"],
35
+ "crap": ["crap"]
36
+ }
37
+ },
38
+ {
39
+ "description": "Confidential Project Names",
40
+ "input_text": """
41
+ Project Titan's launch date has been moved up. We should coordinate
42
+ with Project Phoenix team and the Blue Dragon initiative for resource allocation.
43
+ """,
44
+ "custom_terms": ["Project Titan", "Project Phoenix", "Blue Dragon"],
45
+ "expected_entities": {
46
+ "Project Titan": ["Project Titan"],
47
+ "Project Phoenix": ["Project Phoenix"],
48
+ "Blue Dragon": ["Blue Dragon"]
49
+ }
50
+ }
51
+ ]
52
+
53
+ # Edge cases and special formats
54
+ EDGE_CASE_EXAMPLES = [
55
+ {
56
+ "description": "Common Corporate Abbreviations and Stock Symbols",
57
+ "input_text": """
58
+ MSFT's Azure and O365 platform is gaining market share.
59
+ Have you seen what GOOGL/GOOG and FB/META are doing with their AI?
60
+ CRM (Salesforce) and ORCL (Oracle) have interesting features too.
61
+ """,
62
+ "custom_terms": ["Microsoft", "Google", "Meta", "Facebook", "Salesforce", "Oracle"],
63
+ "expected_entities": {
64
+ "Microsoft": ["MSFT"],
65
+ "Google": ["GOOGL", "GOOG"],
66
+ "Meta": ["META"],
67
+ "Facebook": ["FB"],
68
+ "Salesforce": ["CRM", "Salesforce"],
69
+ "Oracle": ["ORCL"]
70
+ }
71
+ },
72
+ {
73
+ "description": "L33t Speak and Intentional Obfuscation",
74
+ "input_text": """
75
+ S4l3sf0rc3 is better than 0r4cl3!
76
+ M1cr0$oft and G00gl3 are the main competitors.
77
+ Let's check F8book and Met@ too.
78
+ """,
79
+ "custom_terms": ["Salesforce", "Oracle", "Microsoft", "Google", "Facebook", "Meta"],
80
+ "expected_entities": {
81
+ "Salesforce": ["S4l3sf0rc3"],
82
+ "Oracle": ["0r4cl3"],
83
+ "Microsoft": ["M1cr0$oft"],
84
+ "Google": ["G00gl3"],
85
+ "Facebook": ["F8book"],
86
+ "Meta": ["Met@"]
87
+ }
88
+ },
89
+ {
90
+ "description": "Case Variations and Partial Matches",
91
+ "input_text": """
92
+ salesFORCE and ORACLE are competitors.
93
+ MicroSoft and google are too.
94
+ Have you tried micro-soft or Google_Cloud?
95
+ """,
96
+ "custom_terms": ["Microsoft", "Google", "Salesforce", "Oracle"],
97
+ "expected_entities": {
98
+ "Microsoft": ["MicroSoft", "micro-soft"],
99
+ "Google": ["google", "Google_Cloud"],
100
+ "Salesforce": ["salesFORCE"],
101
+ "Oracle": ["ORACLE"]
102
+ }
103
+ },
104
+ {
105
+ "description": "Common Misspellings and Typos",
106
+ "input_text": """
107
+ Microsft and Microsooft are common typos.
108
+ Goggle, Googel, and Gooogle are search related.
109
+ Salezforce and Oracel need checking too.
110
+ """,
111
+ "custom_terms": ["Microsoft", "Google", "Salesforce", "Oracle"],
112
+ "expected_entities": {
113
+ "Microsoft": ["Microsft", "Microsooft"],
114
+ "Google": ["Goggle", "Googel", "Gooogle"],
115
+ "Salesforce": ["Salezforce"],
116
+ "Oracle": ["Oracel"]
117
+ }
118
+ },
119
+ {
120
+ "description": "Mixed Variations and Context",
121
+ "input_text": """
122
+ The M$ cloud competes with AWS (Amazon Web Services).
123
+ FB/Meta's social platform and GOOGL's search dominate.
124
+ SF.com and Oracle-DB are industry standards.
125
+ """,
126
+ "custom_terms": ["Microsoft", "Amazon Web Services", "Facebook", "Meta", "Google", "Salesforce", "Oracle"],
127
+ "expected_entities": {
128
+ "Microsoft": ["M$"],
129
+ "Amazon Web Services": ["AWS"],
130
+ "Facebook": ["FB"],
131
+ "Meta": ["Meta"],
132
+ "Google": ["GOOGL"],
133
+ "Salesforce": ["SF.com"],
134
+ "Oracle": ["Oracle-DB"]
135
+ }
136
+ }
137
+ ]
138
+
139
+ def validate_entities(detected: dict, expected: dict) -> bool:
140
+ """Compare detected entities with expected entities"""
141
+ if set(detected.keys()) != set(expected.keys()):
142
+ return False
143
+ return all(set(detected[k]) == set(expected[k]) for k in expected.keys())
144
+
145
+ def run_test_case(guardrail, test_case, test_type="Main"):
146
+ """Run a single test case and print results"""
147
+ print(f"\n{test_type} Test Case: {test_case['description']}")
148
+ print("-" * 50)
149
+
150
+ result = guardrail.guard(
151
+ test_case['input_text'],
152
+ custom_terms=test_case['custom_terms']
153
+ )
154
+ expected = test_case['expected_entities']
155
+
156
+ # Validate results
157
+ matches = validate_entities(result.detected_entities, expected)
158
+
159
+ print(f"Test Status: {'✓ PASS' if matches else '✗ FAIL'}")
160
+ print(f"Contains Restricted Terms: {result.contains_entities}")
161
+
162
+ if not matches:
163
+ print("\nEntity Comparison:")
164
+ all_entity_types = set(list(result.detected_entities.keys()) + list(expected.keys()))
165
+ for entity_type in all_entity_types:
166
+ detected = set(result.detected_entities.get(entity_type, []))
167
+ expected_set = set(expected.get(entity_type, []))
168
+ print(f"\nEntity Type: {entity_type}")
169
+ print(f" Expected: {sorted(expected_set)}")
170
+ print(f" Detected: {sorted(detected)}")
171
+ if detected != expected_set:
172
+ print(f" Missing: {sorted(expected_set - detected)}")
173
+ print(f" Extra: {sorted(detected - expected_set)}")
174
+
175
+ if result.anonymized_text:
176
+ print(f"\nAnonymized Text:\n{result.anonymized_text}")
177
+
178
+ return matches
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_llm_judge.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from guardrails_genie.guardrails.entity_recognition.llm_judge_entity_recognition_guardrail import RestrictedTermsJudge
2
+ from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
3
+ RESTRICTED_TERMS_EXAMPLES,
4
+ EDGE_CASE_EXAMPLES,
5
+ run_test_case
6
+ )
7
+ from guardrails_genie.llm import OpenAIModel
8
+ import weave
9
+
10
+ def test_restricted_terms_detection():
11
+ """Test restricted terms detection scenarios using predefined test cases"""
12
+ weave.init("guardrails-genie-restricted-terms-llm-judge")
13
+
14
+ # Create the guardrail with OpenAI model
15
+ llm_judge = RestrictedTermsJudge(
16
+ should_anonymize=True,
17
+ llm_model=OpenAIModel()
18
+ )
19
+
20
+ # Test statistics
21
+ total_tests = len(RESTRICTED_TERMS_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
22
+ passed_tests = 0
23
+
24
+ # Test main restricted terms examples
25
+ print("\nRunning Main Restricted Terms Tests")
26
+ print("=" * 80)
27
+ for test_case in RESTRICTED_TERMS_EXAMPLES:
28
+ if run_test_case(llm_judge, test_case):
29
+ passed_tests += 1
30
+
31
+ # Test edge cases
32
+ print("\nRunning Edge Cases")
33
+ print("=" * 80)
34
+ for test_case in EDGE_CASE_EXAMPLES:
35
+ if run_test_case(llm_judge, test_case, "Edge"):
36
+ passed_tests += 1
37
+
38
+ # Print summary
39
+ print("\nTest Summary")
40
+ print("=" * 80)
41
+ print(f"Total Tests: {total_tests}")
42
+ print(f"Passed: {passed_tests}")
43
+ print(f"Failed: {total_tests - passed_tests}")
44
+ print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
45
+
46
+ if __name__ == "__main__":
47
+ test_restricted_terms_detection()
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_regex_model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
2
+ from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
3
+ RESTRICTED_TERMS_EXAMPLES,
4
+ EDGE_CASE_EXAMPLES,
5
+ run_test_case
6
+ )
7
+ import weave
8
+
9
+ def test_restricted_terms_detection():
10
+ """Test restricted terms detection scenarios using predefined test cases"""
11
+ weave.init("guardrails-genie-restricted-terms-regex-model")
12
+
13
+ # Create the guardrail with anonymization enabled
14
+ regex_guardrail = RegexEntityRecognitionGuardrail(
15
+ use_defaults=False, # Don't use default PII patterns
16
+ should_anonymize=True
17
+ )
18
+
19
+ # Test statistics
20
+ total_tests = len(RESTRICTED_TERMS_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
21
+ passed_tests = 0
22
+
23
+ # Test main restricted terms examples
24
+ print("\nRunning Main Restricted Terms Tests")
25
+ print("=" * 80)
26
+ for test_case in RESTRICTED_TERMS_EXAMPLES:
27
+ if run_test_case(regex_guardrail, test_case):
28
+ passed_tests += 1
29
+
30
+ # Test edge cases
31
+ print("\nRunning Edge Cases")
32
+ print("=" * 80)
33
+ for test_case in EDGE_CASE_EXAMPLES:
34
+ if run_test_case(regex_guardrail, test_case, "Edge"):
35
+ passed_tests += 1
36
+
37
+ # Print summary
38
+ print("\nTest Summary")
39
+ print("=" * 80)
40
+ print(f"Total Tests: {total_tests}")
41
+ print(f"Passed: {passed_tests}")
42
+ print(f"Failed: {total_tests - passed_tests}")
43
+ print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
44
+
45
+ if __name__ == "__main__":
46
+ test_restricted_terms_detection()
guardrails_genie/guardrails/entity_recognition/llm_judge_entity_recognition_guardrail.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+ import weave
3
+ from pydantic import BaseModel, Field
4
+ from typing_extensions import Annotated
5
+
6
+ from ...llm import OpenAIModel
7
+ from ..base import Guardrail
8
+ import instructor
9
+
10
+
11
+ class TermMatch(BaseModel):
12
+ """Represents a matched term and its variations"""
13
+ original_term: str
14
+ matched_text: str
15
+ match_type: str = Field(
16
+ description="Type of match: EXACT, MISSPELLING, ABBREVIATION, or VARIANT"
17
+ )
18
+ explanation: str = Field(
19
+ description="Explanation of why this is considered a match"
20
+ )
21
+
22
+
23
+ class RestrictedTermsAnalysis(BaseModel):
24
+ """Analysis result for restricted terms detection"""
25
+ contains_restricted_terms: bool = Field(
26
+ description="Whether any restricted terms were detected"
27
+ )
28
+ detected_matches: List[TermMatch] = Field(
29
+ default_factory=list,
30
+ description="List of detected term matches with their variations"
31
+ )
32
+ explanation: str = Field(
33
+ description="Detailed explanation of the analysis"
34
+ )
35
+ anonymized_text: Optional[str] = Field(
36
+ default=None,
37
+ description="Text with restricted terms replaced with category tags"
38
+ )
39
+
40
+ @property
41
+ def safe(self) -> bool:
42
+ return not self.contains_restricted_terms
43
+
44
+
45
+ class RestrictedTermsRecognitionResponse(BaseModel):
46
+ contains_entities: bool
47
+ detected_entities: Dict[str, List[str]]
48
+ explanation: str
49
+ anonymized_text: Optional[str] = None
50
+
51
+ @property
52
+ def safe(self) -> bool:
53
+ return not self.contains_entities
54
+
55
+
56
+ class RestrictedTermsJudge(Guardrail):
57
+ llm_model: OpenAIModel = Field(default_factory=lambda: OpenAIModel())
58
+ should_anonymize: bool = False
59
+
60
+ def __init__(self, should_anonymize: bool = False, **kwargs):
61
+ super().__init__(**kwargs)
62
+ self.should_anonymize = should_anonymize
63
+ # Pass the OpenAI client to instructor.patch()
64
+ instructor.patch(client=self.llm_model._openai_client)
65
+
66
+ def format_prompts(self, text: str, custom_terms: List[str]) -> tuple[str, str]:
67
+ user_prompt = f"""Analyze the following text for restricted terms and variations:
68
+
69
+ <text_to_analyze>
70
+ {text}
71
+ </text_to_analyze>
72
+
73
+ Look for these specific terms and their variations:
74
+ {', '.join(custom_terms)}
75
+
76
+ Analyze the text carefully for:
77
+ 1. Exact matches
78
+ 2. Common misspellings
79
+ 3. Abbreviations
80
+ 4. Intentional variations (l33t speak, special characters, etc.)
81
+ 5. Case variations
82
+ """
83
+
84
+ system_prompt = """You are an expert system for detecting restricted terms, brand mentions, and inappropriate content.
85
+
86
+ Your task is to:
87
+ 1. Identify exact matches of restricted terms
88
+ 2. Detect variations including:
89
+ - Misspellings (both accidental and intentional)
90
+ - Abbreviations and acronyms
91
+ - Case variations
92
+ - L33t speak or special character substitutions
93
+ - Partial matches within larger words
94
+
95
+ For each match, you must:
96
+ 1. Identify the original restricted term
97
+ 2. Note the actual text that matched
98
+ 3. Classify the match type
99
+ 4. Provide a confidence score
100
+ 5. Explain why it's considered a match
101
+
102
+ Be thorough but avoid false positives. Focus on meaningful matches that indicate actual attempts to use restricted terms.
103
+
104
+ Return your analysis in the structured format specified by the RestrictedTermsAnalysis model."""
105
+
106
+ return user_prompt, system_prompt
107
+
108
+ @weave.op()
109
+ def predict(self, text: str, custom_terms: List[str], **kwargs) -> RestrictedTermsAnalysis:
110
+ user_prompt, system_prompt = self.format_prompts(text, custom_terms)
111
+
112
+ response = self.llm_model.predict(
113
+ user_prompts=user_prompt,
114
+ system_prompt=system_prompt,
115
+ response_format=RestrictedTermsAnalysis,
116
+ temperature=0.1, # Lower temperature for more consistent analysis
117
+ **kwargs
118
+ )
119
+
120
+ return response.choices[0].message.parsed
121
+
122
+ #TODO: Remove default custom_terms
123
+ @weave.op()
124
+ def guard(self, text: str, custom_terms: List[str] = ["Microsoft", "Amazon Web Services", "Facebook", "Meta", "Google", "Salesforce", "Oracle"], aggregate_redaction: bool = True, **kwargs) -> RestrictedTermsRecognitionResponse:
125
+ """
126
+ Guard against restricted terms and their variations.
127
+
128
+ Args:
129
+ text: Text to analyze
130
+ custom_terms: List of restricted terms to check for
131
+
132
+ Returns:
133
+ RestrictedTermsRecognitionResponse containing safety assessment and detailed analysis
134
+ """
135
+ analysis = self.predict(text, custom_terms, **kwargs)
136
+
137
+ # Create a summary of findings
138
+ if analysis.contains_restricted_terms:
139
+ summary_parts = ["Restricted terms detected:"]
140
+ for match in analysis.detected_matches:
141
+ summary_parts.append(f"\n- {match.original_term}: {match.matched_text} ({match.match_type})")
142
+ summary = "\n".join(summary_parts)
143
+ else:
144
+ summary = "No restricted terms detected."
145
+
146
+ # Updated anonymization logic
147
+ anonymized_text = None
148
+ if self.should_anonymize and analysis.contains_restricted_terms:
149
+ anonymized_text = text
150
+ for match in analysis.detected_matches:
151
+ replacement = "[redacted]" if aggregate_redaction else f"[{match.match_type.upper()}]"
152
+ anonymized_text = anonymized_text.replace(match.matched_text, replacement)
153
+
154
+ # Convert detected_matches to a dictionary format
155
+ detected_entities = {}
156
+ for match in analysis.detected_matches:
157
+ if match.original_term not in detected_entities:
158
+ detected_entities[match.original_term] = []
159
+ detected_entities[match.original_term].append(match.matched_text)
160
+
161
+ return RestrictedTermsRecognitionResponse(
162
+ contains_entities=analysis.contains_restricted_terms,
163
+ detected_entities=detected_entities,
164
+ explanation=summary,
165
+ anonymized_text=anonymized_text
166
+ )
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from typing import Dict, List, Tuple
3
+ import random
4
+ from tqdm import tqdm
5
+ import json
6
+ from pathlib import Path
7
+ import weave
8
+
9
+ def load_ai4privacy_dataset(num_samples: int = 100, split: str = "validation") -> List[Dict]:
10
+ """
11
+ Load and prepare samples from the ai4privacy dataset.
12
+
13
+ Args:
14
+ num_samples: Number of samples to evaluate
15
+ split: Dataset split to use ("train" or "validation")
16
+
17
+ Returns:
18
+ List of prepared test cases
19
+ """
20
+ # Load the dataset
21
+ dataset = load_dataset("ai4privacy/pii-masking-400k")
22
+
23
+ # Get the specified split
24
+ data_split = dataset[split]
25
+
26
+ # Randomly sample entries if num_samples is less than total
27
+ if num_samples < len(data_split):
28
+ indices = random.sample(range(len(data_split)), num_samples)
29
+ samples = [data_split[i] for i in indices]
30
+ else:
31
+ samples = data_split
32
+
33
+ # Convert to test case format
34
+ test_cases = []
35
+ for sample in samples:
36
+ # Extract entities from privacy_mask
37
+ entities: Dict[str, List[str]] = {}
38
+ for entity in sample['privacy_mask']:
39
+ label = entity['label']
40
+ value = entity['value']
41
+ if label not in entities:
42
+ entities[label] = []
43
+ entities[label].append(value)
44
+
45
+ test_case = {
46
+ "description": f"AI4Privacy Sample (ID: {sample['uid']})",
47
+ "input_text": sample['source_text'],
48
+ "expected_entities": entities,
49
+ "masked_text": sample['masked_text'],
50
+ "language": sample['language'],
51
+ "locale": sample['locale']
52
+ }
53
+ test_cases.append(test_case)
54
+
55
+ return test_cases
56
+
57
+ @weave.op()
58
+ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]:
59
+ """
60
+ Evaluate a model on the test cases.
61
+
62
+ Args:
63
+ guardrail: Entity recognition guardrail to evaluate
64
+ test_cases: List of test cases
65
+
66
+ Returns:
67
+ Tuple of (metrics dict, detailed results list)
68
+ """
69
+ metrics = {
70
+ "total": len(test_cases),
71
+ "passed": 0,
72
+ "failed": 0,
73
+ "entity_metrics": {} # Will store precision/recall per entity type
74
+ }
75
+
76
+ detailed_results = []
77
+
78
+ for test_case in tqdm(test_cases, desc="Evaluating samples"):
79
+ # Run detection
80
+ result = guardrail.guard(test_case['input_text'])
81
+ detected = result.detected_entities
82
+ expected = test_case['expected_entities']
83
+
84
+ # Track entity-level metrics
85
+ all_entity_types = set(list(detected.keys()) + list(expected.keys()))
86
+ entity_results = {}
87
+
88
+ for entity_type in all_entity_types:
89
+ detected_set = set(detected.get(entity_type, []))
90
+ expected_set = set(expected.get(entity_type, []))
91
+
92
+ # Calculate metrics
93
+ true_positives = len(detected_set & expected_set)
94
+ false_positives = len(detected_set - expected_set)
95
+ false_negatives = len(expected_set - detected_set)
96
+
97
+ precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
98
+ recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
99
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
100
+
101
+ entity_results[entity_type] = {
102
+ "precision": precision,
103
+ "recall": recall,
104
+ "f1": f1,
105
+ "true_positives": true_positives,
106
+ "false_positives": false_positives,
107
+ "false_negatives": false_negatives
108
+ }
109
+
110
+ # Aggregate metrics
111
+ if entity_type not in metrics["entity_metrics"]:
112
+ metrics["entity_metrics"][entity_type] = {
113
+ "total_true_positives": 0,
114
+ "total_false_positives": 0,
115
+ "total_false_negatives": 0
116
+ }
117
+ metrics["entity_metrics"][entity_type]["total_true_positives"] += true_positives
118
+ metrics["entity_metrics"][entity_type]["total_false_positives"] += false_positives
119
+ metrics["entity_metrics"][entity_type]["total_false_negatives"] += false_negatives
120
+
121
+ # Store detailed result
122
+ detailed_result = {
123
+ "id": test_case.get("description", ""),
124
+ "language": test_case.get("language", ""),
125
+ "locale": test_case.get("locale", ""),
126
+ "input_text": test_case["input_text"],
127
+ "expected_entities": expected,
128
+ "detected_entities": detected,
129
+ "entity_metrics": entity_results,
130
+ "anonymized_text": result.anonymized_text if result.anonymized_text else None
131
+ }
132
+ detailed_results.append(detailed_result)
133
+
134
+ # Update pass/fail counts
135
+ if all(entity_results[et]["f1"] == 1.0 for et in entity_results):
136
+ metrics["passed"] += 1
137
+ else:
138
+ metrics["failed"] += 1
139
+
140
+ # Calculate final entity metrics
141
+ for entity_type, counts in metrics["entity_metrics"].items():
142
+ tp = counts["total_true_positives"]
143
+ fp = counts["total_false_positives"]
144
+ fn = counts["total_false_negatives"]
145
+
146
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
147
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
148
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
149
+
150
+ metrics["entity_metrics"][entity_type].update({
151
+ "precision": precision,
152
+ "recall": recall,
153
+ "f1": f1
154
+ })
155
+
156
+ return metrics, detailed_results
157
+
158
+ def save_results(metrics: Dict, detailed_results: List[Dict], model_name: str, output_dir: str = "evaluation_results"):
159
+ """Save evaluation results to files"""
160
+ output_dir = Path(output_dir)
161
+ output_dir.mkdir(exist_ok=True)
162
+
163
+ # Save metrics summary
164
+ with open(output_dir / f"{model_name}_metrics.json", "w") as f:
165
+ json.dump(metrics, f, indent=2)
166
+
167
+ # Save detailed results
168
+ with open(output_dir / f"{model_name}_detailed_results.json", "w") as f:
169
+ json.dump(detailed_results, f, indent=2)
170
+
171
+ def print_metrics_summary(metrics: Dict):
172
+ """Print a summary of the evaluation metrics"""
173
+ print("\nEvaluation Summary")
174
+ print("=" * 80)
175
+ print(f"Total Samples: {metrics['total']}")
176
+ print(f"Passed: {metrics['passed']}")
177
+ print(f"Failed: {metrics['failed']}")
178
+ print(f"Success Rate: {(metrics['passed']/metrics['total'])*100:.1f}%")
179
+
180
+ print("\nEntity-level Metrics:")
181
+ print("-" * 80)
182
+ print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
183
+ print("-" * 80)
184
+ for entity_type, entity_metrics in metrics["entity_metrics"].items():
185
+ print(f"{entity_type:<20} {entity_metrics['precision']:>10.2f} {entity_metrics['recall']:>10.2f} {entity_metrics['f1']:>10.2f}")
186
+
187
+ def main():
188
+ """Main evaluation function"""
189
+ weave.init("guardrails-genie-pii-evaluation")
190
+
191
+ # Load test cases
192
+ test_cases = load_ai4privacy_dataset(num_samples=100)
193
+
194
+ # Initialize models to evaluate
195
+ models = {
196
+ "regex": RegexEntityRecognitionGuardrail(should_anonymize=True),
197
+ "presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True),
198
+ "transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True)
199
+ }
200
+
201
+ # Evaluate each model
202
+ for model_name, guardrail in models.items():
203
+ print(f"\nEvaluating {model_name} model...")
204
+ metrics, detailed_results = evaluate_model(guardrail, test_cases)
205
+
206
+ # Print and save results
207
+ print_metrics_summary(metrics)
208
+ save_results(metrics, detailed_results, model_name)
209
+
210
+ if __name__ == "__main__":
211
+ from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
212
+ from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
213
+ from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
214
+
215
+ main()
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_test_examples.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Collection of PII test examples with expected outcomes for entity recognition testing.
3
+ Each example includes the input text and expected entities to be detected.
4
+ """
5
+
6
+ PII_TEST_EXAMPLES = [
7
+ {
8
+ "description": "Business Context - Employee Record",
9
+ "input_text": """
10
+ Please update our records for employee John Smith:
11
+ Email: john.smith@company.com
12
+ Phone: 123-456-7890
13
+ SSN: 123-45-6789
14
+ Emergency Contact: Mary Johnson (Tel: 098-765-4321)
15
+ """,
16
+ "expected_entities": {
17
+ "GIVENNAME": ["John", "Mary"],
18
+ "SURNAME": ["Smith", "Johnson"],
19
+ "EMAIL": ["john.smith@company.com"],
20
+ "PHONE_NUMBER": ["123-456-7890", "098-765-4321"],
21
+ "SOCIALNUM": ["123-45-6789"]
22
+ }
23
+ },
24
+ {
25
+ "description": "Meeting Notes with Attendees",
26
+ "input_text": """
27
+ Meeting Notes - Project Alpha
28
+ Date: 2024-03-15
29
+ Attendees:
30
+ - Sarah Williams (sarah.w@company.com)
31
+ - Robert Brown (bobby@email.com)
32
+ - Tom Wilson (555-0123-4567)
33
+
34
+ Action Items:
35
+ 1. Sarah to review documentation
36
+ 2. Contact Bob at his alternate number: 777-888-9999
37
+ """,
38
+ "expected_entities": {
39
+ "GIVENNAME": ["Sarah", "Robert", "Tom", "Bob"],
40
+ "SURNAME": ["Williams", "Brown", "Wilson"],
41
+ "EMAIL": ["sarah.w@company.com", "bobby@email.com"],
42
+ "PHONE_NUMBER": ["555-0123-4567", "777-888-9999"]
43
+ }
44
+ },
45
+ {
46
+ "description": "Medical Record",
47
+ "input_text": """
48
+ Patient: Emma Thompson
49
+ DOB: 05/15/1980
50
+ Medical Record #: MR-12345
51
+ Primary Care: Dr. James Wilson
52
+ Contact: emma.t@email.com
53
+ Insurance ID: INS-987654321
54
+ Emergency Contact: Michael Thompson (555-123-4567)
55
+ """,
56
+ "expected_entities": {
57
+ "GIVENNAME": ["Emma", "James", "Michael"],
58
+ "SURNAME": ["Thompson", "Wilson", "Thompson"],
59
+ "EMAIL": ["emma.t@email.com"],
60
+ "PHONE_NUMBER": ["555-123-4567"]
61
+ }
62
+ },
63
+ {
64
+ "description": "No PII Content",
65
+ "input_text": """
66
+ Project Status Update:
67
+ - All deliverables are on track
68
+ - Budget is within limits
69
+ - Next review scheduled for next week
70
+ """,
71
+ "expected_entities": {}
72
+ },
73
+ {
74
+ "description": "Mixed Format Phone Numbers",
75
+ "input_text": """
76
+ Contact Directory:
77
+ Main Office: (555) 123-4567
78
+ Support: 555.987.6543
79
+ International: +1-555-321-7890
80
+ Emergency: 555 444 3333
81
+ """,
82
+ "expected_entities": {
83
+ "PHONE_NUMBER": [
84
+ "(555) 123-4567",
85
+ "555.987.6543",
86
+ "+1-555-321-7890",
87
+ "555 444 3333"
88
+ ]
89
+ }
90
+ }
91
+ ]
92
+
93
+ # Additional examples can be added to test specific edge cases or formats
94
+ EDGE_CASE_EXAMPLES = [
95
+ {
96
+ "description": "Mixed Case and Special Characters",
97
+ "input_text": """
98
+ JoHn.DoE@company.com
99
+ JANE_SMITH@email.com
100
+ bob.jones123@domain.co.uk
101
+ """,
102
+ "expected_entities": {
103
+ "EMAIL": [
104
+ "JoHn.DoE@company.com",
105
+ "JANE_SMITH@email.com",
106
+ "bob.jones123@domain.co.uk"
107
+ ],
108
+ "GIVENNAME": ["John", "Jane", "Bob"],
109
+ "SURNAME": ["Doe", "Smith", "Jones"]
110
+ }
111
+ }
112
+ ]
113
+
114
+ def validate_entities(detected: dict, expected: dict) -> bool:
115
+ """Compare detected entities with expected entities"""
116
+ if set(detected.keys()) != set(expected.keys()):
117
+ return False
118
+ return all(set(detected[k]) == set(expected[k]) for k in expected.keys())
119
+
120
+ def run_test_case(guardrail, test_case, test_type="Main"):
121
+ """Run a single test case and print results"""
122
+ print(f"\n{test_type} Test Case: {test_case['description']}")
123
+ print("-" * 50)
124
+
125
+ result = guardrail.guard(test_case['input_text'])
126
+ expected = test_case['expected_entities']
127
+
128
+ # Validate results
129
+ matches = validate_entities(result.detected_entities, expected)
130
+
131
+ print(f"Test Status: {'✓ PASS' if matches else '✗ FAIL'}")
132
+ print(f"Contains PII: {result.contains_entities}")
133
+
134
+ if not matches:
135
+ print("\nEntity Comparison:")
136
+ all_entity_types = set(list(result.detected_entities.keys()) + list(expected.keys()))
137
+ for entity_type in all_entity_types:
138
+ detected = set(result.detected_entities.get(entity_type, []))
139
+ expected_set = set(expected.get(entity_type, []))
140
+ print(f"\nEntity Type: {entity_type}")
141
+ print(f" Expected: {sorted(expected_set)}")
142
+ print(f" Detected: {sorted(detected)}")
143
+ if detected != expected_set:
144
+ print(f" Missing: {sorted(expected_set - detected)}")
145
+ print(f" Extra: {sorted(detected - expected_set)}")
146
+
147
+ if result.anonymized_text:
148
+ print(f"\nAnonymized Text:\n{result.anonymized_text}")
149
+
150
+ return matches
guardrails_genie/guardrails/entity_recognition/pii_examples/run_presidio_model.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
2
+ from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import PII_TEST_EXAMPLES, EDGE_CASE_EXAMPLES, run_test_case, validate_entities
3
+ import weave
4
+
5
+ def test_pii_detection():
6
+ """Test PII detection scenarios using predefined test cases"""
7
+ weave.init("guardrails-genie-pii-presidio-model")
8
+
9
+ # Create the guardrail with default entities and anonymization enabled
10
+ pii_guardrail = PresidioEntityRecognitionGuardrail(
11
+ should_anonymize=True,
12
+ show_available_entities=True
13
+ )
14
+
15
+ # Test statistics
16
+ total_tests = len(PII_TEST_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
17
+ passed_tests = 0
18
+
19
+ # Test main PII examples
20
+ print("\nRunning Main PII Tests")
21
+ print("=" * 80)
22
+ for test_case in PII_TEST_EXAMPLES:
23
+ if run_test_case(pii_guardrail, test_case):
24
+ passed_tests += 1
25
+
26
+ # Test edge cases
27
+ print("\nRunning Edge Cases")
28
+ print("=" * 80)
29
+ for test_case in EDGE_CASE_EXAMPLES:
30
+ if run_test_case(pii_guardrail, test_case, "Edge"):
31
+ passed_tests += 1
32
+
33
+ # Print summary
34
+ print("\nTest Summary")
35
+ print("=" * 80)
36
+ print(f"Total Tests: {total_tests}")
37
+ print(f"Passed: {passed_tests}")
38
+ print(f"Failed: {total_tests - passed_tests}")
39
+ print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
40
+
41
+ if __name__ == "__main__":
42
+ test_pii_detection()
guardrails_genie/guardrails/entity_recognition/pii_examples/run_regex_model.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
2
+ from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import PII_TEST_EXAMPLES, EDGE_CASE_EXAMPLES, run_test_case, validate_entities
3
+ import weave
4
+
5
+ def test_pii_detection():
6
+ """Test PII detection scenarios using predefined test cases"""
7
+ weave.init("guardrails-genie-pii-regex-model")
8
+
9
+ # Create the guardrail with default entities and anonymization enabled
10
+ pii_guardrail = RegexEntityRecognitionGuardrail(
11
+ should_anonymize=True,
12
+ show_available_entities=True
13
+ )
14
+
15
+ # Test statistics
16
+ total_tests = len(PII_TEST_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
17
+ passed_tests = 0
18
+
19
+ # Test main PII examples
20
+ print("\nRunning Main PII Tests")
21
+ print("=" * 80)
22
+ for test_case in PII_TEST_EXAMPLES:
23
+ if run_test_case(pii_guardrail, test_case):
24
+ passed_tests += 1
25
+
26
+ # Test edge cases
27
+ print("\nRunning Edge Cases")
28
+ print("=" * 80)
29
+ for test_case in EDGE_CASE_EXAMPLES:
30
+ if run_test_case(pii_guardrail, test_case, "Edge"):
31
+ passed_tests += 1
32
+
33
+ # Print summary
34
+ print("\nTest Summary")
35
+ print("=" * 80)
36
+ print(f"Total Tests: {total_tests}")
37
+ print(f"Passed: {passed_tests}")
38
+ print(f"Failed: {total_tests - passed_tests}")
39
+ print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
40
+
41
+ if __name__ == "__main__":
42
+ test_pii_detection()
guardrails_genie/guardrails/entity_recognition/pii_examples/run_transformers.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
2
+ from guardrails_genie.guardrails.entity_recognition.pii_examples.pii_test_examples import PII_TEST_EXAMPLES, EDGE_CASE_EXAMPLES, run_test_case, validate_entities
3
+ import weave
4
+
5
+ def test_pii_detection():
6
+ """Test PII detection scenarios using predefined test cases"""
7
+ weave.init("guardrails-genie-pii-transformers-pipeline-model")
8
+
9
+ # Create the guardrail with default entities and anonymization enabled
10
+ pii_guardrail = TransformersEntityRecognitionGuardrail(
11
+ selected_entities=["GIVENNAME", "SURNAME", "EMAIL", "TELEPHONENUM", "SOCIALNUM"],
12
+ should_anonymize=True,
13
+ show_available_entities=True
14
+ )
15
+
16
+ # Test statistics
17
+ total_tests = len(PII_TEST_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
18
+ passed_tests = 0
19
+
20
+ # Test main PII examples
21
+ print("\nRunning Main PII Tests")
22
+ print("=" * 80)
23
+ for test_case in PII_TEST_EXAMPLES:
24
+ if run_test_case(pii_guardrail, test_case):
25
+ passed_tests += 1
26
+
27
+ # Test edge cases
28
+ print("\nRunning Edge Cases")
29
+ print("=" * 80)
30
+ for test_case in EDGE_CASE_EXAMPLES:
31
+ if run_test_case(pii_guardrail, test_case, "Edge"):
32
+ passed_tests += 1
33
+
34
+ # Print summary
35
+ print("\nTest Summary")
36
+ print("=" * 80)
37
+ print(f"Total Tests: {total_tests}")
38
+ print(f"Passed: {passed_tests}")
39
+ print(f"Failed: {total_tests - passed_tests}")
40
+ print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
41
+
42
+ if __name__ == "__main__":
43
+ test_pii_detection()
guardrails_genie/guardrails/entity_recognition/presidio_entity_recognition_guardrail.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional, ClassVar, Any
2
+ import weave
3
+ from pydantic import BaseModel
4
+
5
+ from presidio_analyzer import AnalyzerEngine, RecognizerRegistry, Pattern, PatternRecognizer
6
+ from presidio_anonymizer import AnonymizerEngine
7
+
8
+ from ..base import Guardrail
9
+
10
+ class PresidioEntityRecognitionResponse(BaseModel):
11
+ contains_entities: bool
12
+ detected_entities: Dict[str, List[str]]
13
+ explanation: str
14
+ anonymized_text: Optional[str] = None
15
+
16
+ @property
17
+ def safe(self) -> bool:
18
+ return not self.contains_entities
19
+
20
+ class PresidioEntityRecognitionSimpleResponse(BaseModel):
21
+ contains_entities: bool
22
+ explanation: str
23
+ anonymized_text: Optional[str] = None
24
+
25
+ @property
26
+ def safe(self) -> bool:
27
+ return not self.contains_entities
28
+
29
+ #TODO: Add support for transformers workflow and not just Spacy
30
+ class PresidioEntityRecognitionGuardrail(Guardrail):
31
+ @staticmethod
32
+ def get_available_entities() -> List[str]:
33
+ registry = RecognizerRegistry()
34
+ analyzer = AnalyzerEngine(registry=registry)
35
+ return [recognizer.supported_entities[0]
36
+ for recognizer in analyzer.registry.recognizers]
37
+
38
+ analyzer: AnalyzerEngine
39
+ anonymizer: AnonymizerEngine
40
+ selected_entities: List[str]
41
+ should_anonymize: bool
42
+ language: str
43
+
44
+ def __init__(
45
+ self,
46
+ selected_entities: Optional[List[str]] = None,
47
+ should_anonymize: bool = False,
48
+ language: str = "en",
49
+ deny_lists: Optional[Dict[str, List[str]]] = None,
50
+ regex_patterns: Optional[Dict[str, List[Dict[str, str]]]] = None,
51
+ custom_recognizers: Optional[List[Any]] = None,
52
+ show_available_entities: bool = False
53
+ ):
54
+ # If show_available_entities is True, print available entities
55
+ if show_available_entities:
56
+ available_entities = self.get_available_entities()
57
+ print("\nAvailable entities:")
58
+ print("=" * 25)
59
+ for entity in available_entities:
60
+ print(f"- {entity}")
61
+ print("=" * 25 + "\n")
62
+
63
+ # Initialize default values
64
+ if selected_entities is None:
65
+ selected_entities = [
66
+ "CREDIT_CARD", "US_SSN", "EMAIL_ADDRESS", "PHONE_NUMBER",
67
+ "IP_ADDRESS", "URL", "DATE_TIME"
68
+ ]
69
+
70
+ # Get available entities dynamically
71
+ available_entities = self.get_available_entities()
72
+
73
+ # Filter out invalid entities and warn user
74
+ invalid_entities = [e for e in selected_entities if e not in available_entities]
75
+ valid_entities = [e for e in selected_entities if e in available_entities]
76
+
77
+ if invalid_entities:
78
+ print(f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}")
79
+ print(f"Continuing with valid entities: {valid_entities}")
80
+ selected_entities = valid_entities
81
+
82
+ # Initialize analyzer with default recognizers
83
+ analyzer = AnalyzerEngine()
84
+
85
+ # Add custom recognizers if provided
86
+ if custom_recognizers:
87
+ for recognizer in custom_recognizers:
88
+ analyzer.registry.add_recognizer(recognizer)
89
+
90
+ # Add deny list recognizers if provided
91
+ if deny_lists:
92
+ for entity_type, tokens in deny_lists.items():
93
+ deny_list_recognizer = PatternRecognizer(
94
+ supported_entity=entity_type,
95
+ deny_list=tokens
96
+ )
97
+ analyzer.registry.add_recognizer(deny_list_recognizer)
98
+
99
+ # Add regex pattern recognizers if provided
100
+ if regex_patterns:
101
+ for entity_type, patterns in regex_patterns.items():
102
+ presidio_patterns = [
103
+ Pattern(
104
+ name=pattern.get("name", f"pattern_{i}"),
105
+ regex=pattern["regex"],
106
+ score=pattern.get("score", 0.5)
107
+ ) for i, pattern in enumerate(patterns)
108
+ ]
109
+ regex_recognizer = PatternRecognizer(
110
+ supported_entity=entity_type,
111
+ patterns=presidio_patterns
112
+ )
113
+ analyzer.registry.add_recognizer(regex_recognizer)
114
+
115
+ # Initialize Presidio engines
116
+ anonymizer = AnonymizerEngine()
117
+
118
+ # Call parent class constructor with all fields
119
+ super().__init__(
120
+ analyzer=analyzer,
121
+ anonymizer=anonymizer,
122
+ selected_entities=selected_entities,
123
+ should_anonymize=should_anonymize,
124
+ language=language
125
+ )
126
+
127
+ @weave.op()
128
+ def guard(self, prompt: str, return_detected_types: bool = True, **kwargs) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
129
+ """
130
+ Check if the input prompt contains any entities using Presidio.
131
+
132
+ Args:
133
+ prompt: The text to analyze
134
+ return_detected_types: If True, returns detailed entity type information
135
+ """
136
+ # Analyze text for entities
137
+ analyzer_results = self.analyzer.analyze(
138
+ text=prompt,
139
+ entities=self.selected_entities,
140
+ language=self.language
141
+ )
142
+
143
+ # Group results by entity type
144
+ detected_entities = {}
145
+ for result in analyzer_results:
146
+ entity_type = result.entity_type
147
+ text_slice = prompt[result.start:result.end]
148
+ if entity_type not in detected_entities:
149
+ detected_entities[entity_type] = []
150
+ detected_entities[entity_type].append(text_slice)
151
+
152
+ # Create explanation
153
+ explanation_parts = []
154
+ if detected_entities:
155
+ explanation_parts.append("Found the following entities in the text:")
156
+ for entity_type, instances in detected_entities.items():
157
+ explanation_parts.append(f"- {entity_type}: {len(instances)} instance(s)")
158
+ else:
159
+ explanation_parts.append("No entities detected in the text.")
160
+
161
+ # Add information about what was checked
162
+ explanation_parts.append("\nChecked for these entity types:")
163
+ for entity in self.selected_entities:
164
+ explanation_parts.append(f"- {entity}")
165
+
166
+ # Anonymize if requested
167
+ anonymized_text = None
168
+ if self.should_anonymize and detected_entities:
169
+ anonymized_result = self.anonymizer.anonymize(
170
+ text=prompt,
171
+ analyzer_results=analyzer_results
172
+ )
173
+ anonymized_text = anonymized_result.text
174
+
175
+ if return_detected_types:
176
+ return PresidioEntityRecognitionResponse(
177
+ contains_entities=bool(detected_entities),
178
+ detected_entities=detected_entities,
179
+ explanation="\n".join(explanation_parts),
180
+ anonymized_text=anonymized_text
181
+ )
182
+ else:
183
+ return PresidioEntityRecognitionSimpleResponse(
184
+ contains_entities=bool(detected_entities),
185
+ explanation="\n".join(explanation_parts),
186
+ anonymized_text=anonymized_text
187
+ )
188
+
189
+ @weave.op()
190
+ def predict(self, prompt: str, return_detected_types: bool = True, **kwargs) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse:
191
+ return self.guard(prompt, return_detected_types=return_detected_types, **kwargs)
guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, ClassVar
2
+
3
+ import weave
4
+ from pydantic import BaseModel
5
+
6
+ from ...regex_model import RegexModel
7
+ from ..base import Guardrail
8
+ import re
9
+
10
+
11
+ class RegexEntityRecognitionResponse(BaseModel):
12
+ contains_entities: bool
13
+ detected_entities: Dict[str, list[str]]
14
+ explanation: str
15
+ anonymized_text: Optional[str] = None
16
+
17
+ @property
18
+ def safe(self) -> bool:
19
+ return not self.contains_entities
20
+
21
+
22
+ class RegexEntityRecognitionSimpleResponse(BaseModel):
23
+ contains_entities: bool
24
+ explanation: str
25
+ anonymized_text: Optional[str] = None
26
+
27
+ @property
28
+ def safe(self) -> bool:
29
+ return not self.contains_entities
30
+
31
+
32
+ class RegexEntityRecognitionGuardrail(Guardrail):
33
+ regex_model: RegexModel
34
+ patterns: Dict[str, str] = {}
35
+ should_anonymize: bool = False
36
+
37
+ DEFAULT_PATTERNS: ClassVar[Dict[str, str]] = {
38
+ "email": r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
39
+ "phone_number": r"\b(?:\+?1[-.]?)?\(?(?:[0-9]{3})\)?[-.]?(?:[0-9]{3})[-.]?(?:[0-9]{4})\b",
40
+ "ssn": r"\b\d{3}[-]?\d{2}[-]?\d{4}\b",
41
+ "credit_card": r"\b\d{4}[-.]?\d{4}[-.]?\d{4}[-.]?\d{4}\b",
42
+ "ip_address": r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b",
43
+ "date_of_birth": r"\b\d{2}[-/]\d{2}[-/]\d{4}\b",
44
+ "passport": r"\b[A-Z]{1,2}[0-9]{6,9}\b",
45
+ "drivers_license": r"\b[A-Z]\d{7}\b",
46
+ "bank_account": r"\b\d{8,17}\b",
47
+ "zip_code": r"\b\d{5}(?:[-]\d{4})?\b"
48
+ }
49
+
50
+ def __init__(self, use_defaults: bool = True, should_anonymize: bool = False, **kwargs):
51
+ patterns = {}
52
+ if use_defaults:
53
+ patterns = self.DEFAULT_PATTERNS.copy()
54
+ if kwargs.get("patterns"):
55
+ patterns.update(kwargs["patterns"])
56
+
57
+ # Create the RegexModel instance
58
+ regex_model = RegexModel(patterns=patterns)
59
+
60
+ # Initialize the base class with both the regex_model and patterns
61
+ super().__init__(
62
+ regex_model=regex_model,
63
+ patterns=patterns,
64
+ should_anonymize=should_anonymize
65
+ )
66
+
67
+ def text_to_pattern(self, text: str) -> str:
68
+ """
69
+ Convert input text into a regex pattern that matches the exact text.
70
+ """
71
+ # Escape special regex characters in the text
72
+ escaped_text = re.escape(text)
73
+ # Create a pattern that matches the exact text, case-insensitive
74
+ return rf"\b{escaped_text}\b"
75
+
76
+ @weave.op()
77
+ def guard(self, prompt: str, custom_terms: Optional[list[str]] = None, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
78
+ """
79
+ Check if the input prompt contains any entities based on the regex patterns.
80
+
81
+ Args:
82
+ prompt: Input text to check for entities
83
+ custom_terms: List of custom terms to be converted into regex patterns. If provided,
84
+ only these terms will be checked, ignoring default patterns.
85
+ return_detected_types: If True, returns detailed entity type information
86
+
87
+ Returns:
88
+ RegexEntityRecognitionResponse or RegexEntityRecognitionSimpleResponse containing detection results
89
+ """
90
+ if custom_terms:
91
+ # Create a temporary RegexModel with only the custom patterns
92
+ temp_patterns = {term: self.text_to_pattern(term) for term in custom_terms}
93
+ temp_model = RegexModel(patterns=temp_patterns)
94
+ result = temp_model.check(prompt)
95
+ else:
96
+ # Use the original regex_model if no custom terms provided
97
+ result = self.regex_model.check(prompt)
98
+
99
+ # Create detailed explanation
100
+ explanation_parts = []
101
+ if result.matched_patterns:
102
+ explanation_parts.append("Found the following entities in the text:")
103
+ for entity_type, matches in result.matched_patterns.items():
104
+ explanation_parts.append(f"- {entity_type}: {len(matches)} instance(s)")
105
+ else:
106
+ explanation_parts.append("No entities detected in the text.")
107
+
108
+ if result.failed_patterns:
109
+ explanation_parts.append("\nChecked but did not find these entity types:")
110
+ for pattern in result.failed_patterns:
111
+ explanation_parts.append(f"- {pattern}")
112
+
113
+ # Updated anonymization logic
114
+ anonymized_text = None
115
+ if getattr(self, 'should_anonymize', False) and result.matched_patterns:
116
+ anonymized_text = prompt
117
+ for entity_type, matches in result.matched_patterns.items():
118
+ for match in matches:
119
+ replacement = "[redacted]" if aggregate_redaction else f"[{entity_type.upper()}]"
120
+ anonymized_text = anonymized_text.replace(match, replacement)
121
+
122
+ if return_detected_types:
123
+ return RegexEntityRecognitionResponse(
124
+ contains_entities=not result.passed,
125
+ detected_entities=result.matched_patterns,
126
+ explanation="\n".join(explanation_parts),
127
+ anonymized_text=anonymized_text
128
+ )
129
+ else:
130
+ return RegexEntityRecognitionSimpleResponse(
131
+ contains_entities=not result.passed,
132
+ explanation="\n".join(explanation_parts),
133
+ anonymized_text=anonymized_text
134
+ )
135
+
136
+ @weave.op()
137
+ def predict(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
138
+ return self.guard(prompt, return_detected_types=return_detected_types, aggregate_redaction=aggregate_redaction, **kwargs)
guardrails_genie/guardrails/entity_recognition/transformers_entity_recognition_guardrail.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional, ClassVar
2
+ from transformers import pipeline, AutoConfig
3
+ import json
4
+ from pydantic import BaseModel
5
+ from ..base import Guardrail
6
+ import weave
7
+
8
+ class TransformersEntityRecognitionResponse(BaseModel):
9
+ contains_entities: bool
10
+ detected_entities: Dict[str, List[str]]
11
+ explanation: str
12
+ anonymized_text: Optional[str] = None
13
+
14
+ @property
15
+ def safe(self) -> bool:
16
+ return not self.contains_entities
17
+
18
+ class TransformersEntityRecognitionSimpleResponse(BaseModel):
19
+ contains_entities: bool
20
+ explanation: str
21
+ anonymized_text: Optional[str] = None
22
+
23
+ @property
24
+ def safe(self) -> bool:
25
+ return not self.contains_entities
26
+
27
+ class TransformersEntityRecognitionGuardrail(Guardrail):
28
+ """Generic guardrail for detecting entities using any token classification model."""
29
+
30
+ _pipeline: Optional[object] = None
31
+ selected_entities: List[str]
32
+ should_anonymize: bool
33
+ available_entities: List[str]
34
+
35
+ def __init__(
36
+ self,
37
+ model_name: str = "iiiorg/piiranha-v1-detect-personal-information",
38
+ selected_entities: Optional[List[str]] = None,
39
+ should_anonymize: bool = False,
40
+ show_available_entities: bool = True,
41
+ ):
42
+ # Load model config and extract available entities
43
+ config = AutoConfig.from_pretrained(model_name)
44
+ entities = self._extract_entities_from_config(config)
45
+
46
+ if show_available_entities:
47
+ self._print_available_entities(entities)
48
+
49
+ # Initialize default values if needed
50
+ if selected_entities is None:
51
+ selected_entities = entities # Use all available entities by default
52
+
53
+ # Filter out invalid entities and warn user
54
+ invalid_entities = [e for e in selected_entities if e not in entities]
55
+ valid_entities = [e for e in selected_entities if e in entities]
56
+
57
+ if invalid_entities:
58
+ print(f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}")
59
+ print(f"Continuing with valid entities: {valid_entities}")
60
+ selected_entities = valid_entities
61
+
62
+ # Call parent class constructor
63
+ super().__init__(
64
+ selected_entities=selected_entities,
65
+ should_anonymize=should_anonymize,
66
+ available_entities=entities
67
+ )
68
+
69
+ # Initialize pipeline
70
+ self._pipeline = pipeline(
71
+ task="token-classification",
72
+ model=model_name,
73
+ aggregation_strategy="simple" # Merge same entities
74
+ )
75
+
76
+ def _extract_entities_from_config(self, config) -> List[str]:
77
+ """Extract unique entity types from the model config."""
78
+ # Get id2label mapping from config
79
+ id2label = config.id2label
80
+
81
+ # Extract unique entity types (removing B- and I- prefixes)
82
+ entities = set()
83
+ for label in id2label.values():
84
+ if label.startswith(('B-', 'I-')):
85
+ entities.add(label[2:]) # Remove prefix
86
+ elif label != 'O': # Skip the 'O' (Outside) label
87
+ entities.add(label)
88
+
89
+ return sorted(list(entities))
90
+
91
+ def _print_available_entities(self, entities: List[str]):
92
+ """Print all available entity types that can be detected by the model."""
93
+ print("\nAvailable entity types:")
94
+ print("=" * 25)
95
+ for entity in entities:
96
+ print(f"- {entity}")
97
+ print("=" * 25 + "\n")
98
+
99
+ def print_available_entities(self):
100
+ """Print all available entity types that can be detected by the model."""
101
+ self._print_available_entities(self.available_entities)
102
+
103
+ def _detect_entities(self, text: str) -> Dict[str, List[str]]:
104
+ """Detect entities in the text using the pipeline."""
105
+ results = self._pipeline(text)
106
+
107
+ # Group findings by entity type
108
+ detected_entities = {}
109
+ for entity in results:
110
+ entity_type = entity['entity_group']
111
+ if entity_type in self.selected_entities:
112
+ if entity_type not in detected_entities:
113
+ detected_entities[entity_type] = []
114
+ detected_entities[entity_type].append(entity['word'])
115
+
116
+ return detected_entities
117
+
118
+ def _anonymize_text(self, text: str, aggregate_redaction: bool = True) -> str:
119
+ """Anonymize detected entities in text using the pipeline."""
120
+ results = self._pipeline(text)
121
+
122
+ # Sort entities by start position in reverse order to avoid offset issues
123
+ entities = sorted(results, key=lambda x: x['start'], reverse=True)
124
+
125
+ # Create a mutable list of characters
126
+ chars = list(text)
127
+
128
+ # Apply redactions
129
+ for entity in entities:
130
+ if entity['entity_group'] in self.selected_entities:
131
+ start, end = entity['start'], entity['end']
132
+ replacement = ' [redacted] ' if aggregate_redaction else f" [{entity['entity_group']}] "
133
+
134
+ # Replace the entity with the redaction marker
135
+ chars[start:end] = replacement
136
+
137
+ # Join characters and clean up only consecutive spaces (preserving newlines)
138
+ result = ''.join(chars)
139
+ # Replace multiple spaces with single space, but preserve newlines
140
+ lines = result.split('\n')
141
+ cleaned_lines = [' '.join(line.split()) for line in lines]
142
+ return '\n'.join(cleaned_lines)
143
+
144
+ @weave.op()
145
+ def guard(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True) -> TransformersEntityRecognitionResponse | TransformersEntityRecognitionSimpleResponse:
146
+ """Check if the input prompt contains any entities using the transformer pipeline.
147
+
148
+ Args:
149
+ prompt: The text to analyze
150
+ return_detected_types: If True, returns detailed entity type information
151
+ aggregate_redaction: If True, uses generic [redacted] instead of entity type
152
+ """
153
+ # Detect entities
154
+ detected_entities = self._detect_entities(prompt)
155
+
156
+ # Create explanation
157
+ explanation_parts = []
158
+ if detected_entities:
159
+ explanation_parts.append("Found the following entities in the text:")
160
+ for entity_type, instances in detected_entities.items():
161
+ explanation_parts.append(f"- {entity_type}: {len(instances)} instance(s)")
162
+ else:
163
+ explanation_parts.append("No entities detected in the text.")
164
+
165
+ explanation_parts.append("\nChecked for these entities:")
166
+ for entity in self.selected_entities:
167
+ explanation_parts.append(f"- {entity}")
168
+
169
+ # Anonymize if requested
170
+ anonymized_text = None
171
+ if self.should_anonymize and detected_entities:
172
+ anonymized_text = self._anonymize_text(prompt, aggregate_redaction)
173
+
174
+ if return_detected_types:
175
+ return TransformersEntityRecognitionResponse(
176
+ contains_entities=bool(detected_entities),
177
+ detected_entities=detected_entities,
178
+ explanation="\n".join(explanation_parts),
179
+ anonymized_text=anonymized_text
180
+ )
181
+ else:
182
+ return TransformersEntityRecognitionSimpleResponse(
183
+ contains_entities=bool(detected_entities),
184
+ explanation="\n".join(explanation_parts),
185
+ anonymized_text=anonymized_text
186
+ )
187
+
188
+ @weave.op()
189
+ def predict(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> TransformersEntityRecognitionResponse | TransformersEntityRecognitionSimpleResponse:
190
+ return self.guard(prompt, return_detected_types=return_detected_types, aggregate_redaction=aggregate_redaction, **kwargs)
guardrails_genie/guardrails/manager.py CHANGED
@@ -1,5 +1,6 @@
1
  import weave
2
  from rich.progress import track
 
3
 
4
  from .base import Guardrail
5
 
@@ -20,10 +21,12 @@ class GuardrailManager(weave.Model):
20
  alerts.append(
21
  {"guardrail_name": guardrail.__class__.__name__, "response": response}
22
  )
23
- safe = safe and response["safe"]
24
- summaries += (
25
- f"**{guardrail.__class__.__name__}**: {response['summary']}\n\n---\n\n"
26
- )
 
 
27
  return {"safe": safe, "alerts": alerts, "summary": summaries}
28
 
29
  @weave.op()
 
1
  import weave
2
  from rich.progress import track
3
+ from pydantic import BaseModel
4
 
5
  from .base import Guardrail
6
 
 
21
  alerts.append(
22
  {"guardrail_name": guardrail.__class__.__name__, "response": response}
23
  )
24
+ if isinstance(response, BaseModel):
25
+ safe = safe and response.safe
26
+ summaries += f"**{guardrail.__class__.__name__}**: {response.explanation}\n\n---\n\n"
27
+ else:
28
+ safe = safe and response["safe"]
29
+ summaries += f"**{guardrail.__class__.__name__}**: {response['summary']}\n\n---\n\n"
30
  return {"safe": safe, "alerts": alerts, "summary": summaries}
31
 
32
  @weave.op()
guardrails_genie/regex_model.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional
2
+ import re
3
+ import weave
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class RegexResult(BaseModel):
8
+ passed: bool
9
+ matched_patterns: Dict[str, List[str]]
10
+ failed_patterns: List[str]
11
+
12
+
13
+ class RegexModel(weave.Model):
14
+ patterns: Dict[str, str]
15
+
16
+ def __init__(self, patterns: Dict[str, str]) -> None:
17
+ """
18
+ Initialize RegexModel with a dictionary of patterns.
19
+
20
+ Args:
21
+ patterns: Dictionary where key is pattern name and value is regex pattern
22
+ Example: {"email": r"[^@ \t\r\n]+@[^@ \t\r\n]+\.[^@ \t\r\n]+",
23
+ "phone": r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b"}
24
+ """
25
+ super().__init__(patterns=patterns)
26
+ self._compiled_patterns = {
27
+ name: re.compile(pattern) for name, pattern in patterns.items()
28
+ }
29
+
30
+ @weave.op()
31
+ def check(self, text: str) -> RegexResult:
32
+ """
33
+ Check text against all patterns and return detailed results.
34
+
35
+ Args:
36
+ text: Input text to check against patterns
37
+
38
+ Returns:
39
+ RegexResult containing pass/fail status and details about matches
40
+ """
41
+ matches: Dict[str, List[str]] = {}
42
+ failed_patterns: List[str] = []
43
+
44
+ for pattern_name, compiled_pattern in self._compiled_patterns.items():
45
+ found_matches = compiled_pattern.findall(text)
46
+ if found_matches:
47
+ matches[pattern_name] = found_matches
48
+ else:
49
+ failed_patterns.append(pattern_name)
50
+
51
+ # Consider it passed only if no patterns matched (no PII found)
52
+ passed = len(matches) == 0
53
+
54
+ return RegexResult(
55
+ passed=passed,
56
+ matched_patterns=matches,
57
+ failed_patterns=failed_patterns
58
+ )
59
+
60
+ @weave.op()
61
+ def predict(self, text: str) -> RegexResult:
62
+ """
63
+ Alias for check() to maintain consistency with other models.
64
+ """
65
+ return self.check(text)
pyproject.toml CHANGED
@@ -20,6 +20,8 @@ dependencies = [
20
  "pymupdf4llm>=0.0.17",
21
  "transformers>=4.46.3",
22
  "torch>=2.5.1",
 
 
23
  ]
24
 
25
  [tool.setuptools]
 
20
  "pymupdf4llm>=0.0.17",
21
  "transformers>=4.46.3",
22
  "torch>=2.5.1",
23
+ "presidio-analyzer>=2.2.355",
24
+ "presidio-anonymizer>=2.2.355",
25
  ]
26
 
27
  [tool.setuptools]