ash0ts commited on
Commit
3caf047
1 Parent(s): 13d2f14

Make PII evals work

Browse files

Weave evals run but the results dont match the normal benchmark script

application_pages/chat_app.py CHANGED
@@ -66,28 +66,28 @@ def initialize_guardrails():
66
  getattr(
67
  importlib.import_module("guardrails_genie.guardrails"),
68
  guardrail_name,
69
- )()
70
  )
71
  elif guardrail_name == "RegexEntityRecognitionGuardrail":
72
  st.session_state.guardrails.append(
73
  getattr(
74
  importlib.import_module("guardrails_genie.guardrails"),
75
  guardrail_name,
76
- )()
77
  )
78
  elif guardrail_name == "TransformersEntityRecognitionGuardrail":
79
  st.session_state.guardrails.append(
80
  getattr(
81
  importlib.import_module("guardrails_genie.guardrails"),
82
  guardrail_name,
83
- )()
84
  )
85
  elif guardrail_name == "RestrictedTermsJudge":
86
  st.session_state.guardrails.append(
87
  getattr(
88
  importlib.import_module("guardrails_genie.guardrails"),
89
  guardrail_name,
90
- )()
91
  )
92
  st.session_state.guardrails_manager = GuardrailManager(
93
  guardrails=st.session_state.guardrails
 
66
  getattr(
67
  importlib.import_module("guardrails_genie.guardrails"),
68
  guardrail_name,
69
+ )(should_anonymize=True)
70
  )
71
  elif guardrail_name == "RegexEntityRecognitionGuardrail":
72
  st.session_state.guardrails.append(
73
  getattr(
74
  importlib.import_module("guardrails_genie.guardrails"),
75
  guardrail_name,
76
+ )(should_anonymize=True)
77
  )
78
  elif guardrail_name == "TransformersEntityRecognitionGuardrail":
79
  st.session_state.guardrails.append(
80
  getattr(
81
  importlib.import_module("guardrails_genie.guardrails"),
82
  guardrail_name,
83
+ )(should_anonymize=True)
84
  )
85
  elif guardrail_name == "RestrictedTermsJudge":
86
  st.session_state.guardrails.append(
87
  getattr(
88
  importlib.import_module("guardrails_genie.guardrails"),
89
  guardrail_name,
90
+ )(should_anonymize=True)
91
  )
