ash0ts commited on
Commit
28d8897
·
1 Parent(s): 3ad3f59

Add banned terms guardrails

Browse files
application_pages/chat_app.py CHANGED
@@ -82,6 +82,13 @@ def initialize_guardrails():
82
  guardrail_name,
83
  )()
84
  )
 
 
 
 
 
 
 
85
  st.session_state.guardrails_manager = GuardrailManager(
86
  guardrails=st.session_state.guardrails
87
  )
 
82
  guardrail_name,
83
  )()
84
  )
85
+ elif guardrail_name == "RestrictedTermsJudge":
86
+ st.session_state.guardrails.append(
87
+ getattr(
88
+ importlib.import_module("guardrails_genie.guardrails"),
89
+ guardrail_name,
90
+ )()
91
+ )
92
  st.session_state.guardrails_manager = GuardrailManager(
93
  guardrails=st.session_state.guardrails
94
  )
guardrails_genie/guardrails/__init__.py CHANGED
@@ -6,6 +6,7 @@ from .entity_recognition import (
6
  PresidioEntityRecognitionGuardrail,
7
  RegexEntityRecognitionGuardrail,
8
  TransformersEntityRecognitionGuardrail,
 
9
  )
10
  from .manager import GuardrailManager
11
 
@@ -15,5 +16,6 @@ __all__ = [
15
  "PresidioEntityRecognitionGuardrail",
16
  "RegexEntityRecognitionGuardrail",
17
  "TransformersEntityRecognitionGuardrail",
 
18
  "GuardrailManager",
19
  ]
 
6
  PresidioEntityRecognitionGuardrail,
7
  RegexEntityRecognitionGuardrail,
8
  TransformersEntityRecognitionGuardrail,
9
+ RestrictedTermsJudge,
10
  )
11
  from .manager import GuardrailManager
12
 
 
16
  "PresidioEntityRecognitionGuardrail",
17
  "RegexEntityRecognitionGuardrail",
18
  "TransformersEntityRecognitionGuardrail",
19
+ "RestrictedTermsJudge",
20
  "GuardrailManager",
21
  ]
guardrails_genie/guardrails/entity_recognition/__init__.py CHANGED
@@ -1,9 +1,10 @@
1
  from .presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
2
  from .regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
3
  from .transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
4
-
5
  __all__ = [
6
  "PresidioEntityRecognitionGuardrail",
7
  "RegexEntityRecognitionGuardrail",
8
  "TransformersEntityRecognitionGuardrail",
 
9
  ]
 
1
  from .presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
2
  from .regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
3
  from .transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
