DeId-Small / run_benchmarks.py
Minibase's picture
Upload run_benchmarks.py with huggingface_hub
89fb98c verified
#!/usr/bin/env python3
"""
Minimal De-identification Benchmark Runner for HuggingFace Publication
This script evaluates a de-identification model's performance on key metrics:
- PII Detection Rate: How well it identifies personal identifiers
- Completeness: Whether all PII is successfully masked
- Semantic Preservation: How well meaning is preserved
- Latency: Response time performance
- Domain Performance: Results across different text types
"""
import json
import re
import time
import requests
from typing import Dict, List, Tuple, Any
import yaml
from datetime import datetime
import sys
import os
class DeIdBenchmarkRunner:
def __init__(self, config_path: str):
with open(config_path, 'r') as f:
self.config = yaml.safe_load(f)
self.results = {
"metadata": {
"timestamp": datetime.now().isoformat(),
"model": "Minibase-DeId-Small",
"dataset": self.config["datasets"]["benchmark_dataset"]["file_path"],
"sample_size": self.config["datasets"]["benchmark_dataset"]["sample_size"]
},
"metrics": {},
"domain_performance": {},
"examples": []
}
def load_dataset(self) -> List[Dict]:
"""Load and sample the benchmark dataset"""
dataset_path = self.config["datasets"]["benchmark_dataset"]["file_path"]
sample_size = self.config["datasets"]["benchmark_dataset"]["sample_size"]
examples = []
with open(dataset_path, 'r') as f:
for i, line in enumerate(f):
if i >= sample_size:
break
examples.append(json.loads(line.strip()))
print(f"โœ… Loaded {len(examples)} examples from {dataset_path}")
return examples
# Removed domain categorization as requested
def extract_placeholders(self, text: str) -> List[str]:
"""Extract all placeholder tags from text (e.g., [NAME_1], [DOB_1])"""
# Match patterns like [WORD_1], [WORD_NUMBER], etc.
pattern = r'\[([A-Z_]+_\d+)\]'
return re.findall(pattern, text)
def calculate_pii_detection_rate(self, input_text: str, predicted: str) -> float:
"""Calculate PII detection rate - if input has PII and output has placeholders, count as success"""
# Check if input contains any PII patterns
input_has_pii = self._input_contains_pii(input_text)
if not input_has_pii:
return 1.0 # No PII in input, so detection is perfect
# Check if output contains any placeholders at all
predicted_placeholders = self.extract_placeholders(predicted)
output_has_placeholders = len(predicted_placeholders) > 0
# If input has PII and output has placeholders, count as successful detection
return 1.0 if output_has_placeholders else 0.0
def _input_contains_pii(self, input_text: str) -> bool:
"""Check if input text contains personal identifiable information"""
pii_patterns = [
r'\b\d{4}-\d{2}-\d{2}\b', # Dates like 1985-03-15
r'\b\d{1,3}/\d{1,2}/\d{4}\b', # Dates like 05/12/1980
r'\b\d{1,3}\s+[A-Z][a-z]+\s+(?:St|Street|Ave|Avenue|Rd|Road|Blvd|Boulevard)\b', # Addresses
r'\(\d{3}\)\s*\d{3}-\d{4}\b', # Phone numbers like (555) 123-4567
r'\+?\d{1,3}[-.\s]?\d{3}[-.\s]?\d{4}\b', # International phone numbers
r'\b[A-Z][a-z]+\s+[A-Z][a-z]+\b', # Names (First Last)
r'\b[A-Z][a-z]+\s+[A-Z]\.\s*[A-Z][a-z]+\b', # Names with middle initial
r'\b\d+@\w+\.\w+\b', # Email addresses
r'\b[A-Z]{2,}\d+\b', # IDs like EMP-001-XYZ
r'\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?\b', # Monetary amounts like $85,000
r'\b\d{3}-\d{2}-\d{4}\b', # SSN-like patterns
r'\b(?:Mr|Mrs|Ms|Dr|Prof)\.\s+[A-Z][a-z]+\b', # Titles with names
r'\b\d{5}(?:-\d{4})?\b', # ZIP codes
r'\b[A-Z][a-z]+,\s+[A-Z]{2}\s+\d{5}\b', # City, State ZIP
]
return any(re.search(pattern, input_text) for pattern in pii_patterns)
def calculate_completeness(self, predicted: str) -> bool:
"""Check if response appears to have no obvious PII remaining"""
# Simple heuristics for detecting remaining PII
pii_patterns = [
r'\b\d{4}-\d{2}-\d{2}\b', # Dates like 1985-03-15
r'\b\d{1,3}\s+[A-Z][a-z]+\s+(?:St|Street|Ave|Avenue|Rd|Road)\b', # Addresses
r'\(\d{3}\)\s*\d{3}-\d{4}\b', # Phone numbers
r'\b[A-Z][a-z]+\s+[A-Z][a-z]+\b', # Names (simplified)
r'\b\d+@\w+\.\w+\b' # Email addresses
]
# If any PII patterns remain, it's incomplete
for pattern in pii_patterns:
if re.search(pattern, predicted):
return False
return True
def calculate_semantic_preservation(self, input_text: str, predicted: str, expected: str) -> float:
"""Calculate semantic preservation - how well the meaning is preserved after de-identification"""
# For de-identification, semantic preservation should focus on:
# 1. Whether the core message/content is maintained
# 2. Whether the text structure remains coherent
# 3. Whether placeholder density is reasonable
# Simple approach: compare text length and placeholder density
input_words = len(input_text.split())
expected_words = len(expected.split())
predicted_words = len(predicted.split())
# Length preservation (closer to 1.0 is better)
if expected_words == 0:
length_preservation = 1.0
else:
length_ratio = predicted_words / expected_words
# Penalize if too different in length (ideal ratio around 0.8-1.2)
if 0.5 <= length_ratio <= 2.0:
length_preservation = 1.0 - abs(1.0 - length_ratio) * 0.5
else:
length_preservation = 0.1 # Heavily penalize extreme length differences
# Placeholder density (should be reasonable, not too sparse or dense)
pred_placeholders = self.extract_placeholders(predicted)
placeholder_ratio = len(pred_placeholders) / max(predicted_words, 1)
if 0.05 <= placeholder_ratio <= 0.3: # Reasonable placeholder density
density_score = 1.0
elif placeholder_ratio < 0.05: # Too few placeholders
density_score = placeholder_ratio / 0.05
else: # Too many placeholders
density_score = max(0.1, 1.0 - (placeholder_ratio - 0.3) * 2)
# Structure preservation (check if basic sentence structure is maintained)
# Simple check: count punctuation marks as proxy for structure
input_punct = len(re.findall(r'[.!?]', input_text))
predicted_punct = len(re.findall(r'[.!?]', predicted))
if input_punct == 0:
structure_score = 1.0
else:
structure_ratio = min(predicted_punct, input_punct * 1.5) / input_punct
structure_score = min(1.0, structure_ratio)
# Combine scores (weighted average)
final_score = (length_preservation * 0.4) + (density_score * 0.4) + (structure_score * 0.2)
return max(0.0, min(1.0, final_score)) # Clamp to [0,1]
def call_model(self, instruction: str, input_text: str) -> Tuple[str, float]:
"""Call the de-identification model and measure latency"""
prompt = f"{instruction}\n\nInput: {input_text}\n\nResponse: "
payload = {
"prompt": prompt,
"max_tokens": self.config["model"]["max_tokens"],
"temperature": self.config["model"]["temperature"]
}
headers = {'Content-Type': 'application/json'}
start_time = time.time()
try:
response = requests.post(
f"{self.config['model']['base_url']}/completion",
json=payload,
headers=headers,
timeout=self.config["model"]["timeout"]
)
latency = (time.time() - start_time) * 1000 # Convert to ms
if response.status_code == 200:
result = response.json()
return result.get('content', ''), latency
else:
return f"Error: Server returned status {response.status_code}", latency
except requests.exceptions.RequestException as e:
latency = (time.time() - start_time) * 1000
return f"Error: {e}", latency
def run_benchmarks(self):
"""Run the complete benchmark suite"""
print("๐Ÿš€ Starting De-identification Benchmarks...")
print(f"๐Ÿ“Š Sample size: {self.config['datasets']['benchmark_dataset']['sample_size']}")
print(f"๐ŸŽฏ Model: {self.results['metadata']['model']}")
print()
examples = self.load_dataset()
# Initialize metrics
total_pii_detection = 0
total_completeness = 0
total_semantic_preservation = 0
total_latency = 0
successful_requests = 0
for i, example in enumerate(examples):
if i % 10 == 0:
print(f"๐Ÿ“ˆ Progress: {i}/{len(examples)} examples processed")
instruction = example[self.config["datasets"]["benchmark_dataset"]["instruction_field"]]
input_text = example[self.config["datasets"]["benchmark_dataset"]["input_field"]]
expected_output = example[self.config["datasets"]["benchmark_dataset"]["expected_output_field"]]
# Call model
predicted_output, latency = self.call_model(instruction, input_text)
if not predicted_output.startswith("Error"):
successful_requests += 1
# Calculate metrics
pii_detection = self.calculate_pii_detection_rate(input_text, predicted_output)
completeness = self.calculate_completeness(predicted_output)
semantic_preservation = self.calculate_semantic_preservation(input_text, predicted_output, expected_output)
# Update totals
total_pii_detection += pii_detection
total_completeness += completeness
total_semantic_preservation += semantic_preservation
total_latency += latency
# Store example if requested
if len(self.results["examples"]) < self.config["output"]["max_examples"]:
self.results["examples"].append({
"input": input_text,
"expected": expected_output,
"predicted": predicted_output,
"metrics": {
"pii_detection": pii_detection,
"completeness": completeness,
"semantic_preservation": semantic_preservation,
"latency_ms": latency
}
})
# Calculate final metrics
if successful_requests > 0:
self.results["metrics"] = {
"pii_detection_rate": total_pii_detection / successful_requests,
"completeness_score": total_completeness / successful_requests,
"semantic_preservation": total_semantic_preservation / successful_requests,
"average_latency_ms": total_latency / successful_requests,
"successful_requests": successful_requests,
"total_requests": len(examples)
}
self.save_results()
def save_results(self):
"""Save benchmark results to files"""
# Save detailed JSON results
with open(self.config["output"]["detailed_results_file"], 'w') as f:
json.dump(self.results, f, indent=2)
# Save human-readable summary
summary = self.generate_summary()
with open(self.config["output"]["results_file"], 'w') as f:
f.write(summary)
print("\nโœ… Benchmark complete!")
print(f"๐Ÿ“„ Detailed results saved to: {self.config['output']['detailed_results_file']}")
print(f"๐Ÿ“Š Summary saved to: {self.config['output']['results_file']}")
def generate_summary(self) -> str:
"""Generate a human-readable benchmark summary"""
m = self.results["metrics"]
summary = f"""# De-identification Benchmark Results
**Model:** {self.results['metadata']['model']}
**Dataset:** {self.results['metadata']['dataset']}
**Sample Size:** {self.results['metadata']['sample_size']}
**Date:** {self.results['metadata']['timestamp']}
## Overall Performance
| Metric | Score | Description |
|--------|-------|-------------|
| PII Detection Rate | {m.get('pii_detection_rate', 0):.3f} | How well personal identifiers are detected |
| Completeness Score | {m.get('completeness_score', 0):.3f} | Percentage of texts fully de-identified |
| Semantic Preservation | {m.get('semantic_preservation', 0):.3f} | How well meaning is preserved |
| Average Latency | {m.get('average_latency_ms', 0):.1f}ms | Response time performance |
## Key Improvements
- **PII Detection**: Now measures if model generates ANY placeholders when PII is present in input
- **Unified Evaluation**: All examples evaluated together (no domain separation)
- **Lenient Scoring**: Focuses on detection capability rather than exact placeholder matching
"""
if self.config["output"]["include_examples"] and self.results["examples"]:
summary += "## Example Results\n\n"
for i, example in enumerate(self.results["examples"][:3]): # Show first 3 examples
summary += f"### Example {i+1}\n"
summary += f"**Input:** {example['input'][:100]}...\n"
summary += f"**Expected:** {example['expected'][:100]}...\n"
summary += f"**Predicted:** {example['predicted'][:100]}...\n"
summary += f"**PII Detection:** {example['metrics']['pii_detection']:.3f}\n\n"
return summary
def main():
if len(sys.argv) != 2:
print("Usage: python run_benchmarks.py <config_file>")
sys.exit(1)
config_path = sys.argv[1]
if not os.path.exists(config_path):
print(f"Error: Config file {config_path} not found")
sys.exit(1)
runner = DeIdBenchmarkRunner(config_path)
runner.run_benchmarks()
if __name__ == "__main__":
main()