92
  st.session_state.guardrails_manager = GuardrailManager(
93
  guardrails=st.session_state.guardrails
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark.py CHANGED
@@ -6,6 +6,35 @@ import json
6
  from pathlib import Path
7
  import weave
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def load_ai4privacy_dataset(num_samples: int = 100, split: str = "validation") -> List[Dict]:
10
  """
11
  Load and prepare samples from the ai4privacy dataset.
@@ -81,6 +110,17 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]
81
  detected = result.detected_entities
82
  expected = test_case['expected_entities']
83
 
 
 
 
 
 
 
 
 
 
 
 
84
  # Track entity-level metrics
85
  all_entity_types = set(list(detected.keys()) + list(expected.keys()))
86
  entity_results = {}
@@ -137,12 +177,20 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]
137
  else:
138
  metrics["failed"] += 1
139
 
140
- # Calculate final entity metrics
 
 
 
 
141
  for entity_type, counts in metrics["entity_metrics"].items():
142
  tp = counts["total_true_positives"]
143
  fp = counts["total_false_positives"]
144
  fn = counts["total_false_negatives"]
145
 
 
 
 
 
146
  precision = tp / (tp + fp) if (tp + fp) > 0 else 0
147
  recall = tp / (tp + fn) if (tp + fn) > 0 else 0
148
  f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
@@ -153,6 +201,20 @@ def evaluate_model(guardrail, test_cases: List[Dict]) -> Tuple[Dict, List[Dict]]
153
  "f1": f1
154
  })
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  return metrics, detailed_results
157
 
158
  def save_results(metrics: Dict, detailed_results: List[Dict], model_name: str, output_dir: str = "evaluation_results"):
@@ -177,6 +239,15 @@ def print_metrics_summary(metrics: Dict):
177
  print(f"Failed: {metrics['failed']}")
178
  print(f"Success Rate: {(metrics['passed']/metrics['total'])*100:.1f}%")
179
 
 
 
 
 
 
 
 
 
 
180
  print("\nEntity-level Metrics:")
181
  print("-" * 80)
182
  print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
@@ -193,9 +264,9 @@ def main():
193
 
194
  # Initialize models to evaluate
195
  models = {
196
- "regex": RegexEntityRecognitionGuardrail(should_anonymize=True),
197
- "presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True),
198
- "transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True)
199
  }
200
 
201
  # Evaluate each model
 
6
  from pathlib import Path
7
  import weave
8
 
9
+ # Add this mapping dictionary near the top of the file
10
+ PRESIDIO_TO_TRANSFORMER_MAPPING = {
11
+ "EMAIL_ADDRESS": "EMAIL",
12
+ "PHONE_NUMBER": "TELEPHONENUM",
13
+ "US_SSN": "SOCIALNUM",
14
+ "CREDIT_CARD": "CREDITCARDNUMBER",
15
+ "IP_ADDRESS": "IDCARDNUM",
16
+ "DATE_TIME": "DATEOFBIRTH",
17
+ "US_PASSPORT": "IDCARDNUM",
18
+ "US_DRIVER_LICENSE": "DRIVERLICENSENUM",
19
+ "US_BANK_NUMBER": "ACCOUNTNUM",
20
+ "LOCATION": "CITY",
21
+ "URL": "USERNAME", # URLs often contain usernames
22
+ "IN_PAN": "TAXNUM", # Indian Permanent Account Number
23
+ "UK_NHS": "IDCARDNUM",
24
+ "SG_NRIC_FIN": "IDCARDNUM",
25
+ "AU_ABN": "TAXNUM", # Australian Business Number
26
+ "AU_ACN": "TAXNUM", # Australian Company Number
27
+ "AU_TFN": "TAXNUM", # Australian Tax File Number
28
+ "AU_MEDICARE": "IDCARDNUM",
29
+ "IN_AADHAAR": "IDCARDNUM", # Indian national ID
30
+ "IN_VOTER": "IDCARDNUM",
31
+ "IN_PASSPORT": "IDCARDNUM",
32
+ "CRYPTO": "ACCOUNTNUM", # Cryptocurrency addresses
33
+ "IBAN_CODE": "ACCOUNTNUM",
34
+ "MEDICAL_LICENSE": "IDCARDNUM",
35
+ "IN_VEHICLE_REGISTRATION": "IDCARDNUM"
36
+ }
37
+
38
  def load_ai4privacy_dataset(num_samples: int = 100, split: str = "validation") -> List[Dict]:
39
  """
40
  Load and prepare samples from the ai4privacy dataset.
 
110
  detected = result.detected_entities
111
  expected = test_case['expected_entities']
112
 
113
+ # Map Presidio entities if this is the Presidio guardrail
114
+ if isinstance(guardrail, PresidioEntityRecognitionGuardrail):
115
+ mapped_detected = {}
116
+ for entity_type, values in detected.items():
117
+ mapped_type = PRESIDIO_TO_TRANSFORMER_MAPPING.get(entity_type)
118
+ if mapped_type:
119
+ if mapped_type not in mapped_detected:
120
+ mapped_detected[mapped_type] = []
121
+ mapped_detected[mapped_type].extend(values)
122
+ detected = mapped_detected
123
+
124
  # Track entity-level metrics
125
  all_entity_types = set(list(detected.keys()) + list(expected.keys()))
126
  entity_results = {}
 
177
  else:
178
  metrics["failed"] += 1
179
 
180
+ # Calculate final entity metrics and track totals for overall metrics
181
+ total_tp = 0
182
+ total_fp = 0
183
+ total_fn = 0
184
+
185
  for entity_type, counts in metrics["entity_metrics"].items():
186
  tp = counts["total_true_positives"]
187
  fp = counts["total_false_positives"]
188
  fn = counts["total_false_negatives"]
189
 
190
+ total_tp += tp
191
+ total_fp += fp
192
+ total_fn += fn
193
+
194
  precision = tp / (tp + fp) if (tp + fp) > 0 else 0
195
  recall = tp / (tp + fn) if (tp + fn) > 0 else 0
196
  f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
 
201
  "f1": f1
202
  })
203
 
204
+ # Calculate overall metrics
205
+ overall_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
206
+ overall_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
207
+ overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0
208
+
209
+ metrics["overall"] = {
210
+ "precision": overall_precision,
211
+ "recall": overall_recall,
212
+ "f1": overall_f1,
213
+ "total_true_positives": total_tp,
214
+ "total_false_positives": total_fp,
215
+ "total_false_negatives": total_fn
216
+ }
217
+
218
  return metrics, detailed_results