4
+ from .llm_judge_entity_recognition_guardrail import RestrictedTermsJudge
5
  __all__ = [
6
  "PresidioEntityRecognitionGuardrail",
7
  "RegexEntityRecognitionGuardrail",
8
  "TransformersEntityRecognitionGuardrail",
9
+ "RestrictedTermsJudge"
10
  ]
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_benchmark.py ADDED
File without changes
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/banned_term_examples.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Collection of restricted terms test examples with expected outcomes for entity recognition testing.
3
+ Focuses on banned terms, competitor mentions, and brand protection scenarios.
4
+ """
5
+
6
+ RESTRICTED_TERMS_EXAMPLES = [
7
+ {
8
+ "description": "Competitor Product Discussion",
9
+ "input_text": """
10
+ I think we should implement features similar to Salesforce's Einstein AI
11
+ and Oracle's Cloud Infrastructure. Maybe we could also look at how
12
+ AWS handles their lambda functions.
13
+ """,
14
+ "custom_terms": ["Salesforce", "Oracle", "AWS", "Einstein AI", "Cloud Infrastructure", "lambda"],
15
+ "expected_entities": {
16
+ "Salesforce": ["Salesforce"],
17
+ "Oracle": ["Oracle"],
18
+ "AWS": ["AWS"],
19
+ "Einstein AI": ["Einstein AI"],
20
+ "Cloud Infrastructure": ["Cloud Infrastructure"],
21
+ "lambda": ["lambda"]
22
+ }
23
+ },
24
+ {
25
+ "description": "Inappropriate Language in Support Ticket",
26
+ "input_text": """
27
+ This damn product keeps crashing! What the hell is wrong with your
28
+ stupid service? I've wasted so much freaking time on this crap.
29
+ """,
30
+ "custom_terms": ["damn", "hell", "stupid", "crap"],
31
+ "expected_entities": {
32
+ "damn": ["damn"],
33
+ "hell": ["hell"],
34
+ "stupid": ["stupid"],
35
+ "crap": ["crap"]
36
+ }
37
+ },
38
+ {
39
+ "description": "Confidential Project Names",
40
+ "input_text": """
41
+ Project Titan's launch date has been moved up. We should coordinate
42
+ with Project Phoenix team and the Blue Dragon initiative for resource allocation.
43
+ """,
44
+ "custom_terms": ["Project Titan", "Project Phoenix", "Blue Dragon"],
45
+ "expected_entities": {
46
+ "Project Titan": ["Project Titan"],
47
+ "Project Phoenix": ["Project Phoenix"],
48
+ "Blue Dragon": ["Blue Dragon"]
49
+ }
50
+ }
51
+ ]
52
+
53
+ # Edge cases and special formats
54
+ EDGE_CASE_EXAMPLES = [
55
+ {
56
+ "description": "Common Corporate Abbreviations and Stock Symbols",
57
+ "input_text": """
58
+ MSFT's Azure and O365 platform is gaining market share.
59
+ Have you seen what GOOGL/GOOG and FB/META are doing with their AI?
60
+ CRM (Salesforce) and ORCL (Oracle) have interesting features too.
61
+ """,
62
+ "custom_terms": ["Microsoft", "Google", "Meta", "Facebook", "Salesforce", "Oracle"],
63
+ "expected_entities": {
64
+ "Microsoft": ["MSFT"],
65
+ "Google": ["GOOGL", "GOOG"],
66
+ "Meta": ["META"],
67
+ "Facebook": ["FB"],
68
+ "Salesforce": ["CRM", "Salesforce"],
69
+ "Oracle": ["ORCL"]
70
+ }
71
+ },
72
+ {
73
+ "description": "L33t Speak and Intentional Obfuscation",
74
+ "input_text": """
75
+ S4l3sf0rc3 is better than 0r4cl3!
76
+ M1cr0$oft and G00gl3 are the main competitors.
77
+ Let's check F8book and Met@ too.
78
+ """,
79
+ "custom_terms": ["Salesforce", "Oracle", "Microsoft", "Google", "Facebook", "Meta"],
80
+ "expected_entities": {
81
+ "Salesforce": ["S4l3sf0rc3"],
82
+ "Oracle": ["0r4cl3"],
83
+ "Microsoft": ["M1cr0$oft"],
84
+ "Google": ["G00gl3"],
85
+ "Facebook": ["F8book"],
86
+ "Meta": ["Met@"]
87
+ }
88
+ },
89
+ {
90
+ "description": "Case Variations and Partial Matches",
91
+ "input_text": """
92
+ salesFORCE and ORACLE are competitors.
93
+ MicroSoft and google are too.
94
+ Have you tried micro-soft or Google_Cloud?
95
+ """,
96
+ "custom_terms": ["Microsoft", "Google", "Salesforce", "Oracle"],
97
+ "expected_entities": {
98
+ "Microsoft": ["MicroSoft", "micro-soft"],
99
+ "Google": ["google", "Google_Cloud"],
100
+ "Salesforce": ["salesFORCE"],
101
+ "Oracle": ["ORACLE"]
102
+ }
103
+ },
104
+ {
105
+ "description": "Common Misspellings and Typos",
106
+ "input_text": """
107
+ Microsft and Microsooft are common typos.
108
+ Goggle, Googel, and Gooogle are search related.
109
+ Salezforce and Oracel need checking too.
110
+ """,
111
+ "custom_terms": ["Microsoft", "Google", "Salesforce", "Oracle"],
112
+ "expected_entities": {
113
+ "Microsoft": ["Microsft", "Microsooft"],
114
+ "Google": ["Goggle", "Googel", "Gooogle"],
115
+ "Salesforce": ["Salezforce"],
116
+ "Oracle": ["Oracel"]
117
+ }
118
+ },
119
+ {
120
+ "description": "Mixed Variations and Context",
121
+ "input_text": """
122
+ The M$ cloud competes with AWS (Amazon Web Services).
123
+ FB/Meta's social platform and GOOGL's search dominate.
124
+ SF.com and Oracle-DB are industry standards.
125
+ """,
126
+ "custom_terms": ["Microsoft", "Amazon Web Services", "Facebook", "Meta", "Google", "Salesforce", "Oracle"],
127
+ "expected_entities": {
128
+ "Microsoft": ["M$"],
129
+ "Amazon Web Services": ["AWS"],
130
+ "Facebook": ["FB"],
131
+ "Meta": ["Meta"],
132
+ "Google": ["GOOGL"],
133
+ "Salesforce": ["SF.com"],
134
+ "Oracle": ["Oracle-DB"]
135
+ }
136
+ }
137
+ ]
138
+
139
+ def validate_entities(detected: dict, expected: dict) -> bool:
140
+ """Compare detected entities with expected entities"""
141
+ if set(detected.keys()) != set(expected.keys()):
142
+ return False
143
+ return all(set(detected[k]) == set(expected[k]) for k in expected.keys())
144
+
145
+ def run_test_case(guardrail, test_case, test_type="Main"):
146
+ """Run a single test case and print results"""
147
+ print(f"\n{test_type} Test Case: {test_case['description']}")
148
+ print("-" * 50)
149
+
150
+ result = guardrail.guard(
151
+ test_case['input_text'],
152
+ custom_terms=test_case['custom_terms']
153
+ )
154
+ expected = test_case['expected_entities']
155
+
156
+ # Validate results
157
+ matches = validate_entities(result.detected_entities, expected)
158
+
159
+ print(f"Test Status: {'✓ PASS' if matches else '✗ FAIL'}")
160
+ print(f"Contains Restricted Terms: {result.contains_entities}")
161
+
162
+ if not matches:
163
+ print("\nEntity Comparison:")
164
+ all_entity_types = set(list(result.detected_entities.keys()) + list(expected.keys()))
165
+ for entity_type in all_entity_types:
166
+ detected = set(result.detected_entities.get(entity_type, []))
167
+ expected_set = set(expected.get(entity_type, []))
168
+ print(f"\nEntity Type: {entity_type}")
169
+ print(f" Expected: {sorted(expected_set)}")
170
+ print(f" Detected: {sorted(detected)}")
171
+ if detected != expected_set:
172
+ print(f" Missing: {sorted(expected_set - detected)}")
173
+ print(f" Extra: {sorted(detected - expected_set)}")
174
+
175
+ if result.anonymized_text:
176
+ print(f"\nAnonymized Text:\n{result.anonymized_text}")
177
+
178
+ return matches
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_llm_judge.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from guardrails_genie.guardrails.entity_recognition.llm_judge_entity_recognition_guardrail import RestrictedTermsJudge
2
+ from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
3
+ RESTRICTED_TERMS_EXAMPLES,
4
+ EDGE_CASE_EXAMPLES,
5
+ run_test_case
6
+ )
7
+ from guardrails_genie.llm import OpenAIModel
8
+ import weave
9
+
10
+ def test_restricted_terms_detection():
11
+ """Test restricted terms detection scenarios using predefined test cases"""
12
+ weave.init("guardrails-genie-restricted-terms-llm-judge")
13
+
14
+ # Create the guardrail with OpenAI model
15
+ llm_judge = RestrictedTermsJudge(
16
+ should_anonymize=True,
17
+ llm_model=OpenAIModel()
18
+ )
19
+
20
+ # Test statistics
21
+ total_tests = len(RESTRICTED_TERMS_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
22
+ passed_tests = 0
23
+
24
+ # Test main restricted terms examples
25
+ print("\nRunning Main Restricted Terms Tests")
26
+ print("=" * 80)
27
+ for test_case in RESTRICTED_TERMS_EXAMPLES:
28
+ if run_test_case(llm_judge, test_case):
29
+ passed_tests += 1
30
+
31
+ # Test edge cases
32
+ print("\nRunning Edge Cases")
33
+ print("=" * 80)
34
+ for test_case in EDGE_CASE_EXAMPLES:
35
+ if run_test_case(llm_judge, test_case, "Edge"):
36
+ passed_tests += 1
37
+
38
+ # Print summary
39
+ print("\nTest Summary")
40
+ print("=" * 80)
41
+ print(f"Total Tests: {total_tests}")
42
+ print(f"Passed: {passed_tests}")
43
+ print(f"Failed: {total_tests - passed_tests}")
44
+ print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
45
+
46
+ if __name__ == "__main__":
47
+ test_restricted_terms_detection()
guardrails_genie/guardrails/entity_recognition/banned_terms_examples/run_regex_model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
2
+ from guardrails_genie.guardrails.entity_recognition.banned_terms_examples.banned_term_examples import (
3
+ RESTRICTED_TERMS_EXAMPLES,
4
+ EDGE_CASE_EXAMPLES,
5
+ run_test_case
6
+ )
7
+ import weave
8
+
9
+ def test_restricted_terms_detection():
10
+ """Test restricted terms detection scenarios using predefined test cases"""
11
+ weave.init("guardrails-genie-restricted-terms-regex-model")
12
+
13
+ # Create the guardrail with anonymization enabled
14
+ regex_guardrail = RegexEntityRecognitionGuardrail(
15
+ use_defaults=False, # Don't use default PII patterns
16
+ should_anonymize=True
17
+ )
18
+
19
+ # Test statistics
20
+ total_tests = len(RESTRICTED_TERMS_EXAMPLES) + len(EDGE_CASE_EXAMPLES)
21
+ passed_tests = 0
22
+
23
+ # Test main restricted terms examples
24
+ print("\nRunning Main Restricted Terms Tests")
25
+ print("=" * 80)
26
+ for test_case in RESTRICTED_TERMS_EXAMPLES:
27
+ if run_test_case(regex_guardrail, test_case):
28
+ passed_tests += 1
29
+
30
+ # Test edge cases
31
+ print("\nRunning Edge Cases")
32
+ print("=" * 80)
33
+ for test_case in EDGE_CASE_EXAMPLES:
34
+ if run_test_case(regex_guardrail, test_case, "Edge"):
35
+ passed_tests += 1
36
+
37
+ # Print summary
38
+ print("\nTest Summary")
39
+ print("=" * 80)
40
+ print(f"Total Tests: {total_tests}")
41
+ print(f"Passed: {passed_tests}")
42
+ print(f"Failed: {total_tests - passed_tests}")
43
+ print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
44
+
45
+ if __name__ == "__main__":
46
+ test_restricted_terms_detection()
guardrails_genie/guardrails/entity_recognition/llm_judge_entity_recognition_guardrail.py CHANGED
@@ -1,3 +1,166 @@
1
- ## Word conssitentcy
2
- # - Scent -> Odor
3
- # - odour -> Odor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+ import weave
3
+ from pydantic import BaseModel, Field
4
+ from typing_extensions import Annotated
5
+
6
+ from ...llm import OpenAIModel
7
+ from ..base import Guardrail
8
+ import instructor
9
+
10
+
11
+ class TermMatch(BaseModel):
12
+ """Represents a matched term and its variations"""
13
+ original_term: str
14
+ matched_text: str
15
+ match_type: str = Field(
16
+ description="Type of match: EXACT, MISSPELLING, ABBREVIATION, or VARIANT"
17
+ )
18
+ explanation: str = Field(
19
+ description="Explanation of why this is considered a match"
20
+ )
21
+
22
+
23
+ class RestrictedTermsAnalysis(BaseModel):
24
+ """Analysis result for restricted terms detection"""
25
+ contains_restricted_terms: bool = Field(
26
+ description="Whether any restricted terms were detected"
27
+ )
28
+ detected_matches: List[TermMatch] = Field(
29
+ default_factory=list,
30
+ description="List of detected term matches with their variations"
31
+ )
32
+ explanation: str = Field(
33
+ description="Detailed explanation of the analysis"
34
+ )
35
+ anonymized_text: Optional[str] = Field(
36
+ default=None,
37
+ description="Text with restricted terms replaced with category tags"
38
+ )
39
+
40
+ @property
41
+ def safe(self) -> bool:
42
+ return not self.contains_restricted_terms
43
+
44
+
45
+ class RestrictedTermsRecognitionResponse(BaseModel):
46
+ contains_entities: bool
47
+ detected_entities: Dict[str, List[str]]
48
+ explanation: str
49
+ anonymized_text: Optional[str] = None
50
+
51
+ @property
52
+ def safe(self) -> bool:
53
+ return not self.contains_entities
54
+
55
+
56
+ class RestrictedTermsJudge(Guardrail):
57
+ llm_model: OpenAIModel = Field(default_factory=lambda: OpenAIModel())
58
+ should_anonymize: bool = False
59
+
60
+ def __init__(self, should_anonymize: bool = False, **kwargs):
61
+ super().__init__(**kwargs)
62
+ self.should_anonymize = should_anonymize
63
+ # Pass the OpenAI client to instructor.patch()
64
+ instructor.patch(client=self.llm_model._openai_client)
65
+
66
+ def format_prompts(self, text: str, custom_terms: List[str]) -> tuple[str, str]:
67
+ user_prompt = f"""Analyze the following text for restricted terms and variations:
68
+
69
+ <text_to_analyze>
70
+ {text}
71
+ </text_to_analyze>
72
+
73
+ Look for these specific terms and their variations:
74
+ {', '.join(custom_terms)}
75
+
76
+ Analyze the text carefully for:
77
+ 1. Exact matches
78
+ 2. Common misspellings
79
+ 3. Abbreviations
80
+ 4. Intentional variations (l33t speak, special characters, etc.)
81
+ 5. Case variations
82
+ """
83
+
84
+ system_prompt = """You are an expert system for detecting restricted terms, brand mentions, and inappropriate content.
85
+
86
+ Your task is to:
87
+ 1. Identify exact matches of restricted terms
88
+ 2. Detect variations including:
89
+ - Misspellings (both accidental and intentional)
90
+ - Abbreviations and acronyms
91
+ - Case variations
92
+ - L33t speak or special character substitutions
93
+ - Partial matches within larger words
94
+
95
+ For each match, you must:
96
+ 1. Identify the original restricted term
97
+ 2. Note the actual text that matched
98
+ 3. Classify the match type
99
+ 4. Provide a confidence score
100
+ 5. Explain why it's considered a match
101
+
102
+ Be thorough but avoid false positives. Focus on meaningful matches that indicate actual attempts to use restricted terms.
103
+
104
+ Return your analysis in the structured format specified by the RestrictedTermsAnalysis model."""
105
+
106
+ return user_prompt, system_prompt
107
+
108
+ @weave.op()
109
+ def predict(self, text: str, custom_terms: List[str], **kwargs) -> RestrictedTermsAnalysis:
110
+ user_prompt, system_prompt = self.format_prompts(text, custom_terms)
111
+
112
+ response = self.llm_model.predict(
113
+ user_prompts=user_prompt,
114
+ system_prompt=system_prompt,
115
+ response_format=RestrictedTermsAnalysis,
116
+ temperature=0.1, # Lower temperature for more consistent analysis
117
+ **kwargs
118
+ )
119
+
120
+ return response.choices[0].message.parsed
121
+
122
+ #TODO: Remove default custom_terms
123
+ @weave.op()
124
+ def guard(self, text: str, custom_terms: List[str] = ["Microsoft", "Amazon Web Services", "Facebook", "Meta", "Google", "Salesforce", "Oracle"], aggregate_redaction: bool = True, **kwargs) -> RestrictedTermsRecognitionResponse:
125
+ """
126
+ Guard against restricted terms and their variations.
127
+
128
+ Args:
129
+ text: Text to analyze
130
+ custom_terms: List of restricted terms to check for
131
+
132
+ Returns:
133
+ RestrictedTermsRecognitionResponse containing safety assessment and detailed analysis
134
+ """
135
+ analysis = self.predict(text, custom_terms, **kwargs)
136
+
137
+ # Create a summary of findings
138
+ if analysis.contains_restricted_terms:
139
+ summary_parts = ["Restricted terms detected:"]
140
+ for match in analysis.detected_matches:
141
+ summary_parts.append(f"\n- {match.original_term}: {match.matched_text} ({match.match_type})")
142
+ summary = "\n".join(summary_parts)
143
+ else:
144
+ summary = "No restricted terms detected."
145
+
146
+ # Updated anonymization logic
147
+ anonymized_text = None
148
+ if self.should_anonymize and analysis.contains_restricted_terms:
149
+ anonymized_text = text
150
+ for match in analysis.detected_matches:
151
+ replacement = "[redacted]" if aggregate_redaction else f"[{match.match_type.upper()}]"
152
+ anonymized_text = anonymized_text.replace(match.matched_text, replacement)
153
+
154
+ # Convert detected_matches to a dictionary format
155
+ detected_entities = {}
156
+ for match in analysis.detected_matches:
157
+ if match.original_term not in detected_entities:
158
+ detected_entities[match.original_term] = []
159
+ detected_entities[match.original_term].append(match.matched_text)
160
+
161
+ return RestrictedTermsRecognitionResponse(
162
+ contains_entities=analysis.contains_restricted_terms,
163
+ detected_entities=detected_entities,
164
+ explanation=summary,
165
+ anonymized_text=anonymized_text
166
+ )
guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py CHANGED
@@ -5,6 +5,7 @@ from pydantic import BaseModel
5
 
6
  from ...regex_model import RegexModel
7
  from ..base import Guardrail
 
8
 
9
 
10
  class RegexEntityRecognitionResponse(BaseModel):
@@ -63,19 +64,37 @@ class RegexEntityRecognitionGuardrail(Guardrail):
63
  should_anonymize=should_anonymize
64
  )
65
 
 
 
 
 
 
 
 
 
 
66
  @weave.op()
67
- def guard(self, prompt: str, return_detected_types: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
68
  """
