Spaces:
Running
Running
Make PII evals work
Browse filesWeave evals run but the results dont match the normal benchmark script
- application_pages/chat_app.py +4 -4
- guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark.py +75 -4
- guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark_weave.py +322 -0
- guardrails_genie/guardrails/entity_recognition/presidio_entity_recognition_guardrail.py +3 -6
- guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py +30 -12
- guardrails_genie/guardrails/entity_recognition/transformers_entity_recognition_guardrail.py +1 -1
- guardrails_genie/regex_model.py +18 -13
application_pages/chat_app.py
CHANGED
@@ -66,28 +66,28 @@ def initialize_guardrails():
|
|
66 |
getattr(
|
67 |
importlib.import_module("guardrails_genie.guardrails"),
|
68 |
guardrail_name,
|
69 |
-
)()
|
70 |
)
|
71 |
elif guardrail_name == "RegexEntityRecognitionGuardrail":
|
72 |
st.session_state.guardrails.append(
|
73 |
getattr(
|
74 |
importlib.import_module("guardrails_genie.guardrails"),
|
75 |
guardrail_name,
|
76 |
-
)()
|
77 |
)
|
78 |
elif guardrail_name == "TransformersEntityRecognitionGuardrail":
|
79 |
st.session_state.guardrails.append(
|
80 |
getattr(
|
81 |
importlib.import_module("guardrails_genie.guardrails"),
|
82 |
guardrail_name,
|
83 |
-
)()
|
84 |
)
|
85 |
elif guardrail_name == "RestrictedTermsJudge":
|
86 |
st.session_state.guardrails.append(
|
87 |
getattr(
|
88 |
importlib.import_module("guardrails_genie.guardrails"),
|
89 |
guardrail_name,
|
90 |
-
)()
|
91 |
)
|
92 |
st.session_state.guardrails_manager = GuardrailManager(
|
93 |
guardrails=st.session_state.guardrails
|
|
|
66 |
getattr(
|
67 |
importlib.import_module("guardrails_genie.guardrails"),
|
68 |
guardrail_name,
|
69 |
+
)(should_anonymize=True)
|
70 |
)
|
71 |
elif guardrail_name == "RegexEntityRecognitionGuardrail":
|
72 |
st.session_state.guardrails.append(
|
73 |
getattr(
|
74 |
importlib.import_module("guardrails_genie.guardrails"),
|
75 |
guardrail_name,
|
76 |
+
)(should_anonymize=True)
|
77 |
)
|
78 |
elif guardrail_name == "TransformersEntityRecognitionGuardrail":
|
79 |
st.session_state.guardrails.append(
|
80 |
getattr(
|
81 |
importlib.import_module("guardrails_genie.guardrails"),
|
82 |
guardrail_name,
|
83 |
+
)(should_anonymize=True)
|
84 |
)
|
85 |
elif guardrail_name == "RestrictedTermsJudge":
|
86 |
st.session_state.guardrails.append(
|
87 |
getattr(
|
88 |
importlib.import_module("guardrails_genie.guardrails"),
|
89 |
guardrail_name,
|
90 |
+
)(should_anonymize=True)
|
91 |
)
|
92 |
st.session_state.guardrails_manager = GuardrailManager(
|
93 |
guardrails=st.session_state.guardrails
|
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark.py
CHANGED
@@ -6,6 +6,35 @@ import json
|
|
6 |
from pathlib import Path
|
7 |
import weave
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
def load_ai4privacy_dataset(num_samples: int = 100, split: str = "validation") -> List[Dict]:
|
10 |
"""
|
11 |
Load and prepare samples from the ai4privacy dataset.
|
@@ -81,6 +110,17 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]
|
|
81 |
detected = result.detected_entities
|
82 |
expected = test_case['expected_entities']
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
# Track entity-level metrics
|
85 |
all_entity_types = set(list(detected.keys()) + list(expected.keys()))
|
86 |
entity_results = {}
|
@@ -137,12 +177,20 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]
|
|
137 |
else:
|
138 |
metrics["failed"] += 1
|
139 |
|
140 |
-
# Calculate final entity metrics
|
|
|
|
|
|
|
|
|
141 |
for entity_type, counts in metrics["entity_metrics"].items():
|
142 |
tp = counts["total_true_positives"]
|
143 |
fp = counts["total_false_positives"]
|
144 |
fn = counts["total_false_negatives"]
|
145 |
|
|
|
|
|
|
|
|
|
146 |
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
147 |
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
148 |
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
@@ -153,6 +201,20 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]
|
|
153 |
"f1": f1
|
154 |
})
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
return metrics, detailed_results
|
157 |
|
158 |
def save_results(metrics: Dict, detailed_results: List[Dict], model_name: str, output_dir: str = "evaluation_results"):
|
@@ -177,6 +239,15 @@ def print_metrics_summary(metrics: Dict):
|
|
177 |
print(f"Failed: {metrics['failed']}")
|
178 |
print(f"Success Rate: {(metrics['passed']/metrics['total'])*100:.1f}%")
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
print("\nEntity-level Metrics:")
|
181 |
print("-" * 80)
|
182 |
print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
|
@@ -193,9 +264,9 @@ def main():
|
|
193 |
|
194 |
# Initialize models to evaluate
|
195 |
models = {
|
196 |
-
"regex": RegexEntityRecognitionGuardrail(should_anonymize=True),
|
197 |
-
"presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True),
|
198 |
-
"transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True)
|
199 |
}
|
200 |
|
201 |
# Evaluate each model
|
|
|
6 |
from pathlib import Path
|
7 |
import weave
|
8 |
|
9 |
+
# Add this mapping dictionary near the top of the file
|
10 |
+
PRESIDIO_TO_TRANSFORMER_MAPPING = {
|
11 |
+
"EMAIL_ADDRESS": "EMAIL",
|
12 |
+
"PHONE_NUMBER": "TELEPHONENUM",
|
13 |
+
"US_SSN": "SOCIALNUM",
|
14 |
+
"CREDIT_CARD": "CREDITCARDNUMBER",
|
15 |
+
"IP_ADDRESS": "IDCARDNUM",
|
16 |
+
"DATE_TIME": "DATEOFBIRTH",
|
17 |
+
"US_PASSPORT": "IDCARDNUM",
|
18 |
+
"US_DRIVER_LICENSE": "DRIVERLICENSENUM",
|
19 |
+
"US_BANK_NUMBER": "ACCOUNTNUM",
|
20 |
+
"LOCATION": "CITY",
|
21 |
+
"URL": "USERNAME", # URLs often contain usernames
|
22 |
+
"IN_PAN": "TAXNUM", # Indian Permanent Account Number
|
23 |
+
"UK_NHS": "IDCARDNUM",
|
24 |
+
"SG_NRIC_FIN": "IDCARDNUM",
|
25 |
+
"AU_ABN": "TAXNUM", # Australian Business Number
|
26 |
+
"AU_ACN": "TAXNUM", # Australian Company Number
|
27 |
+
"AU_TFN": "TAXNUM", # Australian Tax File Number
|
28 |
+
"AU_MEDICARE": "IDCARDNUM",
|
29 |
+
"IN_AADHAAR": "IDCARDNUM", # Indian national ID
|
30 |
+
"IN_VOTER": "IDCARDNUM",
|
31 |
+
"IN_PASSPORT": "IDCARDNUM",
|
32 |
+
"CRYPTO": "ACCOUNTNUM", # Cryptocurrency addresses
|
33 |
+
"IBAN_CODE": "ACCOUNTNUM",
|
34 |
+
"MEDICAL_LICENSE": "IDCARDNUM",
|
35 |
+
"IN_VEHICLE_REGISTRATION": "IDCARDNUM"
|
36 |
+
}
|
37 |
+
|
38 |
def load_ai4privacy_dataset(num_samples: int = 100, split: str = "validation") -> List[Dict]:
|
39 |
"""
|
40 |
Load and prepare samples from the ai4privacy dataset.
|
|
|
110 |
detected = result.detected_entities
|
111 |
expected = test_case['expected_entities']
|
112 |
|
113 |
+
# Map Presidio entities if this is the Presidio guardrail
|
114 |
+
if isinstance(guardrail, PresidioEntityRecognitionGuardrail):
|
115 |
+
mapped_detected = {}
|
116 |
+
for entity_type, values in detected.items():
|
117 |
+
mapped_type = PRESIDIO_TO_TRANSFORMER_MAPPING.get(entity_type)
|
118 |
+
if mapped_type:
|
119 |
+
if mapped_type not in mapped_detected:
|
120 |
+
mapped_detected[mapped_type] = []
|
121 |
+
mapped_detected[mapped_type].extend(values)
|
122 |
+
detected = mapped_detected
|
123 |
+
|
124 |
# Track entity-level metrics
|
125 |
all_entity_types = set(list(detected.keys()) + list(expected.keys()))
|
126 |
entity_results = {}
|
|
|
177 |
else:
|
178 |
metrics["failed"] += 1
|
179 |
|
180 |
+
# Calculate final entity metrics and track totals for overall metrics
|
181 |
+
total_tp = 0
|
182 |
+
total_fp = 0
|
183 |
+
total_fn = 0
|
184 |
+
|
185 |
for entity_type, counts in metrics["entity_metrics"].items():
|
186 |
tp = counts["total_true_positives"]
|
187 |
fp = counts["total_false_positives"]
|
188 |
fn = counts["total_false_negatives"]
|
189 |
|
190 |
+
total_tp += tp
|
191 |
+
total_fp += fp
|
192 |
+
total_fn += fn
|
193 |
+
|
194 |
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
195 |
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
196 |
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
|
|
201 |
"f1": f1
|
202 |
})
|
203 |
|
204 |
+
# Calculate overall metrics
|
205 |
+
overall_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
|
206 |
+
overall_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
|
207 |
+
overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0
|
208 |
+
|
209 |
+
metrics["overall"] = {
|
210 |
+
"precision": overall_precision,
|
211 |
+
"recall": overall_recall,
|
212 |
+
"f1": overall_f1,
|
213 |
+
"total_true_positives": total_tp,
|
214 |
+
"total_false_positives": total_fp,
|
215 |
+
"total_false_negatives": total_fn
|
216 |
+
}
|
217 |
+
|
218 |
return metrics, detailed_results
|
219 |
|
220 |
def save_results(metrics: Dict, detailed_results: List[Dict], model_name: str, output_dir: str = "evaluation_results"):
|
|
|
239 |
print(f"Failed: {metrics['failed']}")
|
240 |
print(f"Success Rate: {(metrics['passed']/metrics['total'])*100:.1f}%")
|
241 |
|
242 |
+
# Print overall metrics
|
243 |
+
print("\nOverall Metrics:")
|
244 |
+
print("-" * 80)
|
245 |
+
print(f"{'Metric':<20} {'Value':>10}")
|
246 |
+
print("-" * 80)
|
247 |
+
print(f"{'Precision':<20} {metrics['overall']['precision']:>10.2f}")
|
248 |
+
print(f"{'Recall':<20} {metrics['overall']['recall']:>10.2f}")
|
249 |
+
print(f"{'F1':<20} {metrics['overall']['f1']:>10.2f}")
|
250 |
+
|
251 |
print("\nEntity-level Metrics:")
|
252 |
print("-" * 80)
|
253 |
print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
|
|
|
264 |
|
265 |
# Initialize models to evaluate
|
266 |
models = {
|
267 |
+
"regex": RegexEntityRecognitionGuardrail(should_anonymize=True, show_available_entities=True),
|
268 |
+
"presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True, show_available_entities=True),
|
269 |
+
"transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True, show_available_entities=True)
|
270 |
}
|
271 |
|
272 |
# Evaluate each model
|
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark_weave.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from typing import Dict, List, Tuple, Optional
|
3 |
+
import random
|
4 |
+
from tqdm import tqdm
|
5 |
+
import json
|
6 |
+
from pathlib import Path
|
7 |
+
import weave
|
8 |
+
from weave.scorers import Scorer
|
9 |
+
from weave import Evaluation
|
10 |
+
import asyncio
|
11 |
+
|
12 |
+
# Add this mapping dictionary near the top of the file
|
13 |
+
PRESIDIO_TO_TRANSFORMER_MAPPING = {
|
14 |
+
"EMAIL_ADDRESS": "EMAIL",
|
15 |
+
"PHONE_NUMBER": "TELEPHONENUM",
|
16 |
+
"US_SSN": "SOCIALNUM",
|
17 |
+
"CREDIT_CARD": "CREDITCARDNUMBER",
|
18 |
+
"IP_ADDRESS": "IDCARDNUM",
|
19 |
+
"DATE_TIME": "DATEOFBIRTH",
|
20 |
+
"US_PASSPORT": "IDCARDNUM",
|
21 |
+
"US_DRIVER_LICENSE": "DRIVERLICENSENUM",
|
22 |
+
"US_BANK_NUMBER": "ACCOUNTNUM",
|
23 |
+
"LOCATION": "CITY",
|
24 |
+
"URL": "USERNAME", # URLs often contain usernames
|
25 |
+
"IN_PAN": "TAXNUM", # Indian Permanent Account Number
|
26 |
+
"UK_NHS": "IDCARDNUM",
|
27 |
+
"SG_NRIC_FIN": "IDCARDNUM",
|
28 |
+
"AU_ABN": "TAXNUM", # Australian Business Number
|
29 |
+
"AU_ACN": "TAXNUM", # Australian Company Number
|
30 |
+
"AU_TFN": "TAXNUM", # Australian Tax File Number
|
31 |
+
"AU_MEDICARE": "IDCARDNUM",
|
32 |
+
"IN_AADHAAR": "IDCARDNUM", # Indian national ID
|
33 |
+
"IN_VOTER": "IDCARDNUM",
|
34 |
+
"IN_PASSPORT": "IDCARDNUM",
|
35 |
+
"CRYPTO": "ACCOUNTNUM", # Cryptocurrency addresses
|
36 |
+
"IBAN_CODE": "ACCOUNTNUM",
|
37 |
+
"MEDICAL_LICENSE": "IDCARDNUM",
|
38 |
+
"IN_VEHICLE_REGISTRATION": "IDCARDNUM"
|
39 |
+
}
|
40 |
+
|
41 |
+
class EntityRecognitionScorer(Scorer):
|
42 |
+
"""Scorer for evaluating entity recognition performance"""
|
43 |
+
|
44 |
+
@weave.op()
|
45 |
+
async def score(self, model_output: Optional[dict], input_text: str, expected_entities: Dict) -> Dict:
|
46 |
+
"""Score entity recognition results"""
|
47 |
+
if not model_output:
|
48 |
+
return {"f1": 0.0}
|
49 |
+
|
50 |
+
# Convert Pydantic model to dict if necessary
|
51 |
+
if hasattr(model_output, "model_dump"):
|
52 |
+
model_output = model_output.model_dump()
|
53 |
+
elif hasattr(model_output, "dict"):
|
54 |
+
model_output = model_output.dict()
|
55 |
+
|
56 |
+
detected = model_output.get("detected_entities", {})
|
57 |
+
|
58 |
+
# Map Presidio entities if needed
|
59 |
+
if model_output.get("model_type") == "presidio":
|
60 |
+
mapped_detected = {}
|
61 |
+
for entity_type, values in detected.items():
|
62 |
+
mapped_type = PRESIDIO_TO_TRANSFORMER_MAPPING.get(entity_type)
|
63 |
+
if mapped_type:
|
64 |
+
if mapped_type not in mapped_detected:
|
65 |
+
mapped_detected[mapped_type] = []
|
66 |
+
mapped_detected[mapped_type].extend(values)
|
67 |
+
detected = mapped_detected
|
68 |
+
|
69 |
+
# Track entity-level metrics
|
70 |
+
all_entity_types = set(list(detected.keys()) + list(expected_entities.keys()))
|
71 |
+
entity_metrics = {}
|
72 |
+
|
73 |
+
for entity_type in all_entity_types:
|
74 |
+
detected_set = set(detected.get(entity_type, []))
|
75 |
+
expected_set = set(expected_entities.get(entity_type, []))
|
76 |
+
|
77 |
+
# Calculate metrics
|
78 |
+
true_positives = len(detected_set & expected_set)
|
79 |
+
false_positives = len(detected_set - expected_set)
|
80 |
+
false_negatives = len(expected_set - detected_set)
|
81 |
+
|
82 |
+
if entity_type not in entity_metrics:
|
83 |
+
entity_metrics[entity_type] = {
|
84 |
+
"total_true_positives": 0,
|
85 |
+
"total_false_positives": 0,
|
86 |
+
"total_false_negatives": 0
|
87 |
+
}
|
88 |
+
|
89 |
+
entity_metrics[entity_type]["total_true_positives"] += true_positives
|
90 |
+
entity_metrics[entity_type]["total_false_positives"] += false_positives
|
91 |
+
entity_metrics[entity_type]["total_false_negatives"] += false_negatives
|
92 |
+
|
93 |
+
# Calculate per-entity metrics
|
94 |
+
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
|
95 |
+
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
|
96 |
+
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
97 |
+
|
98 |
+
entity_metrics[entity_type].update({
|
99 |
+
"precision": precision,
|
100 |
+
"recall": recall,
|
101 |
+
"f1": f1
|
102 |
+
})
|
103 |
+
|
104 |
+
# Calculate overall metrics
|
105 |
+
total_tp = sum(metrics["total_true_positives"] for metrics in entity_metrics.values())
|
106 |
+
total_fp = sum(metrics["total_false_positives"] for metrics in entity_metrics.values())
|
107 |
+
total_fn = sum(metrics["total_false_negatives"] for metrics in entity_metrics.values())
|
108 |
+
|
109 |
+
overall_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
|
110 |
+
overall_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
|
111 |
+
overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0
|
112 |
+
|
113 |
+
entity_metrics["overall"] = {
|
114 |
+
"precision": overall_precision,
|
115 |
+
"recall": overall_recall,
|
116 |
+
"f1": overall_f1,
|
117 |
+
"total_true_positives": total_tp,
|
118 |
+
"total_false_positives": total_fp,
|
119 |
+
"total_false_negatives": total_fn
|
120 |
+
}
|
121 |
+
|
122 |
+
return entity_metrics
|
123 |
+
|
124 |
+
def load_ai4privacy_dataset(num_samples: int = 100, split: str = "validation") -> List[Dict]:
|
125 |
+
"""
|
126 |
+
Load and prepare samples from the ai4privacy dataset.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
num_samples: Number of samples to evaluate
|
130 |
+
split: Dataset split to use ("train" or "validation")
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
List of prepared test cases
|
134 |
+
"""
|
135 |
+
# Load the dataset
|
136 |
+
dataset = load_dataset("ai4privacy/pii-masking-400k")
|
137 |
+
|
138 |
+
# Get the specified split
|
139 |
+
data_split = dataset[split]
|
140 |
+
|
141 |
+
# Randomly sample entries if num_samples is less than total
|
142 |
+
if num_samples < len(data_split):
|
143 |
+
indices = random.sample(range(len(data_split)), num_samples)
|
144 |
+
samples = [data_split[i] for i in indices]
|
145 |
+
else:
|
146 |
+
samples = data_split
|
147 |
+
|
148 |
+
# Convert to test case format
|
149 |
+
test_cases = []
|
150 |
+
for sample in samples:
|
151 |
+
# Extract entities from privacy_mask
|
152 |
+
entities: Dict[str, List[str]] = {}
|
153 |
+
for entity in sample['privacy_mask']:
|
154 |
+
label = entity['label']
|
155 |
+
value = entity['value']
|
156 |
+
if label not in entities:
|
157 |
+
entities[label] = []
|
158 |
+
entities[label].append(value)
|
159 |
+
|
160 |
+
test_case = {
|
161 |
+
"description": f"AI4Privacy Sample (ID: {sample['uid']})",
|
162 |
+
"input_text": sample['source_text'],
|
163 |
+
"expected_entities": entities,
|
164 |
+
"masked_text": sample['masked_text'],
|
165 |
+
"language": sample['language'],
|
166 |
+
"locale": sample['locale']
|
167 |
+
}
|
168 |
+
test_cases.append(test_case)
|
169 |
+
|
170 |
+
return test_cases
|
171 |
+
|
172 |
+
def save_results(weave_results: Dict, model_name: str, output_dir: str = "evaluation_results"):
|
173 |
+
"""Save evaluation results to files"""
|
174 |
+
output_dir = Path(output_dir)
|
175 |
+
output_dir.mkdir(exist_ok=True)
|
176 |
+
|
177 |
+
# Extract and process results
|
178 |
+
scorer_results = weave_results.get("EntityRecognitionScorer", [])
|
179 |
+
if not scorer_results or all(r is None for r in scorer_results):
|
180 |
+
print(f"No valid results to save for {model_name}")
|
181 |
+
return
|
182 |
+
|
183 |
+
# Calculate summary metrics
|
184 |
+
total_samples = len(scorer_results)
|
185 |
+
passed = sum(1 for r in scorer_results if r is not None and not isinstance(r, str))
|
186 |
+
|
187 |
+
# Aggregate entity-level metrics
|
188 |
+
entity_metrics = {}
|
189 |
+
for result in scorer_results:
|
190 |
+
try:
|
191 |
+
if isinstance(result, str) or not result:
|
192 |
+
continue
|
193 |
+
|
194 |
+
for entity_type, metrics in result.items():
|
195 |
+
if entity_type not in entity_metrics:
|
196 |
+
entity_metrics[entity_type] = {
|
197 |
+
"precision": [],
|
198 |
+
"recall": [],
|
199 |
+
"f1": []
|
200 |
+
}
|
201 |
+
entity_metrics[entity_type]["precision"].append(metrics["precision"])
|
202 |
+
entity_metrics[entity_type]["recall"].append(metrics["recall"])
|
203 |
+
entity_metrics[entity_type]["f1"].append(metrics["f1"])
|
204 |
+
except (AttributeError, TypeError, KeyError):
|
205 |
+
continue
|
206 |
+
|
207 |
+
# Calculate averages
|
208 |
+
summary_metrics = {
|
209 |
+
"total": total_samples,
|
210 |
+
"passed": passed,
|
211 |
+
"failed": total_samples - passed,
|
212 |
+
"success_rate": (passed/total_samples) if total_samples > 0 else 0,
|
213 |
+
"entity_metrics": {
|
214 |
+
entity_type: {
|
215 |
+
"precision": sum(metrics["precision"]) / len(metrics["precision"]) if metrics["precision"] else 0,
|
216 |
+
"recall": sum(metrics["recall"]) / len(metrics["recall"]) if metrics["recall"] else 0,
|
217 |
+
"f1": sum(metrics["f1"]) / len(metrics["f1"]) if metrics["f1"] else 0
|
218 |
+
}
|
219 |
+
for entity_type, metrics in entity_metrics.items()
|
220 |
+
}
|
221 |
+
}
|
222 |
+
|
223 |
+
# Save files
|
224 |
+
with open(output_dir / f"{model_name}_metrics.json", "w") as f:
|
225 |
+
json.dump(summary_metrics, f, indent=2)
|
226 |
+
|
227 |
+
# Save detailed results, filtering out string results
|
228 |
+
detailed_results = [r for r in scorer_results if not isinstance(r, str) and r is not None]
|
229 |
+
with open(output_dir / f"{model_name}_detailed_results.json", "w") as f:
|
230 |
+
json.dump(detailed_results, f, indent=2)
|
231 |
+
|
232 |
+
def print_metrics_summary(weave_results: Dict):
|
233 |
+
"""Print a summary of the evaluation metrics"""
|
234 |
+
print("\nEvaluation Summary")
|
235 |
+
print("=" * 80)
|
236 |
+
|
237 |
+
# Extract results from Weave's evaluation format
|
238 |
+
scorer_results = weave_results.get("EntityRecognitionScorer", {})
|
239 |
+
if not scorer_results:
|
240 |
+
print("No valid results available")
|
241 |
+
return
|
242 |
+
|
243 |
+
# Calculate overall metrics
|
244 |
+
total_samples = int(weave_results.get("model_latency", {}).get("count", 0))
|
245 |
+
passed = total_samples # Since we have results, all samples passed
|
246 |
+
failed = 0
|
247 |
+
|
248 |
+
print(f"Total Samples: {total_samples}")
|
249 |
+
print(f"Passed: {passed}")
|
250 |
+
print(f"Failed: {failed}")
|
251 |
+
print(f"Success Rate: {(passed/total_samples)*100:.2f}%")
|
252 |
+
|
253 |
+
# Print overall metrics
|
254 |
+
if "overall" in scorer_results:
|
255 |
+
overall = scorer_results["overall"]
|
256 |
+
print("\nOverall Metrics:")
|
257 |
+
print("-" * 80)
|
258 |
+
print(f"{'Metric':<20} {'Value':>10}")
|
259 |
+
print("-" * 80)
|
260 |
+
print(f"{'Precision':<20} {overall['precision']['mean']:>10.2f}")
|
261 |
+
print(f"{'Recall':<20} {overall['recall']['mean']:>10.2f}")
|
262 |
+
print(f"{'F1':<20} {overall['f1']['mean']:>10.2f}")
|
263 |
+
|
264 |
+
# Print entity-level metrics
|
265 |
+
print("\nEntity-Level Metrics:")
|
266 |
+
print("-" * 80)
|
267 |
+
print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
|
268 |
+
print("-" * 80)
|
269 |
+
|
270 |
+
for entity_type, metrics in scorer_results.items():
|
271 |
+
if entity_type == "overall":
|
272 |
+
continue
|
273 |
+
|
274 |
+
precision = metrics.get("precision", {}).get("mean", 0)
|
275 |
+
recall = metrics.get("recall", {}).get("mean", 0)
|
276 |
+
f1 = metrics.get("f1", {}).get("mean", 0)
|
277 |
+
|
278 |
+
print(f"{entity_type:<20} {precision:>10.2f} {recall:>10.2f} {f1:>10.2f}")
|
279 |
+
|
280 |
+
def preprocess_model_input(example: Dict) -> Dict:
|
281 |
+
"""Preprocess dataset example to match model input format."""
|
282 |
+
return {
|
283 |
+
"prompt": example["input_text"],
|
284 |
+
"model_type": example.get("model_type", "unknown") # Add model type for Presidio mapping
|
285 |
+
}
|
286 |
+
|
287 |
+
def main():
|
288 |
+
"""Main evaluation function"""
|
289 |
+
weave.init("guardrails-genie-pii-evaluation")
|
290 |
+
|
291 |
+
# Load test cases
|
292 |
+
test_cases = load_ai4privacy_dataset(num_samples=100)
|
293 |
+
|
294 |
+
# Add model type to test cases for Presidio mapping
|
295 |
+
models = {
|
296 |
+
# "regex": RegexEntityRecognitionGuardrail(should_anonymize=True),
|
297 |
+
"presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True),
|
298 |
+
# "transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True)
|
299 |
+
}
|
300 |
+
|
301 |
+
scorer = EntityRecognitionScorer()
|
302 |
+
|
303 |
+
# Evaluate each model
|
304 |
+
for model_name, guardrail in models.items():
|
305 |
+
print(f"\nEvaluating {model_name} model...")
|
306 |
+
# Add model type to test cases
|
307 |
+
model_test_cases = [{**case, "model_type": model_name} for case in test_cases]
|
308 |
+
|
309 |
+
evaluation = Evaluation(
|
310 |
+
dataset=model_test_cases,
|
311 |
+
scorers=[scorer],
|
312 |
+
preprocess_model_input=preprocess_model_input
|
313 |
+
)
|
314 |
+
|
315 |
+
results = asyncio.run(evaluation.evaluate(guardrail))
|
316 |
+
|
317 |
+
if __name__ == "__main__":
|
318 |
+
from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
|
319 |
+
from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
|
320 |
+
from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
|
321 |
+
|
322 |
+
main()
|
guardrails_genie/guardrails/entity_recognition/presidio_entity_recognition_guardrail.py
CHANGED
@@ -60,12 +60,9 @@ class PresidioEntityRecognitionGuardrail(Guardrail):
|
|
60 |
print(f"- {entity}")
|
61 |
print("=" * 25 + "\n")
|
62 |
|
63 |
-
# Initialize default values
|
64 |
if selected_entities is None:
|
65 |
-
selected_entities =
|
66 |
-
"CREDIT_CARD", "US_SSN", "EMAIL_ADDRESS", "PHONE_NUMBER",
|
67 |
-
"IP_ADDRESS", "URL", "DATE_TIME"
|
68 |
-
]
|
69 |
|
70 |
# Get available entities dynamically
|
71 |
available_entities = self.get_available_entities()
|
@@ -135,7 +132,7 @@ class PresidioEntityRecognitionGuardrail(Guardrail):
|
|
135 |
"""
|
136 |
# Analyze text for entities
|
137 |
analyzer_results = self.analyzer.analyze(
|
138 |
-
text=prompt,
|
139 |
entities=self.selected_entities,
|
140 |
language=self.language
|
141 |
)
|
|
|
60 |
print(f"- {entity}")
|
61 |
print("=" * 25 + "\n")
|
62 |
|
63 |
+
# Initialize default values to all available entities
|
64 |
if selected_entities is None:
|
65 |
+
selected_entities = self.get_available_entities()
|
|
|
|
|
|
|
66 |
|
67 |
# Get available entities dynamically
|
68 |
available_entities = self.get_available_entities()
|
|
|
132 |
"""
|
133 |
# Analyze text for entities
|
134 |
analyzer_results = self.analyzer.analyze(
|
135 |
+
text=str(prompt),
|
136 |
entities=self.selected_entities,
|
137 |
language=self.language
|
138 |
)
|
guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import Dict, Optional, ClassVar
|
2 |
|
3 |
import weave
|
4 |
from pydantic import BaseModel
|
@@ -35,24 +35,34 @@ class RegexEntityRecognitionGuardrail(Guardrail):
|
|
35 |
should_anonymize: bool = False
|
36 |
|
37 |
DEFAULT_PATTERNS: ClassVar[Dict[str, str]] = {
|
38 |
-
"
|
39 |
-
"
|
40 |
-
"
|
41 |
-
"
|
42 |
-
"
|
43 |
-
"
|
44 |
-
"
|
45 |
-
"
|
46 |
-
"
|
47 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
}
|
49 |
|
50 |
-
def __init__(self, use_defaults: bool = True, should_anonymize: bool = False, **kwargs):
|
51 |
patterns = {}
|
52 |
if use_defaults:
|
53 |
patterns = self.DEFAULT_PATTERNS.copy()
|
54 |
if kwargs.get("patterns"):
|
55 |
patterns.update(kwargs["patterns"])
|
|
|
|
|
|
|
56 |
|
57 |
# Create the RegexModel instance
|
58 |
regex_model = RegexModel(patterns=patterns)
|
@@ -72,6 +82,14 @@ class RegexEntityRecognitionGuardrail(Guardrail):
|
|
72 |
escaped_text = re.escape(text)
|
73 |
# Create a pattern that matches the exact text, case-insensitive
|
74 |
return rf"\b{escaped_text}\b"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
@weave.op()
|
77 |
def guard(self, prompt: str, custom_terms: Optional[list[str]] = None, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
|
|
|
1 |
+
from typing import Dict, Optional, ClassVar, List
|
2 |
|
3 |
import weave
|
4 |
from pydantic import BaseModel
|
|
|
35 |
should_anonymize: bool = False
|
36 |
|
37 |
DEFAULT_PATTERNS: ClassVar[Dict[str, str]] = {
|
38 |
+
"EMAIL": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
39 |
+
"TELEPHONENUM": r'\b(\+\d{1,3}[-.]?)?\(?\d{3}\)?[-.]?\d{3}[-.]?\d{4}\b',
|
40 |
+
"SOCIALNUM": r'\b\d{3}[-]?\d{2}[-]?\d{4}\b',
|
41 |
+
"CREDITCARDNUMBER": r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b',
|
42 |
+
"DATEOFBIRTH": r'\b(0[1-9]|1[0-2])[-/](0[1-9]|[12]\d|3[01])[-/](19|20)\d{2}\b',
|
43 |
+
"DRIVERLICENSENUM": r'[A-Z]\d{7}', # Example pattern, adjust for your needs
|
44 |
+
"ACCOUNTNUM": r'\b\d{10,12}\b', # Example pattern for bank accounts
|
45 |
+
"ZIPCODE": r'\b\d{5}(?:-\d{4})?\b',
|
46 |
+
"GIVENNAME": r'\b[A-Z][a-z]+\b', # Basic pattern for first names
|
47 |
+
"SURNAME": r'\b[A-Z][a-z]+\b', # Basic pattern for last names
|
48 |
+
"CITY": r'\b[A-Z][a-z]+(?:[\s-][A-Z][a-z]+)*\b',
|
49 |
+
"STREET": r'\b\d+\s+[A-Z][a-z]+\s+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Lane|Ln|Drive|Dr)\b',
|
50 |
+
"IDCARDNUM": r'[A-Z]\d{7,8}', # Generic pattern for ID cards
|
51 |
+
"USERNAME": r'@[A-Za-z]\w{3,}', # Basic username pattern
|
52 |
+
"PASSWORD": r'[A-Za-z0-9@#$%^&+=]{8,}', # Basic password pattern
|
53 |
+
"TAXNUM": r'\b\d{2}[-]\d{7}\b', # Example tax number pattern
|
54 |
+
"BUILDINGNUM": r'\b\d+[A-Za-z]?\b' # Basic building number pattern
|
55 |
}
|
56 |
|
57 |
+
def __init__(self, use_defaults: bool = True, should_anonymize: bool = False, show_available_entities: bool = False, **kwargs):
|
58 |
patterns = {}
|
59 |
if use_defaults:
|
60 |
patterns = self.DEFAULT_PATTERNS.copy()
|
61 |
if kwargs.get("patterns"):
|
62 |
patterns.update(kwargs["patterns"])
|
63 |
+
|
64 |
+
if show_available_entities:
|
65 |
+
self._print_available_entities(patterns.keys())
|
66 |
|
67 |
# Create the RegexModel instance
|
68 |
regex_model = RegexModel(patterns=patterns)
|
|
|
82 |
escaped_text = re.escape(text)
|
83 |
# Create a pattern that matches the exact text, case-insensitive
|
84 |
return rf"\b{escaped_text}\b"
|
85 |
+
|
86 |
+
def _print_available_entities(self, entities: List[str]):
|
87 |
+
"""Print available entities"""
|
88 |
+
print("\nAvailable entity types:")
|
89 |
+
print("=" * 25)
|
90 |
+
for entity in entities:
|
91 |
+
print(f"- {entity}")
|
92 |
+
print("=" * 25 + "\n")
|
93 |
|
94 |
@weave.op()
|
95 |
def guard(self, prompt: str, custom_terms: Optional[list[str]] = None, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
|
guardrails_genie/guardrails/entity_recognition/transformers_entity_recognition_guardrail.py
CHANGED
@@ -37,7 +37,7 @@ class TransformersEntityRecognitionGuardrail(Guardrail):
|
|
37 |
model_name: str = "iiiorg/piiranha-v1-detect-personal-information",
|
38 |
selected_entities: Optional[List[str]] = None,
|
39 |
should_anonymize: bool = False,
|
40 |
-
show_available_entities: bool =
|
41 |
):
|
42 |
# Load model config and extract available entities
|
43 |
config = AutoConfig.from_pretrained(model_name)
|
|
|
37 |
model_name: str = "iiiorg/piiranha-v1-detect-personal-information",
|
38 |
selected_entities: Optional[List[str]] = None,
|
39 |
should_anonymize: bool = False,
|
40 |
+
show_available_entities: bool = False,
|
41 |
):
|
42 |
# Load model config and extract available entities
|
43 |
config = AutoConfig.from_pretrained(model_name)
|
guardrails_genie/regex_model.py
CHANGED
@@ -28,7 +28,7 @@ class RegexModel(weave.Model):
|
|
28 |
}
|
29 |
|
30 |
@weave.op()
|
31 |
-
def check(self,
|
32 |
"""
|
33 |
Check text against all patterns and return detailed results.
|
34 |
|
@@ -38,23 +38,28 @@ class RegexModel(weave.Model):
|
|
38 |
Returns:
|
39 |
RegexResult containing pass/fail status and details about matches
|
40 |
"""
|
41 |
-
|
42 |
-
failed_patterns
|
43 |
|
44 |
-
for pattern_name,
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
else:
|
49 |
failed_patterns.append(pattern_name)
|
50 |
|
51 |
-
# Consider it passed only if no patterns matched (no PII found)
|
52 |
-
passed = len(matches) == 0
|
53 |
-
|
54 |
return RegexResult(
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
)
|
59 |
|
60 |
@weave.op()
|
|
|
28 |
}
|
29 |
|
30 |
@weave.op()
|
31 |
+
def check(self, prompt: str) -> RegexResult:
|
32 |
"""
|
33 |
Check text against all patterns and return detailed results.
|
34 |
|
|
|
38 |
Returns:
|
39 |
RegexResult containing pass/fail status and details about matches
|
40 |
"""
|
41 |
+
matched_patterns = {}
|
42 |
+
failed_patterns = []
|
43 |
|
44 |
+
for pattern_name, pattern in self.patterns.items():
|
45 |
+
matches = []
|
46 |
+
for match in re.finditer(pattern, prompt):
|
47 |
+
if match.groups():
|
48 |
+
# If there are capture groups, join them with a separator
|
49 |
+
matches.append('-'.join(str(g) for g in match.groups() if g is not None))
|
50 |
+
else:
|
51 |
+
# If no capture groups, use the full match
|
52 |
+
matches.append(match.group(0))
|
53 |
+
|
54 |
+
if matches:
|
55 |
+
matched_patterns[pattern_name] = matches
|
56 |
else:
|
57 |
failed_patterns.append(pattern_name)
|
58 |
|
|
|
|
|
|
|
59 |
return RegexResult(
|
60 |
+
matched_patterns=matched_patterns,
|
61 |
+
failed_patterns=failed_patterns,
|
62 |
+
passed=len(matched_patterns) == 0
|
63 |
)
|
64 |
|
65 |
@weave.op()
|