219
 
220
  def save_results(metrics: Dict, detailed_results: List[Dict], model_name: str, output_dir: str = "evaluation_results"):
 
239
  print(f"Failed: {metrics['failed']}")
240
  print(f"Success Rate: {(metrics['passed']/metrics['total'])*100:.1f}%")
241
 
242
+ # Print overall metrics
243
+ print("\nOverall Metrics:")
244
+ print("-" * 80)
245
+ print(f"{'Metric':<20} {'Value':>10}")
246
+ print("-" * 80)
247
+ print(f"{'Precision':<20} {metrics['overall']['precision']:>10.2f}")
248
+ print(f"{'Recall':<20} {metrics['overall']['recall']:>10.2f}")
249
+ print(f"{'F1':<20} {metrics['overall']['f1']:>10.2f}")
250
+
251
  print("\nEntity-level Metrics:")
252
  print("-" * 80)
253
  print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
 
264
 
265
  # Initialize models to evaluate
266
  models = {
267
+ "regex": RegexEntityRecognitionGuardrail(should_anonymize=True, show_available_entities=True),
268
+ "presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True, show_available_entities=True),
269
+ "transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True, show_available_entities=True)
270
  }
271
 
272
  # Evaluate each model
guardrails_genie/guardrails/entity_recognition/pii_examples/pii_benchmark_weave.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from typing import Dict, List, Tuple, Optional
3
+ import random
4
+ from tqdm import tqdm
5
+ import json
6
+ from pathlib import Path
7
+ import weave
8
+ from weave.scorers import Scorer
9
+ from weave import Evaluation
10
+ import asyncio
11
+
12
+ # Add this mapping dictionary near the top of the file
13
+ PRESIDIO_TO_TRANSFORMER_MAPPING = {
14
+ "EMAIL_ADDRESS": "EMAIL",
15
+ "PHONE_NUMBER": "TELEPHONENUM",
16
+ "US_SSN": "SOCIALNUM",
17
+ "CREDIT_CARD": "CREDITCARDNUMBER",
18
+ "IP_ADDRESS": "IDCARDNUM",
19
+ "DATE_TIME": "DATEOFBIRTH",
20
+ "US_PASSPORT": "IDCARDNUM",
21
+ "US_DRIVER_LICENSE": "DRIVERLICENSENUM",
22
+ "US_BANK_NUMBER": "ACCOUNTNUM",
23
+ "LOCATION": "CITY",
24
+ "URL": "USERNAME", # URLs often contain usernames
25
+ "IN_PAN": "TAXNUM", # Indian Permanent Account Number
26
+ "UK_NHS": "IDCARDNUM",
27
+ "SG_NRIC_FIN": "IDCARDNUM",
28
+ "AU_ABN": "TAXNUM", # Australian Business Number
29
+ "AU_ACN": "TAXNUM", # Australian Company Number
30
+ "AU_TFN": "TAXNUM", # Australian Tax File Number
31
+ "AU_MEDICARE": "IDCARDNUM",
32
+ "IN_AADHAAR": "IDCARDNUM", # Indian national ID
33
+ "IN_VOTER": "IDCARDNUM",
34
+ "IN_PASSPORT": "IDCARDNUM",
35
+ "CRYPTO": "ACCOUNTNUM", # Cryptocurrency addresses
36
+ "IBAN_CODE": "ACCOUNTNUM",
37
+ "MEDICAL_LICENSE": "IDCARDNUM",
38
+ "IN_VEHICLE_REGISTRATION": "IDCARDNUM"
39
+ }
40
+
41
+ class EntityRecognitionScorer(Scorer):
42
+ """Scorer for evaluating entity recognition performance"""
43
+
44
+ @weave.op()
45
+ async def score(self, model_output: Optional[dict], input_text: str, expected_entities: Dict) -> Dict:
46
+ """Score entity recognition results"""
47
+ if not model_output:
48
+ return {"f1": 0.0}
49
+
50
+ # Convert Pydantic model to dict if necessary
51
+ if hasattr(model_output, "model_dump"):
52
+ model_output = model_output.model_dump()
53
+ elif hasattr(model_output, "dict"):
54
+ model_output = model_output.dict()
55
+
56
+ detected = model_output.get("detected_entities", {})
57
+
58
+ # Map Presidio entities if needed
59
+ if model_output.get("model_type") == "presidio":
60
+ mapped_detected = {}
61
+ for entity_type, values in detected.items():
62
+ mapped_type = PRESIDIO_TO_TRANSFORMER_MAPPING.get(entity_type)
63
+ if mapped_type:
64
+ if mapped_type not in mapped_detected:
65
+ mapped_detected[mapped_type] = []
66
+ mapped_detected[mapped_type].extend(values)
67
+ detected = mapped_detected
68
+
69
+ # Track entity-level metrics
70
+ all_entity_types = set(list(detected.keys()) + list(expected_entities.keys()))
71
+ entity_metrics = {}
72
+
73
+ for entity_type in all_entity_types:
74
+ detected_set = set(detected.get(entity_type, []))
75
+ expected_set = set(expected_entities.get(entity_type, []))
76
+
77
+ # Calculate metrics
78
+ true_positives = len(detected_set & expected_set)
79
+ false_positives = len(detected_set - expected_set)
80
+ false_negatives = len(expected_set - detected_set)
81
+
82
+ if entity_type not in entity_metrics:
83
+ entity_metrics[entity_type] = {
84
+ "total_true_positives": 0,
85
+ "total_false_positives": 0,
86
+ "total_false_negatives": 0
87
+ }
88
+
89
+ entity_metrics[entity_type]["total_true_positives"] += true_positives
90
+ entity_metrics[entity_type]["total_false_positives"] += false_positives
91
+ entity_metrics[entity_type]["total_false_negatives"] += false_negatives
92
+
93
+ # Calculate per-entity metrics
94
+ precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
95
+ recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
96
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
97
+
98
+ entity_metrics[entity_type].update({
99
+ "precision": precision,
100
+ "recall": recall,
101
+ "f1": f1
102
+ })
103
+
104
+ # Calculate overall metrics
105
+ total_tp = sum(metrics["total_true_positives"] for metrics in entity_metrics.values())
106
+ total_fp = sum(metrics["total_false_positives"] for metrics in entity_metrics.values())
107
+ total_fn = sum(metrics["total_false_negatives"] for metrics in entity_metrics.values())
108
+
109
+ overall_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
110
+ overall_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
111
+ overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0
112
+
113
+ entity_metrics["overall"] = {
114
+ "precision": overall_precision,
115
+ "recall": overall_recall,
116
+ "f1": overall_f1,
117
+ "total_true_positives": total_tp,
118
+ "total_false_positives": total_fp,
119
+ "total_false_negatives": total_fn
120
+ }
121
+
122
+ return entity_metrics
123
+
124
+ def load_ai4privacy_dataset(num_samples: int = 100, split: str = "validation") -> List[Dict]:
125
+ """
126
+ Load and prepare samples from the ai4privacy dataset.
127
+
128
+ Args:
129
+ num_samples: Number of samples to evaluate
130
+ split: Dataset split to use ("train" or "validation")
131
+
132
+ Returns:
133
+ List of prepared test cases
134
+ """
135
+ # Load the dataset
136
+ dataset = load_dataset("ai4privacy/pii-masking-400k")
137
+
138
+ # Get the specified split
139
+ data_split = dataset[split]
140
+
141
+ # Randomly sample entries if num_samples is less than total
142
+ if num_samples < len(data_split):
143
+ indices = random.sample(range(len(data_split)), num_samples)
144
+ samples = [data_split[i] for i in indices]
145
+ else:
146
+ samples = data_split
147
+
148
+ # Convert to test case format
149
+ test_cases = []
150
+ for sample in samples:
151
+ # Extract entities from privacy_mask
152
+ entities: Dict[str, List[str]] = {}
153
+ for entity in sample['privacy_mask']:
154
+ label = entity['label']
155
+ value = entity['value']
156
+ if label not in entities:
157
+ entities[label] = []
158
+ entities[label].append(value)
159
+
160
+ test_case = {
161
+ "description": f"AI4Privacy Sample (ID: {sample['uid']})",
162
+ "input_text": sample['source_text'],
163
+ "expected_entities": entities,
164
+ "masked_text": sample['masked_text'],
165
+ "language": sample['language'],
166
+ "locale": sample['locale']
167
+ }
168
+ test_cases.append(test_case)
169
+
170
+ return test_cases
171
+
172
+ def save_results(weave_results: Dict, model_name: str, output_dir: str = "evaluation_results"):
173
+ """Save evaluation results to files"""
174
+ output_dir = Path(output_dir)
175
+ output_dir.mkdir(exist_ok=True)
176
+
177
+ # Extract and process results
178
+ scorer_results = weave_results.get("EntityRecognitionScorer", [])
179
+ if not scorer_results or all(r is None for r in scorer_results):
180
+ print(f"No valid results to save for {model_name}")
181
+ return
182
+
183
+ # Calculate summary metrics
184
+ total_samples = len(scorer_results)
185
+ passed = sum(1 for r in scorer_results if r is not None and not isinstance(r, str))
186
+
187
+ # Aggregate entity-level metrics
188
+ entity_metrics = {}
189
+ for result in scorer_results:
190
+ try:
191
+ if isinstance(result, str) or not result:
192
+ continue
193
+
194
+ for entity_type, metrics in result.items():
195
+ if entity_type not in entity_metrics:
196
+ entity_metrics[entity_type] = {
197
+ "precision": [],
198
+ "recall": [],
199
+ "f1": []
200
+ }
201
+ entity_metrics[entity_type]["precision"].append(metrics["precision"])
202
+ entity_metrics[entity_type]["recall"].append(metrics["recall"])
203
+ entity_metrics[entity_type]["f1"].append(metrics["f1"])
204
+ except (AttributeError, TypeError, KeyError):
205
+ continue
206
+
207
+ # Calculate averages
208
+ summary_metrics = {
209
+ "total": total_samples,
210
+ "passed": passed,
211
+ "failed": total_samples - passed,
212
+ "success_rate": (passed/total_samples) if total_samples > 0 else 0,
213
+ "entity_metrics": {
214
+ entity_type: {
215
+ "precision": sum(metrics["precision"]) / len(metrics["precision"]) if metrics["precision"] else 0,
216
+ "recall": sum(metrics["recall"]) / len(metrics["recall"]) if metrics["recall"] else 0,
217
+ "f1": sum(metrics["f1"]) / len(metrics["f1"]) if metrics["f1"] else 0
218
+ }
219
+ for entity_type, metrics in entity_metrics.items()
220
+ }
221
+ }
222
+
223
+ # Save files
224
+ with open(output_dir / f"{model_name}_metrics.json", "w") as f:
225
+ json.dump(summary_metrics, f, indent=2)
226
+
227
+ # Save detailed results, filtering out string results
228
+ detailed_results = [r for r in scorer_results if not isinstance(r, str) and r is not None]
229
+ with open(output_dir / f"{model_name}_detailed_results.json", "w") as f:
230
+ json.dump(detailed_results, f, indent=2)
231
+
232
+ def print_metrics_summary(weave_results: Dict):
233
+ """Print a summary of the evaluation metrics"""
234
+ print("\nEvaluation Summary")
235
+ print("=" * 80)
236
+
237
+ # Extract results from Weave's evaluation format
238
+ scorer_results = weave_results.get("EntityRecognitionScorer", {})
239
+ if not scorer_results:
240
+ print("No valid results available")
241
+ return
242
+
243
+ # Calculate overall metrics
244
+ total_samples = int(weave_results.get("model_latency", {}).get("count", 0))
245
+ passed = total_samples # Since we have results, all samples passed
246
+ failed = 0
247
+
248
+ print(f"Total Samples: {total_samples}")
249
+ print(f"Passed: {passed}")
250
+ print(f"Failed: {failed}")
251
+ print(f"Success Rate: {(passed/total_samples)*100:.2f}%")
252
+
253
+ # Print overall metrics
254
+ if "overall" in scorer_results:
255
+ overall = scorer_results["overall"]
256
+ print("\nOverall Metrics:")
257
+ print("-" * 80)
258
+ print(f"{'Metric':<20} {'Value':>10}")
259
+ print("-" * 80)
260
+ print(f"{'Precision':<20} {overall['precision']['mean']:>10.2f}")
261
+ print(f"{'Recall':<20} {overall['recall']['mean']:>10.2f}")
262
+ print(f"{'F1':<20} {overall['f1']['mean']:>10.2f}")
263
+
264
+ # Print entity-level metrics
265
+ print("\nEntity-Level Metrics:")
266
+ print("-" * 80)
267
+ print(f"{'Entity Type':<20} {'Precision':>10} {'Recall':>10} {'F1':>10}")
268
+ print("-" * 80)
269
+
270
+ for entity_type, metrics in scorer_results.items():
271
+ if entity_type == "overall":
272
+ continue
273
+
274
+ precision = metrics.get("precision", {}).get("mean", 0)
275
+ recall = metrics.get("recall", {}).get("mean", 0)
276
+ f1 = metrics.get("f1", {}).get("mean", 0)
277
+
278
+ print(f"{entity_type:<20} {precision:>10.2f} {recall:>10.2f} {f1:>10.2f}")
279
+
280
+ def preprocess_model_input(example: Dict) -> Dict:
281
+ """Preprocess dataset example to match model input format."""
282
+ return {
283
+ "prompt": example["input_text"],
284
+ "model_type": example.get("model_type", "unknown") # Add model type for Presidio mapping
285
+ }
286
+
287
+ def main():
288
+ """Main evaluation function"""
289
+ weave.init("guardrails-genie-pii-evaluation")
290
+
291
+ # Load test cases
292
+ test_cases = load_ai4privacy_dataset(num_samples=100)
293
+
294
+ # Add model type to test cases for Presidio mapping
295
+ models = {
296
+ # "regex": RegexEntityRecognitionGuardrail(should_anonymize=True),
297
+ "presidio": PresidioEntityRecognitionGuardrail(should_anonymize=True),
298
+ # "transformers": TransformersEntityRecognitionGuardrail(should_anonymize=True)
299
+ }
300
+
301
+ scorer = EntityRecognitionScorer()
302
+
303
+ # Evaluate each model
304
+ for model_name, guardrail in models.items():
305
+ print(f"\nEvaluating {model_name} model...")
306
+ # Add model type to test cases
307
+ model_test_cases = [{**case, "model_type": model_name} for case in test_cases]
308
+
309
+ evaluation = Evaluation(
310
+ dataset=model_test_cases,
311
+ scorers=[scorer],
312
+ preprocess_model_input=preprocess_model_input
313
+ )
314
+
315
+ results = asyncio.run(evaluation.evaluate(guardrail))
316
+
317
+ if __name__ == "__main__":
318
+ from guardrails_genie.guardrails.entity_recognition.regex_entity_recognition_guardrail import RegexEntityRecognitionGuardrail
319
+ from guardrails_genie.guardrails.entity_recognition.presidio_entity_recognition_guardrail import PresidioEntityRecognitionGuardrail
320
+ from guardrails_genie.guardrails.entity_recognition.transformers_entity_recognition_guardrail import TransformersEntityRecognitionGuardrail
321
+
322
+ main()
guardrails_genie/guardrails/entity_recognition/presidio_entity_recognition_guardrail.py CHANGED
@@ -60,12 +60,9 @@ class PresidioEntityRecognitionGuardrail(Guardrail):
60
  print(f"- {entity}")
