ash0ts commited on
Commit
084dab1
1 Parent(s): 28d8897

add running pii eval script

Browse files
.gitignore CHANGED
@@ -168,4 +168,5 @@ temp.txt
168
  **.csv
169
  binary-classifier/
170
  wandb/
171
- artifacts/
 
 
168
  **.csv
169
  binary-classifier/
170
  wandb/
171
+ artifacts/
172
+ evaluation_results/
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark.py CHANGED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from typing import Dict, List, Tuple
3
+ import random
4
+ from tqdm import tqdm
5
+ import json
6
+ from pathlib import Path
7
+ import weave
8
+
9
+ def load_ai4privacy_dataset(num_samples: int = 100, split: str = "validation") -> List[Dict]:
10
+ """
11
+ Load and prepare samples from the ai4privacy dataset.
12
+
13
+ Args:
14
+ num_samples: Number of samples to evaluate
15
+ split: Dataset split to use ("train" or "validation")
16
+
17
+ Returns:
18
+ List of prepared test cases
19
+ """
20
+ # Load the dataset
21
+ dataset = load_dataset("ai4privacy/pii-masking-400k")
22
+
23
+ # Get the specified split
24
+ data_split = dataset[split]
25
+
26
+ # Randomly sample entries if num_samples is less than total
27
+ if num_samples < len(data_split):
28
+ indices = random.sample(range(len(data_split)), num_samples)
29
+ samples = [data_split[i] for i in indices]
30
+ else:
31
+ samples = data_split
32
+
33
+ # Convert to test case format
34
+ test_cases = []
35
+ for sample in samples:
36
+ # Extract entities from privacy_mask
37
+ entities: Dict[str, List[str]] = {}
38
+ for entity in sample['privacy_mask']:
39
+ label = entity['label']
40
+ value = entity['value']
41
+ if label not in entities:
42
+ entities[label] = []
43
+ entities[label].append(value)
44
+
45
+ test_case = {
46
+ "description": f"AI4Privacy Sample (ID: {sample['uid']})",
47
+ "input_text": sample['source_text'],
48
+ "expected_entities": entities,
49
+ "masked_text": sample['masked_text'],
50
+ "language": sample['language'],
51
+ "locale": sample['locale']
52
+ }
53
+ test_cases.append(test_case)
54
+
55
+ return test_cases
56
+
57
+ @weave.op()
58
+ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]:
59
+ """
60
+ Evaluate a model on the test cases.
61
+
62
+ Args:
63
+ guardrail: Entity recognition guardrail to evaluate
64
+ test_cases: List of test cases
65
+
66
+ Returns:
67
+ Tuple of (metrics dict, detailed results list)
68
+ """
69
+ metrics = {
70
+ "total": len(test_cases),
71
+ "passed": 0,
72
+ "failed": 0,
73
+ "entity_metrics": {} # Will store precision/recall per entity type
74
+ }
75
+
76
+ detailed_results = []
77
+
78
+ for test_case in tqdm(test_cases, desc="Evaluating samples"):
79
+ # Run detection
80
+ result = guardrail.guard(test_case['input_text'])
81
+ detected = result.detected_entities
82
+ expected = test_case['expected_entities']
83
+
84
+ # Track entity-level metrics
85
+ all_entity_types = set(list(detected.keys()) + list(expected.keys()))
86
+ entity_results = {}
87
+
88
+ for entity_type in all_entity_types:
89
+ detected_set = set(detected.get(entity_type, []))
90
+ expected_set = set(expected.get(entity_type, []))
91
+
92
+ # Calculate metrics
93
+ true_positives = len(detected_set & expected_set)
94
+ false_positives = len(detected_set - expected_set)
95
+ false_negatives = len(expected_set - detected_set)
96
+
97
+ precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
98
+ recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
99
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
100
+
101
+ entity_results[entity_type] = {
102
+ "precision": precision,
103
+ "recall": recall,
104
+ "f1": f1,
105
+ "true_positives": true_positives,
106
+ "false_positives": false_positives,
107
+ "false_negatives": false_negatives
108
+ }
109
+
110
+ # Aggregate metrics
111
+ if entity_type not in metrics["entity_metrics"]:
112
+ metrics["entity_metrics"][entity_type] = {
113
+ "total_true_positives": 0,
114
+ "total_false_positives": 0,
115
+ "total_false_negatives": 0
116
+ }
117
+ metrics["entity_metrics"][entity_type]["total_true_positives"] += true_positives
118
+ metrics["entity_metrics"][entity_type]["total_false_positives"] += false_positives
119
+ metrics["entity_metrics"][entity_type]["total_false_negatives"] += false_negatives
120
+
121
+ # Store detailed result
122
+ detailed_result = {
123
+ "id": test_case.get("description", ""),
124
+ "language": test_case.get("language", ""),
125
+ "locale": test_case.get("locale", ""),
126
+ "input_text": test_case["input_text"],
127
+ "expected_entities": expected,
128
+ "detected_entities": detected,
129
+ "entity_metrics": entity_results,
130
+ "anonymized_text": result.anonymized_text if result.anonymized_text else None
131
+ }
132
+ detailed_results.append(detailed_result)
133
+
134
+ # Update pass/fail counts
135
+ if all(entity_results[et]["f1"] == 1.0 for et in entity_results):
136
+ metrics["passed"] += 1
137
+ else:
138
+ metrics["failed"] += 1
139
+
140
+ # Calculate final entity metrics
141
+ for entity_type, counts in metrics["entity_metrics"].items():
142
+ tp = counts["total_true_positives"]
143
+ fp = counts["total_false_positives"]
144
+ fn = counts["total_false_negatives"]
145
+
146
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
147
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
148
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
149
+
150
+ metrics["entity_metrics"][entity_type].update({
151
+ "precision": precision,
152
+ "recall": recall,
153
+ "f1": f1
154
+ })
155
+
156
+ return metrics, detailed_results
157
+
158
+ def save_results(metrics: Dict, detailed_results: List[Dict], model_name: str, output_dir: str = "evaluation_results"):
159
+ """Save evaluation results to files"""
160
+ output_dir = Path(output_dir)
161
+ output_dir.mkdir(exist_ok=True)
162
+
163
+ # Save metrics summary
164
+ with open(output_dir / f"{model_name}_metrics.json", "w") as f:
165
+ json.dump(metrics, f, indent=2)
166
+
167
+ # Save detailed results
168
+ with open(output_dir / f"{model_name}_detailed_results.json", "w") as f:
169
+ json.dump(detailed_results, f, indent=2)
170
+
171
+ def print_metrics_summary(metrics: Dict):
172
+ """Print a summary of the evaluation metrics"""
173
+ print("\nEvaluation Summary")
174
+ print("=" * 80)
175
+ print(f"Total Samples: {metrics['total']}")
176
+ print(f"Passed: {metrics['passed']}")
177
+ print(f"Failed: {metrics['failed']}")
178
+ print(f"Success Rate: {(metrics['passed']/metrics['total'])*100:.1f}%")
179
+
180
+ print("\nEntity-level Metrics:")
181
+ print("-" * 80)
182
+ print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
183
+ print("-" * 80)
184
+ for entity_type, entity_metrics in metrics["entity_metrics"].items():
185
+ print(f"{entity_type:<20} {entity_metrics['precision']:>10.2f} {entity_metrics['recall']:>10.2f} {entity_metrics['f1']:>10.2f}")
186
+
187
+ def main():
188
+ """Main evaluation function"""
189
+ weave.init("guardrails-genie-pii-evaluation")
190
+
191
+ # Load test cases
192
+ test_cases = load_ai4privacy_dataset(num_samples=100)
193
+
194
+ # Initialize models to evaluate
195
+ models = {
196
+ "regex": RegexEntityRecognitionGuardrail(should_anonymize=True),
197
+ "presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True),
198
+ "transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True)
199
+ }
200
+
201
+ # Evaluate each model
202
+ for model_name, guardrail in models.items():
203
+ print(f"\nEvaluating {model_name} model...")
204
+ metrics, detailed_results = evaluate_model(guardrail, test_cases)
205
+
206
+ # Print and save results
207
+ print_metrics_summary(metrics)
208
+ save_results(metrics, detailed_results, model_name)
209
+
210
+ if __name__ == "__main__":
211
+ from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
212
+ from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
213
+ from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
214
+
215
+ main()