Spaces:
Running
Running
Add banned terms guardrails
Browse files- application_pages/chat_app.py +7 -0
- guardrails_genie/guardrails/__init__.py +2 -0
- guardrails_genie/guardrails/entity_recognition/__init__.py +2 -1
- guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_benchmark.py +0 -0
- guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_examples.py +178 -0
- guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_llm_judge.py +47 -0
- guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_regex_model.py +46 -0
- guardrails_genie/guardrails/entity_recognition/llm_judge_entity_recognition_guardrail.py +166 -3
- guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py +25 -6
application_pages/chat_app.py
CHANGED
@@ -82,6 +82,13 @@ def initialize_guardrails():
|
|
82 |
guardrail_name,
|
83 |
)()
|
84 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
st.session_state.guardrails_manager = GuardrailManager(
|
86 |
guardrails=st.session_state.guardrails
|
87 |
)
|
|
|
82 |
guardrail_name,
|
83 |
)()
|
84 |
)
|
85 |
+
elif guardrail_name == "RestrictedTermsJudge":
|
86 |
+
st.session_state.guardrails.append(
|
87 |
+
getattr(
|
88 |
+
importlib.import_module("guardrails_genie.guardrails"),
|
89 |
+
guardrail_name,
|
90 |
+
)()
|
91 |
+
)
|
92 |
st.session_state.guardrails_manager = GuardrailManager(
|
93 |
guardrails=st.session_state.guardrails
|
94 |
)
|
guardrails_genie/guardrails/__init__.py
CHANGED
@@ -6,6 +6,7 @@ from .entity_recognition import (
|
|
6 |
PresidioEntityRecognitionGuardrail,
|
7 |
RegexEntityRecognitionGuardrail,
|
8 |
TransformersEntityRecognitionGuardrail,
|
|
|
9 |
)
|
10 |
from .manager import GuardrailManager
|
11 |
|
@@ -15,5 +16,6 @@ __all__ = [
|
|
15 |
"PresidioEntityRecognitionGuardrail",
|
16 |
"RegexEntityRecognitionGuardrail",
|
17 |
"TransformersEntityRecognitionGuardrail",
|
|
|
18 |
"GuardrailManager",
|
19 |
]
|
|
|
6 |
PresidioEntityRecognitionGuardrail,
|
7 |
RegexEntityRecognitionGuardrail,
|
8 |
TransformersEntityRecognitionGuardrail,
|
9 |
+
RestrictedTermsJudge,
|
10 |
)
|
11 |
from .manager import GuardrailManager
|
12 |
|
|
|
16 |
"PresidioEntityRecognitionGuardrail",
|
17 |
"RegexEntityRecognitionGuardrail",
|
18 |
"TransformersEntityRecognitionGuardrail",
|
19 |
+
"RestrictedTermsJudge",
|
20 |
"GuardrailManager",
|
21 |
]
|
guardrails_genie/guardrails/entity_recognition/__init__.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
from .presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
|
2 |
from .regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
|
3 |
from .transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
|
4 |
-
|
5 |
__all__ = [
|
6 |
"PresidioEntityRecognitionGuardrail",
|
7 |
"RegexEntityRecognitionGuardrail",
|
8 |
"TransformersEntityRecognitionGuardrail",
|
|
|
9 |
]
|
|
|
1 |
from .presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
|
2 |
from .regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
|
3 |
from .transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
|
4 |
+
from .llm_judge_entity_recognition_guardrail import RestrictedTermsJudge
|
5 |
__all__ = [
|
6 |
"PresidioEntityRecognitionGuardrail",
|
7 |
"RegexEntityRecognitionGuardrail",
|
8 |
"TransformersEntityRecognitionGuardrail",
|
9 |
+
"RestrictedTermsJudge"
|
10 |
]
|
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_benchmark.py
ADDED
File without changes
|
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_examples.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Collection of restricted terms test examples with expected outcomes for entity recognition testing.
|
3 |
+
Focuses on banned terms, competitor mentions, and brand protection scenarios.
|
4 |
+
"""
|
5 |
+
|
6 |
+
RESTRICTED_TERMS_EXAMPLES = [
|
7 |
+
{
|
8 |
+
"description": "Competitor Product Discussion",
|
9 |
+
"input_text": """
|
10 |
+
I think we should implement features similar to Salesforce's Einstein AI
|
11 |
+
and Oracle's Cloud Infrastructure. Maybe we could also look at how
|
12 |
+
AWS handles their lambda functions.
|
13 |
+
""",
|
14 |
+
"custom_terms": ["Salesforce", "Oracle", "AWS", "Einstein AI", "Cloud Infrastructure", "lambda"],
|
15 |
+
"expected_entities": {
|
16 |
+
"Salesforce": ["Salesforce"],
|
17 |
+
"Oracle": ["Oracle"],
|
18 |
+
"AWS": ["AWS"],
|
19 |
+
"Einstein AI": ["Einstein AI"],
|
20 |
+
"Cloud Infrastructure": ["Cloud Infrastructure"],
|
21 |
+
"lambda": ["lambda"]
|
22 |
+
}
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"description": "Inappropriate Language in Support Ticket",
|
26 |
+
"input_text": """
|
27 |
+
This damn product keeps crashing! What the hell is wrong with your
|
28 |
+
stupid service? I've wasted so much freaking time on this crap.
|
29 |
+
""",
|
30 |
+
"custom_terms": ["damn", "hell", "stupid", "crap"],
|
31 |
+
"expected_entities": {
|
32 |
+
"damn": ["damn"],
|
33 |
+
"hell": ["hell"],
|
34 |
+
"stupid": ["stupid"],
|
35 |
+
"crap": ["crap"]
|
36 |
+
}
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"description": "Confidential Project Names",
|
40 |
+
"input_text": """
|
41 |
+
Project Titan's launch date has been moved up. We should coordinate
|
42 |
+
with Project Phoenix team and the Blue Dragon initiative for resource allocation.
|
43 |
+
""",
|
44 |
+
"custom_terms": ["Project Titan", "Project Phoenix", "Blue Dragon"],
|
45 |
+
"expected_entities": {
|
46 |
+
"Project Titan": ["Project Titan"],
|
47 |
+
"Project Phoenix": ["Project Phoenix"],
|
48 |
+
"Blue Dragon": ["Blue Dragon"]
|
49 |
+
}
|
50 |
+
}
|
51 |
+
]
|
52 |
+
|
53 |
+
# Edge cases and special formats
|
54 |
+
EDGE_CASE_EXAMPLES = [
|
55 |
+
{
|
56 |
+
"description": "Common Corporate Abbreviations and Stock Symbols",
|
57 |
+
"input_text": """
|
58 |
+
MSFT's Azure and O365 platform is gaining market share.
|
59 |
+
Have you seen what GOOGL/GOOG and FB/META are doing with their AI?
|
60 |
+
CRM (Salesforce) and ORCL (Oracle) have interesting features too.
|
61 |
+
""",
|
62 |
+
"custom_terms": ["Microsoft", "Google", "Meta", "Facebook", "Salesforce", "Oracle"],
|
63 |
+
"expected_entities": {
|
64 |
+
"Microsoft": ["MSFT"],
|
65 |
+
"Google": ["GOOGL", "GOOG"],
|
66 |
+
"Meta": ["META"],
|
67 |
+
"Facebook": ["FB"],
|
68 |
+
"Salesforce": ["CRM", "Salesforce"],
|
69 |
+
"Oracle": ["ORCL"]
|
70 |
+
}
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"description": "L33t Speak and Intentional Obfuscation",
|
74 |
+
"input_text": """
|
75 |
+
S4l3sf0rc3 is better than 0r4cl3!
|
76 |
+
M1cr0$oft and G00gl3 are the main competitors.
|
77 |
+
Let's check F8book and Met@ too.
|
78 |
+
""",
|
79 |
+
"custom_terms": ["Salesforce", "Oracle", "Microsoft", "Google", "Facebook", "Meta"],
|
80 |
+
"expected_entities": {
|
81 |
+
"Salesforce": ["S4l3sf0rc3"],
|
82 |
+
"Oracle": ["0r4cl3"],
|
83 |
+
"Microsoft": ["M1cr0$oft"],
|
84 |
+
"Google": ["G00gl3"],
|
85 |
+
"Facebook": ["F8book"],
|
86 |
+
"Meta": ["Met@"]
|
87 |
+
}
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"description": "Case Variations and Partial Matches",
|
91 |
+
"input_text": """
|
92 |
+
salesFORCE and ORACLE are competitors.
|
93 |
+
MicroSoft and google are too.
|
94 |
+
Have you tried micro-soft or Google_Cloud?
|
95 |
+
""",
|
96 |
+
"custom_terms": ["Microsoft", "Google", "Salesforce", "Oracle"],
|
97 |
+
"expected_entities": {
|
98 |
+
"Microsoft": ["MicroSoft", "micro-soft"],
|
99 |
+
"Google": ["google", "Google_Cloud"],
|
100 |
+
"Salesforce": ["salesFORCE"],
|
101 |
+
"Oracle": ["ORACLE"]
|
102 |
+
}
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"description": "Common Misspellings and Typos",
|
106 |
+
"input_text": """
|
107 |
+
Microsft and Microsooft are common typos.
|
108 |
+
Goggle, Googel, and Gooogle are search related.
|
109 |
+
Salezforce and Oracel need checking too.
|
110 |
+
""",
|
111 |
+
"custom_terms": ["Microsoft", "Google", "Salesforce", "Oracle"],
|
112 |
+
"expected_entities": {
|
113 |
+
"Microsoft": ["Microsft", "Microsooft"],
|
114 |
+
"Google": ["Goggle", "Googel", "Gooogle"],
|
115 |
+
"Salesforce": ["Salezforce"],
|
116 |
+
"Oracle": ["Oracel"]
|
117 |
+
}
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"description": "Mixed Variations and Context",
|
121 |
+
"input_text": """
|
122 |
+
The M$ cloud competes with AWS (Amazon Web Services).
|
123 |
+
FB/Meta's social platform and GOOGL's search dominate.
|
124 |
+
SF.com and Oracle-DB are industry standards.
|
125 |
+
""",
|
126 |
+
"custom_terms": ["Microsoft", "Amazon Web Services", "Facebook", "Meta", "Google", "Salesforce", "Oracle"],
|
127 |
+
"expected_entities": {
|
128 |
+
"Microsoft": ["M$"],
|
129 |
+
"Amazon Web Services": ["AWS"],
|
130 |
+
"Facebook": ["FB"],
|
131 |
+
"Meta": ["Meta"],
|
132 |
+
"Google": ["GOOGL"],
|
133 |
+
"Salesforce": ["SF.com"],
|
134 |
+
"Oracle": ["Oracle-DB"]
|
135 |
+
}
|
136 |
+
}
|
137 |
+
]
|
138 |
+
|
139 |
+
def validate_entities(detected: dict, expected: dict) -> bool:
|
140 |
+
"""Compare detected entities with expected entities"""
|
141 |
+
if set(detected.keys()) != set(expected.keys()):
|
142 |
+
return False
|
143 |
+
return all(set(detected[k]) == set(expected[k]) for k in expected.keys())
|
144 |
+
|
145 |
+
def run_test_case(guardrail, test_case, test_type="Main"):
|
146 |
+
"""Run a single test case and print results"""
|
147 |
+
print(f"\n{test_type} Test Case: {test_case['description']}")
|
148 |
+
print("-" * 50)
|
149 |
+
|
150 |
+
result = guardrail.guard(
|
151 |
+
test_case['input_text'],
|
152 |
+
custom_terms=test_case['custom_terms']
|
153 |
+
)
|
154 |
+
expected = test_case['expected_entities']
|
155 |
+
|
156 |
+
# Validate results
|
157 |
+
matches = validate_entities(result.detected_entities, expected)
|
158 |
+
|
159 |
+
print(f"Test Status: {'✓ PASS' if matches else '✗ FAIL'}")
|
160 |
+
print(f"Contains Restricted Terms: {result.contains_entities}")
|
161 |
+
|
162 |
+
if not matches:
|
163 |
+
print("\nEntity Comparison:")
|
164 |
+
all_entity_types = set(list(result.detected_entities.keys()) + list(expected.keys()))
|
165 |
+
for entity_type in all_entity_types:
|
166 |
+
detected = set(result.detected_entities.get(entity_type, []))
|
167 |
+
expected_set = set(expected.get(entity_type, []))
|
168 |
+
print(f"\nEntity Type: {entity_type}")
|
169 |
+
print(f" Expected: {sorted(expected_set)}")
|
170 |
+
print(f" Detected: {sorted(detected)}")
|
171 |
+
if detected != expected_set:
|
172 |
+
print(f" Missing: {sorted(expected_set - detected)}")
|
173 |
+
print(f" Extra: {sorted(detected - expected_set)}")
|
174 |
+
|
175 |
+
if result.anonymized_text:
|
176 |
+
print(f"\nAnonymized Text:\n{result.anonymized_text}")
|
177 |
+
|
178 |
+
return matches
|
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_llm_judge.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from guardrails_genie.guardrails.entity_recognition.llm_judge_entity_recognition_guardrail import RestrictedTermsJudge
|
2 |
+
from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
|
3 |
+
RESTRICTED_TERMS_EXAMPLES,
|
4 |
+
EDGE_CASE_EXAMPLES,
|
5 |
+
run_test_case
|
6 |
+
)
|
7 |
+
from guardrails_genie.llm import OpenAIModel
|
8 |
+
import weave
|
9 |
+
|
10 |
+
def test_restricted_terms_detection():
|
11 |
+
"""Test restricted terms detection scenarios using predefined test cases"""
|
12 |
+
weave.init("guardrails-genie-restricted-terms-llm-judge")
|
13 |
+
|
14 |
+
# Create the guardrail with OpenAI model
|
15 |
+
llm_judge = RestrictedTermsJudge(
|
16 |
+
should_anonymize=True,
|
17 |
+
llm_model=OpenAIModel()
|
18 |
+
)
|
19 |
+
|
20 |
+
# Test statistics
|
21 |
+
total_tests = len(RESTRICTED_TERMS_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
|
22 |
+
passed_tests = 0
|
23 |
+
|
24 |
+
# Test main restricted terms examples
|
25 |
+
print("\nRunning Main Restricted Terms Tests")
|
26 |
+
print("=" * 80)
|
27 |
+
for test_case in RESTRICTED_TERMS_EXAMPLES:
|
28 |
+
if run_test_case(llm_judge, test_case):
|
29 |
+
passed_tests += 1
|
30 |
+
|
31 |
+
# Test edge cases
|
32 |
+
print("\nRunning Edge Cases")
|
33 |
+
print("=" * 80)
|
34 |
+
for test_case in EDGE_CASE_EXAMPLES:
|
35 |
+
if run_test_case(llm_judge, test_case, "Edge"):
|
36 |
+
passed_tests += 1
|
37 |
+
|
38 |
+
# Print summary
|
39 |
+
print("\nTest Summary")
|
40 |
+
print("=" * 80)
|
41 |
+
print(f"Total Tests: {total_tests}")
|
42 |
+
print(f"Passed: {passed_tests}")
|
43 |
+
print(f"Failed: {total_tests - passed_tests}")
|
44 |
+
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
test_restricted_terms_detection()
|
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_regex_model.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
|
2 |
+
from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
|
3 |
+
RESTRICTED_TERMS_EXAMPLES,
|
4 |
+
EDGE_CASE_EXAMPLES,
|
5 |
+
run_test_case
|
6 |
+
)
|
7 |
+
import weave
|
8 |
+
|
9 |
+
def test_restricted_terms_detection():
|
10 |
+
"""Test restricted terms detection scenarios using predefined test cases"""
|
11 |
+
weave.init("guardrails-genie-restricted-terms-regex-model")
|
12 |
+
|
13 |
+
# Create the guardrail with anonymization enabled
|
14 |
+
regex_guardrail = RegexEntityRecognitionGuardrail(
|
15 |
+
use_defaults=False, # Don't use default PII patterns
|
16 |
+
should_anonymize=True
|
17 |
+
)
|
18 |
+
|
19 |
+
# Test statistics
|
20 |
+
total_tests = len(RESTRICTED_TERMS_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
|
21 |
+
passed_tests = 0
|
22 |
+
|
23 |
+
# Test main restricted terms examples
|
24 |
+
print("\nRunning Main Restricted Terms Tests")
|
25 |
+
print("=" * 80)
|
26 |
+
for test_case in RESTRICTED_TERMS_EXAMPLES:
|
27 |
+
if run_test_case(regex_guardrail, test_case):
|
28 |
+
passed_tests += 1
|
29 |
+
|
30 |
+
# Test edge cases
|
31 |
+
print("\nRunning Edge Cases")
|
32 |
+
print("=" * 80)
|
33 |
+
for test_case in EDGE_CASE_EXAMPLES:
|
34 |
+
if run_test_case(regex_guardrail, test_case, "Edge"):
|
35 |
+
passed_tests += 1
|
36 |
+
|
37 |
+
# Print summary
|
38 |
+
print("\nTest Summary")
|
39 |
+
print("=" * 80)
|
40 |
+
print(f"Total Tests: {total_tests}")
|
41 |
+
print(f"Passed: {passed_tests}")
|
42 |
+
print(f"Failed: {total_tests - passed_tests}")
|
43 |
+
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
44 |
+
|
45 |
+
if __name__ == "__main__":
|
46 |
+
test_restricted_terms_detection()
|
guardrails_genie/guardrails/entity_recognition/llm_judge_entity_recognition_guardrail.py
CHANGED
@@ -1,3 +1,166 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional
|
2 |
+
import weave
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
from typing_extensions import Annotated
|
5 |
+
|
6 |
+
from ...llm import OpenAIModel
|
7 |
+
from ..base import Guardrail
|
8 |
+
import instructor
|
9 |
+
|
10 |
+
|
11 |
+
class TermMatch(BaseModel):
|
12 |
+
"""Represents a matched term and its variations"""
|
13 |
+
original_term: str
|
14 |
+
matched_text: str
|
15 |
+
match_type: str = Field(
|
16 |
+
description="Type of match: EXACT, MISSPELLING, ABBREVIATION, or VARIANT"
|
17 |
+
)
|
18 |
+
explanation: str = Field(
|
19 |
+
description="Explanation of why this is considered a match"
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class RestrictedTermsAnalysis(BaseModel):
|
24 |
+
"""Analysis result for restricted terms detection"""
|
25 |
+
contains_restricted_terms: bool = Field(
|
26 |
+
description="Whether any restricted terms were detected"
|
27 |
+
)
|
28 |
+
detected_matches: List[TermMatch] = Field(
|
29 |
+
default_factory=list,
|
30 |
+
description="List of detected term matches with their variations"
|
31 |
+
)
|
32 |
+
explanation: str = Field(
|
33 |
+
description="Detailed explanation of the analysis"
|
34 |
+
)
|
35 |
+
anonymized_text: Optional[str] = Field(
|
36 |
+
default=None,
|
37 |
+
description="Text with restricted terms replaced with category tags"
|
38 |
+
)
|
39 |
+
|
40 |
+
@property
|
41 |
+
def safe(self) -> bool:
|
42 |
+
return not self.contains_restricted_terms
|
43 |
+
|
44 |
+
|
45 |
+
class RestrictedTermsRecognitionResponse(BaseModel):
|
46 |
+
contains_entities: bool
|
47 |
+
detected_entities: Dict[str, List[str]]
|
48 |
+
explanation: str
|
49 |
+
anonymized_text: Optional[str] = None
|
50 |
+
|
51 |
+
@property
|
52 |
+
def safe(self) -> bool:
|
53 |
+
return not self.contains_entities
|
54 |
+
|
55 |
+
|
56 |
+
class RestrictedTermsJudge(Guardrail):
|
57 |
+
llm_model: OpenAIModel = Field(default_factory=lambda: OpenAIModel())
|
58 |
+
should_anonymize: bool = False
|
59 |
+
|
60 |
+
def __init__(self, should_anonymize: bool = False, **kwargs):
|
61 |
+
super().__init__(**kwargs)
|
62 |
+
self.should_anonymize = should_anonymize
|
63 |
+
# Pass the OpenAI client to instructor.patch()
|
64 |
+
instructor.patch(client=self.llm_model._openai_client)
|
65 |
+
|
66 |
+
def format_prompts(self, text: str, custom_terms: List[str]) -> tuple[str, str]:
|
67 |
+
user_prompt = f"""Analyze the following text for restricted terms and variations:
|
68 |
+
|
69 |
+
<text_to_analyze>
|
70 |
+
{text}
|
71 |
+
</text_to_analyze>
|
72 |
+
|
73 |
+
Look for these specific terms and their variations:
|
74 |
+
{', '.join(custom_terms)}
|
75 |
+
|
76 |
+
Analyze the text carefully for:
|
77 |
+
1. Exact matches
|
78 |
+
2. Common misspellings
|
79 |
+
3. Abbreviations
|
80 |
+
4. Intentional variations (l33t speak, special characters, etc.)
|
81 |
+
5. Case variations
|
82 |
+
"""
|
83 |
+
|
84 |
+
system_prompt = """You are an expert system for detecting restricted terms, brand mentions, and inappropriate content.
|
85 |
+
|
86 |
+
Your task is to:
|
87 |
+
1. Identify exact matches of restricted terms
|
88 |
+
2. Detect variations including:
|
89 |
+
- Misspellings (both accidental and intentional)
|
90 |
+
- Abbreviations and acronyms
|
91 |
+
- Case variations
|
92 |
+
- L33t speak or special character substitutions
|
93 |
+
- Partial matches within larger words
|
94 |
+
|
95 |
+
For each match, you must:
|
96 |
+
1. Identify the original restricted term
|
97 |
+
2. Note the actual text that matched
|
98 |
+
3. Classify the match type
|
99 |
+
4. Provide a confidence score
|
100 |
+
5. Explain why it's considered a match
|
101 |
+
|
102 |
+
Be thorough but avoid false positives. Focus on meaningful matches that indicate actual attempts to use restricted terms.
|
103 |
+
|
104 |
+
Return your analysis in the structured format specified by the RestrictedTermsAnalysis model."""
|
105 |
+
|
106 |
+
return user_prompt, system_prompt
|
107 |
+
|
108 |
+
@weave.op()
|
109 |
+
def predict(self, text: str, custom_terms: List[str], **kwargs) -> RestrictedTermsAnalysis:
|
110 |
+
user_prompt, system_prompt = self.format_prompts(text, custom_terms)
|
111 |
+
|
112 |
+
response = self.llm_model.predict(
|
113 |
+
user_prompts=user_prompt,
|
114 |
+
system_prompt=system_prompt,
|
115 |
+
response_format=RestrictedTermsAnalysis,
|
116 |
+
temperature=0.1, # Lower temperature for more consistent analysis
|
117 |
+
**kwargs
|
118 |
+
)
|
119 |
+
|
120 |
+
return response.choices[0].message.parsed
|
121 |
+
|
122 |
+
#TODO: Remove default custom_terms
|
123 |
+
@weave.op()
|
124 |
+
def guard(self, text: str, custom_terms: List[str] = ["Microsoft", "Amazon Web Services", "Facebook", "Meta", "Google", "Salesforce", "Oracle"], aggregate_redaction: bool = True, **kwargs) -> RestrictedTermsRecognitionResponse:
|
125 |
+
"""
|
126 |
+
Guard against restricted terms and their variations.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
text: Text to analyze
|
130 |
+
custom_terms: List of restricted terms to check for
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
RestrictedTermsRecognitionResponse containing safety assessment and detailed analysis
|
134 |
+
"""
|
135 |
+
analysis = self.predict(text, custom_terms, **kwargs)
|
136 |
+
|
137 |
+
# Create a summary of findings
|
138 |
+
if analysis.contains_restricted_terms:
|
139 |
+
summary_parts = ["Restricted terms detected:"]
|
140 |
+
for match in analysis.detected_matches:
|
141 |
+
summary_parts.append(f"\n- {match.original_term}: {match.matched_text} ({match.match_type})")
|
142 |
+
summary = "\n".join(summary_parts)
|
143 |
+
else:
|
144 |
+
summary = "No restricted terms detected."
|
145 |
+
|
146 |
+
# Updated anonymization logic
|
147 |
+
anonymized_text = None
|
148 |
+
if self.should_anonymize and analysis.contains_restricted_terms:
|
149 |
+
anonymized_text = text
|
150 |
+
for match in analysis.detected_matches:
|
151 |
+
replacement = "[redacted]" if aggregate_redaction else f"[{match.match_type.upper()}]"
|
152 |
+
anonymized_text = anonymized_text.replace(match.matched_text, replacement)
|
153 |
+
|
154 |
+
# Convert detected_matches to a dictionary format
|
155 |
+
detected_entities = {}
|
156 |
+
for match in analysis.detected_matches:
|
157 |
+
if match.original_term not in detected_entities:
|
158 |
+
detected_entities[match.original_term] = []
|
159 |
+
detected_entities[match.original_term].append(match.matched_text)
|
160 |
+
|
161 |
+
return RestrictedTermsRecognitionResponse(
|
162 |
+
contains_entities=analysis.contains_restricted_terms,
|
163 |
+
detected_entities=detected_entities,
|
164 |
+
explanation=summary,
|
165 |
+
anonymized_text=anonymized_text
|
166 |
+
)
|
guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py
CHANGED
@@ -5,6 +5,7 @@ from pydantic import BaseModel
|
|
5 |
|
6 |
from ...regex_model import RegexModel
|
7 |
from ..base import Guardrail
|
|
|
8 |
|
9 |
|
10 |
class RegexEntityRecognitionResponse(BaseModel):
|
@@ -63,19 +64,37 @@ class RegexEntityRecognitionGuardrail(Guardrail):
|
|
63 |
should_anonymize=should_anonymize
|
64 |
)
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
@weave.op()
|
67 |
-
def guard(self, prompt: str, return_detected_types: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
|
68 |
"""
|
69 |
Check if the input prompt contains any entities based on the regex patterns.
|
70 |
|
71 |
Args:
|
72 |
prompt: Input text to check for entities
|
|
|
|
|
73 |
return_detected_types: If True, returns detailed entity type information
|
74 |
|
75 |
Returns:
|
76 |
RegexEntityRecognitionResponse or RegexEntityRecognitionSimpleResponse containing detection results
|
77 |
"""
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
# Create detailed explanation
|
81 |
explanation_parts = []
|
@@ -91,13 +110,13 @@ class RegexEntityRecognitionGuardrail(Guardrail):
|
|
91 |
for pattern in result.failed_patterns:
|
92 |
explanation_parts.append(f"- {pattern}")
|
93 |
|
94 |
-
#
|
95 |
anonymized_text = None
|
96 |
if getattr(self, 'should_anonymize', False) and result.matched_patterns:
|
97 |
anonymized_text = prompt
|
98 |
for entity_type, matches in result.matched_patterns.items():
|
99 |
for match in matches:
|
100 |
-
replacement = f"[{entity_type.upper()}]"
|
101 |
anonymized_text = anonymized_text.replace(match, replacement)
|
102 |
|
103 |
if return_detected_types:
|
@@ -115,5 +134,5 @@ class RegexEntityRecognitionGuardrail(Guardrail):
|
|
115 |
)
|
116 |
|
117 |
@weave.op()
|
118 |
-
def predict(self, prompt: str, return_detected_types: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
|
119 |
-
return self.guard(prompt, return_detected_types=return_detected_types, **kwargs)
|
|
|
5 |
|
6 |
from ...regex_model import RegexModel
|
7 |
from ..base import Guardrail
|
8 |
+
import re
|
9 |
|
10 |
|
11 |
class RegexEntityRecognitionResponse(BaseModel):
|
|
|
64 |
should_anonymize=should_anonymize
|
65 |
)
|
66 |
|
67 |
+
def text_to_pattern(self, text: str) -> str:
|
68 |
+
"""
|
69 |
+
Convert input text into a regex pattern that matches the exact text.
|
70 |
+
"""
|
71 |
+
# Escape special regex characters in the text
|
72 |
+
escaped_text = re.escape(text)
|
73 |
+
# Create a pattern that matches the exact text, case-insensitive
|
74 |
+
return rf"\b{escaped_text}\b"
|
75 |
+
|
76 |
@weave.op()
|
77 |
+
def guard(self, prompt: str, custom_terms: Optional[list[str]] = None, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
|
78 |
"""
|
79 |
Check if the input prompt contains any entities based on the regex patterns.
|
80 |
|
81 |
Args:
|
82 |
prompt: Input text to check for entities
|
83 |
+
custom_terms: List of custom terms to be converted into regex patterns. If provided,
|
84 |
+
only these terms will be checked, ignoring default patterns.
|
85 |
return_detected_types: If True, returns detailed entity type information
|
86 |
|
87 |
Returns:
|
88 |
RegexEntityRecognitionResponse or RegexEntityRecognitionSimpleResponse containing detection results
|
89 |
"""
|
90 |
+
if custom_terms:
|
91 |
+
# Create a temporary RegexModel with only the custom patterns
|
92 |
+
temp_patterns = {term: self.text_to_pattern(term) for term in custom_terms}
|
93 |
+
temp_model = RegexModel(patterns=temp_patterns)
|
94 |
+
result = temp_model.check(prompt)
|
95 |
+
else:
|
96 |
+
# Use the original regex_model if no custom terms provided
|
97 |
+
result = self.regex_model.check(prompt)
|
98 |
|
99 |
# Create detailed explanation
|
100 |
explanation_parts = []
|
|
|
110 |
for pattern in result.failed_patterns:
|
111 |
explanation_parts.append(f"- {pattern}")
|
112 |
|
113 |
+
# Updated anonymization logic
|
114 |
anonymized_text = None
|
115 |
if getattr(self, 'should_anonymize', False) and result.matched_patterns:
|
116 |
anonymized_text = prompt
|
117 |
for entity_type, matches in result.matched_patterns.items():
|
118 |
for match in matches:
|
119 |
+
replacement = "[redacted]" if aggregate_redaction else f"[{entity_type.upper()}]"
|
120 |
anonymized_text = anonymized_text.replace(match, replacement)
|
121 |
|
122 |
if return_detected_types:
|
|
|
134 |
)
|
135 |
|
136 |
@weave.op()
|
137 |
+
def predict(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
|
138 |
+
return self.guard(prompt, return_detected_types=return_detected_types, aggregate_redaction=aggregate_redaction, **kwargs)
|