61
  print("=" * 25 + "\n")
62
 
63
- # Initialize default values
64
  if selected_entities is None:
65
- selected_entities = [
66
- "CREDIT_CARD", "US_SSN", "EMAIL_ADDRESS", "PHONE_NUMBER",
67
- "IP_ADDRESS", "URL", "DATE_TIME"
68
- ]
69
 
70
  # Get available entities dynamically
71
  available_entities = self.get_available_entities()
@@ -135,7 +132,7 @@ class PresidioEntityRecognitionGuardrail(Guardrail):
135
  """
136
  # Analyze text for entities
137
  analyzer_results = self.analyzer.analyze(
138
- text=prompt,
139
  entities=self.selected_entities,
140
  language=self.language
141
  )
 
60
  print(f"- {entity}")
61
  print("=" * 25 + "\n")
62
 
63
+ # Initialize default values to all available entities
64
  if selected_entities is None:
65
+ selected_entities = self.get_available_entities()
 
 
 
66
 
67
  # Get available entities dynamically
68
  available_entities = self.get_available_entities()
 
132
  """
133
  # Analyze text for entities
134
  analyzer_results = self.analyzer.analyze(
135
+ text=str(prompt),
136
  entities=self.selected_entities,
137
  language=self.language
138
  )
guardrails_genie/guardrails/entity_recognition/regex_entity_recognition_guardrail.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, Optional, ClassVar
2
 