69
  Check if the input prompt contains any entities based on the regex patterns.
70
 
71
  Args:
72
  prompt: Input text to check for entities
 
 
73
  return_detected_types: If True, returns detailed entity type information
74
 
75
  Returns:
76
  RegexEntityRecognitionResponse or RegexEntityRecognitionSimpleResponse containing detection results
77
  """
78
- result = self.regex_model.check(prompt)
 
 
 
 
 
 
 
79
 
80
  # Create detailed explanation
81
  explanation_parts = []
@@ -91,13 +110,13 @@ class RegexEntityRecognitionGuardrail(Guardrail):
91
  for pattern in result.failed_patterns:
92
  explanation_parts.append(f"- {pattern}")
93
 
94
- # Add anonymization logic
95
  anonymized_text = None
96
  if getattr(self, 'should_anonymize', False) and result.matched_patterns:
97
  anonymized_text = prompt
98
  for entity_type, matches in result.matched_patterns.items():
99
  for match in matches:
100
- replacement = f"[{entity_type.upper()}]"
101
  anonymized_text = anonymized_text.replace(match, replacement)
102
 
103
  if return_detected_types:
@@ -115,5 +134,5 @@ class RegexEntityRecognitionGuardrail(Guardrail):
115
  )
116
 
117
  @weave.op()
118
- def predict(self, prompt: str, return_detected_types: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
119
- return self.guard(prompt, return_detected_types=return_detected_types, **kwargs)
 
5
 
6
  from ...regex_model import RegexModel
7
  from ..base import Guardrail
8
+ import re
9
 
10
 
11
  class RegexEntityRecognitionResponse(BaseModel):
 
64
  should_anonymize=should_anonymize
65
  )
66
 
67
+ def text_to_pattern(self, text: str) -> str:
68
+ """
69
+ Convert input text into a regex pattern that matches the exact text.
70
+ """
71
+ # Escape special regex characters in the text
72
+ escaped_text = re.escape(text)
73
+ # Create a pattern that matches the exact text, case-insensitive
74
+ return rf"\b{escaped_text}\b"
75
+
76
  @weave.op()
77
+ def guard(self, prompt: str, custom_terms: Optional[list[str]] = None, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
78
  """