3
  import weave
4
  from pydantic import BaseModel
@@ -35,24 +35,34 @@ class RegexEntityRecognitionGuardrail(Guardrail):
35
  should_anonymize: bool = False
36
 
37
  DEFAULT_PATTERNS: ClassVar[Dict[str, str]] = {
38
- "email": r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
39
- "phone_number": r"\b(?:\+?1[-.]?)?\(?(?:[0-9]{3})\)?[-.]?(?:[0-9]{3})[-.]?(?:[0-9]{4})\b",
40
- "ssn": r"\b\d{3}[-]?\d{2}[-]?\d{4}\b",
41
- "credit_card": r"\b\d{4}[-.]?\d{4}[-.]?\d{4}[-.]?\d{4}\b",
42
- "ip_address": r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b",
43
- "date_of_birth": r"\b\d{2}[-/]\d{2}[-/]\d{4}\b",
44
- "passport": r"\b[A-Z]{1,2}[0-9]{6,9}\b",
45
- "drivers_license": r"\b[A-Z]\d{7}\b",
46
- "bank_account": r"\b\d{8,17}\b",
47
- "zip_code": r"\b\d{5}(?:[-]\d{4})?\b"
 
 
 
 
 
 
 
48
  }
49
 
50
- def __init__(self, use_defaults: bool = True, should_anonymize: bool = False, **kwargs):
51
  patterns = {}
52
  if use_defaults:
53
  patterns = self.DEFAULT_PATTERNS.copy()
54
  if kwargs.get("patterns"):
55
  patterns.update(kwargs["patterns"])
 
 
 
56
 
57
  # Create the RegexModel instance
58
  regex_model = RegexModel(patterns=patterns)
@@ -72,6 +82,14 @@ class RegexEntityRecognitionGuardrail(Guardrail):
72
  escaped_text = re.escape(text)
73
  # Create a pattern that matches the exact text, case-insensitive
74
  return rf"\b{escaped_text}\b"
 
 
 
 
 
 
 
 
75
 
76
  @weave.op()
77
  def guard(self, prompt: str, custom_terms: Optional[list[str]] = None, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
 
1
+ from typing import Dict, Optional, ClassVar, List
2
 
3
  import weave
4
  from pydantic import BaseModel
 
35
  should_anonymize: bool = False
36
 
37
  DEFAULT_PATTERNS: ClassVar[Dict[str, str]] = {
38
+ "EMAIL": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
39
+ "TELEPHONENUM": r'\b(\+\d{1,3}[-.]?)?\(?\d{3}\)?[-.]?\d{3}[-.]?\d{4}\b',
40
+ "SOCIALNUM": r'\b\d{3}[-]?\d{2}[-]?\d{4}\b',
41
+ "CREDITCARDNUMBER": r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b',
42
+ "DATEOFBIRTH": r'\b(0[1-9]|1[0-2])[-/](0[1-9]|[12]\d|3[01])[-/](19|20)\d{2}\b',
43
+ "DRIVERLICENSENUM": r'[A-Z]\d{7}', # Example pattern, adjust for your needs
44
+ "ACCOUNTNUM": r'\b\d{10,12}\b', # Example pattern for bank accounts
45
+ "ZIPCODE": r'\b\d{5}(?:-\d{4})?\b',
46
+ "GIVENNAME": r'\b[A-Z][a-z]+\b', # Basic pattern for first names
47
+ "SURNAME": r'\b[A-Z][a-z]+\b', # Basic pattern for last names
48
+ "CITY": r'\b[A-Z][a-z]+(?:[\s-][A-Z][a-z]+)*\b',
49
+ "STREET": r'\b\d+\s+[A-Z][a-z]+\s+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Lane|Ln|Drive|Dr)\b',
50
+ "IDCARDNUM": r'[A-Z]\d{7,8}', # Generic pattern for ID cards
51
+ "USERNAME": r'@[A-Za-z]\w{3,}', # Basic username pattern
52
+ "PASSWORD": r'[A-Za-z0-9@#$%^&+=]{8,}', # Basic password pattern
53
+ "TAXNUM": r'\b\d{2}[-]\d{7}\b', # Example tax number pattern
54
+ "BUILDINGNUM": r'\b\d+[A-Za-z]?\b' # Basic building number pattern
55
  }
56
 
57
+ def __init__(self, use_defaults: bool = True, should_anonymize: bool = False, show_available_entities: bool = False, **kwargs):
58
  patterns = {}
59
  if use_defaults:
60
  patterns = self.DEFAULT_PATTERNS.copy()
61
  if kwargs.get("patterns"):
62
  patterns.update(kwargs["patterns"])
63
+
64
+ if show_available_entities:
65
+ self._print_available_entities(patterns.keys())
66
 
67
  # Create the RegexModel instance
68
  regex_model = RegexModel(patterns=patterns)
 
82
  escaped_text = re.escape(text)
83
  # Create a pattern that matches the exact text, case-insensitive
84
  return rf"\b{escaped_text}\b"
85
+
86
+ def _print_available_entities(self, entities: List[str]):
87
+ """Print available entities"""
88
+ print("\nAvailable entity types:")
89
+ print("=" * 25)
90
+ for entity in entities:
91
+ print(f"- {entity}")
92
+ print("=" * 25 + "\n")
93
 
94
  @weave.op()
95
  def guard(self, prompt: str, custom_terms: Optional[list[str]] = None, return_detected_types: bool = True, aggregate_redaction: bool = True, **kwargs) -> RegexEntityRecognitionResponse | RegexEntityRecognitionSimpleResponse:
guardrails_genie/guardrails/entity_recognition/transformers_entity_recognition_guardrail.py CHANGED
@@ -37,7 +37,7 @@ class TransformersEntityRecognitionGuardrail(Guardrail):
37
  model_name: str = "iiiorg/piiranha-v1-detect-personal-information",
38
  selected_entities: Optional[List[str]] = None,
39
  should_anonymize: bool = False,
40
- show_available_entities: bool = True,
41
  ):