79
  Check if the input prompt contains any entities based on the regex patterns.
80
 
81
  Args:
82
  prompt: Input text to check for entities
83
+ custom_terms: List of custom terms to be converted into regex patterns. If provided,
84
+ only these terms will be checked, ignoring default patterns.
85
  return_detected_types: If True, returns detailed entity type information
86
 
87
  Returns:
88
  RegexEntityRecognitionResponse or RegexEntityRecognitionSimpleResponse containing detection results
89
  """
90
+ if custom_terms:
91
+ # Create a temporary RegexModel with only the custom patterns
92
+ temp_patterns = {term: self.text_to_pattern(term) for term in custom_terms}
93
+ temp_model = RegexModel(patterns=temp_patterns)
94
+ result = temp_model.check(prompt)
95
+ else:
96
+ # Use the original regex_model if no custom terms provided
97
+ result = self.regex_model.check(prompt)
98
 
99
  # Create detailed explanation
100
  explanation_parts = []
 
110
  for pattern in result.failed_patterns:
111
  explanation_parts.append(f"- {pattern}")
112
 
113
+ # Updated anonymization logic
114
  anonymized_text = None
115
  if getattr(self, 'should_anonymize', False) and result.matched_patterns:
116
  anonymized_text = prompt
117
  for entity_type, matches in result.matched_patterns.items():
118
  for match in matches:
119
+ replacement = "[redacted]" if aggregate_redaction else f"[{entity_type.upper()}]"
120
  anonymized_text = anonymized_text.replace(match, replacement)
121
 
122
  if return_detected_types:
 
134
  )
135
 
136
  @weave.op()
137
+ def predict(self, prompt: str, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
138
+ return self.guard(prompt, return_detected_types=return_detected_types, aggregate_redaction=aggregate_redaction, **kwargs)