42
  # Load model config and extract available entities
43
  config = AutoConfig.from_pretrained(model_name)
 
37
  model_name: str = "iiiorg/piiranha-v1-detect-personal-information",
38
  selected_entities: Optional[List[str]] = None,
39
  should_anonymize: bool = False,
40
+ show_available_entities: bool = False,
41
  ):
42
  # Load model config and extract available entities
43
  config = AutoConfig.from_pretrained(model_name)
guardrails_genie/regex_model.py CHANGED
@@ -28,7 +28,7 @@ class RegexModel(weave.Model):
28
  }
29
 
30
  @weave.op()
31
- def check(self, text: str) -> RegexResult:
32
  """
33
  Check text against all patterns and return detailed results.
34
 
@@ -38,23 +38,28 @@ class RegexModel(weave.Model):
38
  Returns:
39
  RegexResult containing pass/fail status and details about matches
40
  """
41
- matches: Dict[str, List[str]] = {}
42
- failed_patterns: List[str] = []
43
 
44
- for pattern_name, compiled_pattern in self._compiled_patterns.items():
45
- found_matches = compiled_pattern.findall(text)
46
- if found_matches:
47
- matches[pattern_name] = found_matches
 
 
 
 
 
 
 
 
48
  else:
49
  failed_patterns.append(pattern_name)
50
 
51
- # Consider it passed only if no patterns matched (no PII found)
52
- passed = len(matches) == 0
53
-
54
  return RegexResult(
55
- passed=passed,
56
- matched_patterns=matches,
57
- failed_patterns=failed_patterns
58
  )
59
 
60
  @weave.op()
 
28
  }
29
 
30
  @weave.op()
31
+ def check(self, prompt: str) -> RegexResult:
32
  """
33
  Check text against all patterns and return detailed results.
34
 
 
38
  Returns:
39
  RegexResult containing pass/fail status and details about matches
40
  """
41
+ matched_patterns = {}
42
+ failed_patterns = []
43
 
44
+ for pattern_name, pattern in self.patterns.items():
45
+ matches = []
46
+ for match in re.finditer(pattern, prompt):
47
+ if match.groups():
48
+ # If there are capture groups, join them with a separator
49
+ matches.append('-'.join(str(g) for g in match.groups() if g is not None))
50
+ else:
51
+ # If no capture groups, use the full match
52
+ matches.append(match.group(0))
53
+
54
+ if matches:
55
+ matched_patterns[pattern_name] = matches
56
  else:
57
  failed_patterns.append(pattern_name)
58
 
 
 
 
59
  return RegexResult(
60
+ matched_patterns=matched_patterns,
61
+ failed_patterns=failed_patterns,
62
+ passed=len(matched_patterns) == 0
63
  )
64
 
65
  @weave.op()