File size: 14,454 Bytes
a363a7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867d46c
a363a7b
 
 
 
 
 
 
867d46c
 
 
 
a363a7b
867d46c
 
 
 
 
 
a363a7b
867d46c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a363a7b
867d46c
a363a7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89fb98c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a363a7b
89fb98c
 
 
a363a7b
89fb98c
 
 
 
 
 
a363a7b
89fb98c
 
 
 
a363a7b
89fb98c
 
 
 
 
a363a7b
89fb98c
 
a363a7b
89fb98c
a363a7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867d46c
a363a7b
89fb98c
a363a7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867d46c
a363a7b
867d46c
 
 
a363a7b
867d46c
a363a7b
 
 
 
867d46c
a363a7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
#!/usr/bin/env python3
"""
Minimal De-identification Benchmark Runner for HuggingFace Publication

This script evaluates a de-identification model's performance on key metrics:
- PII Detection Rate: How well it identifies personal identifiers
- Completeness: Whether all PII is successfully masked
- Semantic Preservation: How well meaning is preserved
- Latency: Response time performance
- Domain Performance: Results across different text types
"""

import json
import re
import time
import requests
from typing import Dict, List, Tuple, Any
import yaml
from datetime import datetime
import sys
import os

class DeIdBenchmarkRunner:
    def __init__(self, config_path: str):
        with open(config_path, 'r') as f:
            self.config = yaml.safe_load(f)

        self.results = {
            "metadata": {
                "timestamp": datetime.now().isoformat(),
                "model": "Minibase-DeId-Small",
                "dataset": self.config["datasets"]["benchmark_dataset"]["file_path"],
                "sample_size": self.config["datasets"]["benchmark_dataset"]["sample_size"]
            },
            "metrics": {},
            "domain_performance": {},
            "examples": []
        }

    def load_dataset(self) -> List[Dict]:
        """Load and sample the benchmark dataset"""
        dataset_path = self.config["datasets"]["benchmark_dataset"]["file_path"]
        sample_size = self.config["datasets"]["benchmark_dataset"]["sample_size"]

        examples = []
        with open(dataset_path, 'r') as f:
            for i, line in enumerate(f):
                if i >= sample_size:
                    break
                examples.append(json.loads(line.strip()))

        print(f"βœ… Loaded {len(examples)} examples from {dataset_path}")
        return examples

    # Removed domain categorization as requested

    def extract_placeholders(self, text: str) -> List[str]:
        """Extract all placeholder tags from text (e.g., [NAME_1], [DOB_1])"""
        # Match patterns like [WORD_1], [WORD_NUMBER], etc.
        pattern = r'\[([A-Z_]+_\d+)\]'
        return re.findall(pattern, text)

    def calculate_pii_detection_rate(self, input_text: str, predicted: str) -> float:
        """Calculate PII detection rate - if input has PII and output has placeholders, count as success"""
        # Check if input contains any PII patterns
        input_has_pii = self._input_contains_pii(input_text)

        if not input_has_pii:
            return 1.0  # No PII in input, so detection is perfect

        # Check if output contains any placeholders at all
        predicted_placeholders = self.extract_placeholders(predicted)
        output_has_placeholders = len(predicted_placeholders) > 0

        # If input has PII and output has placeholders, count as successful detection
        return 1.0 if output_has_placeholders else 0.0

    def _input_contains_pii(self, input_text: str) -> bool:
        """Check if input text contains personal identifiable information"""
        pii_patterns = [
            r'\b\d{4}-\d{2}-\d{2}\b',  # Dates like 1985-03-15
            r'\b\d{1,3}/\d{1,2}/\d{4}\b',  # Dates like 05/12/1980
            r'\b\d{1,3}\s+[A-Z][a-z]+\s+(?:St|Street|Ave|Avenue|Rd|Road|Blvd|Boulevard)\b',  # Addresses
            r'\(\d{3}\)\s*\d{3}-\d{4}\b',  # Phone numbers like (555) 123-4567
            r'\+?\d{1,3}[-.\s]?\d{3}[-.\s]?\d{4}\b',  # International phone numbers
            r'\b[A-Z][a-z]+\s+[A-Z][a-z]+\b',  # Names (First Last)
            r'\b[A-Z][a-z]+\s+[A-Z]\.\s*[A-Z][a-z]+\b',  # Names with middle initial
            r'\b\d+@\w+\.\w+\b',  # Email addresses
            r'\b[A-Z]{2,}\d+\b',  # IDs like EMP-001-XYZ
            r'\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?\b',  # Monetary amounts like $85,000
            r'\b\d{3}-\d{2}-\d{4}\b',  # SSN-like patterns
            r'\b(?:Mr|Mrs|Ms|Dr|Prof)\.\s+[A-Z][a-z]+\b',  # Titles with names
            r'\b\d{5}(?:-\d{4})?\b',  # ZIP codes
            r'\b[A-Z][a-z]+,\s+[A-Z]{2}\s+\d{5}\b',  # City, State ZIP
        ]

        return any(re.search(pattern, input_text) for pattern in pii_patterns)

    def calculate_completeness(self, predicted: str) -> bool:
        """Check if response appears to have no obvious PII remaining"""
        # Simple heuristics for detecting remaining PII
        pii_patterns = [
            r'\b\d{4}-\d{2}-\d{2}\b',  # Dates like 1985-03-15
            r'\b\d{1,3}\s+[A-Z][a-z]+\s+(?:St|Street|Ave|Avenue|Rd|Road)\b',  # Addresses
            r'\(\d{3}\)\s*\d{3}-\d{4}\b',  # Phone numbers
            r'\b[A-Z][a-z]+\s+[A-Z][a-z]+\b',  # Names (simplified)
            r'\b\d+@\w+\.\w+\b'  # Email addresses
        ]

        # If any PII patterns remain, it's incomplete
        for pattern in pii_patterns:
            if re.search(pattern, predicted):
                return False

        return True

    def calculate_semantic_preservation(self, input_text: str, predicted: str, expected: str) -> float:
        """Calculate semantic preservation - how well the meaning is preserved after de-identification"""
        # For de-identification, semantic preservation should focus on:
        # 1. Whether the core message/content is maintained
        # 2. Whether the text structure remains coherent
        # 3. Whether placeholder density is reasonable

        # Simple approach: compare text length and placeholder density
        input_words = len(input_text.split())
        expected_words = len(expected.split())
        predicted_words = len(predicted.split())

        # Length preservation (closer to 1.0 is better)
        if expected_words == 0:
            length_preservation = 1.0
        else:
            length_ratio = predicted_words / expected_words
            # Penalize if too different in length (ideal ratio around 0.8-1.2)
            if 0.5 <= length_ratio <= 2.0:
                length_preservation = 1.0 - abs(1.0 - length_ratio) * 0.5
            else:
                length_preservation = 0.1  # Heavily penalize extreme length differences

        # Placeholder density (should be reasonable, not too sparse or dense)
        pred_placeholders = self.extract_placeholders(predicted)
        placeholder_ratio = len(pred_placeholders) / max(predicted_words, 1)

        if 0.05 <= placeholder_ratio <= 0.3:  # Reasonable placeholder density
            density_score = 1.0
        elif placeholder_ratio < 0.05:  # Too few placeholders
            density_score = placeholder_ratio / 0.05
        else:  # Too many placeholders
            density_score = max(0.1, 1.0 - (placeholder_ratio - 0.3) * 2)

        # Structure preservation (check if basic sentence structure is maintained)
        # Simple check: count punctuation marks as proxy for structure
        input_punct = len(re.findall(r'[.!?]', input_text))
        predicted_punct = len(re.findall(r'[.!?]', predicted))

        if input_punct == 0:
            structure_score = 1.0
        else:
            structure_ratio = min(predicted_punct, input_punct * 1.5) / input_punct
            structure_score = min(1.0, structure_ratio)

        # Combine scores (weighted average)
        final_score = (length_preservation * 0.4) + (density_score * 0.4) + (structure_score * 0.2)

        return max(0.0, min(1.0, final_score))  # Clamp to [0,1]

    def call_model(self, instruction: str, input_text: str) -> Tuple[str, float]:
        """Call the de-identification model and measure latency"""
        prompt = f"{instruction}\n\nInput: {input_text}\n\nResponse: "

        payload = {
            "prompt": prompt,
            "max_tokens": self.config["model"]["max_tokens"],
            "temperature": self.config["model"]["temperature"]
        }

        headers = {'Content-Type': 'application/json'}

        start_time = time.time()
        try:
            response = requests.post(
                f"{self.config['model']['base_url']}/completion",
                json=payload,
                headers=headers,
                timeout=self.config["model"]["timeout"]
            )
            latency = (time.time() - start_time) * 1000  # Convert to ms

            if response.status_code == 200:
                result = response.json()
                return result.get('content', ''), latency
            else:
                return f"Error: Server returned status {response.status_code}", latency
        except requests.exceptions.RequestException as e:
            latency = (time.time() - start_time) * 1000
            return f"Error: {e}", latency

    def run_benchmarks(self):
        """Run the complete benchmark suite"""
        print("πŸš€ Starting De-identification Benchmarks...")
        print(f"πŸ“Š Sample size: {self.config['datasets']['benchmark_dataset']['sample_size']}")
        print(f"🎯 Model: {self.results['metadata']['model']}")
        print()

        examples = self.load_dataset()

        # Initialize metrics
        total_pii_detection = 0
        total_completeness = 0
        total_semantic_preservation = 0
        total_latency = 0

        successful_requests = 0

        for i, example in enumerate(examples):
            if i % 10 == 0:
                print(f"πŸ“ˆ Progress: {i}/{len(examples)} examples processed")

            instruction = example[self.config["datasets"]["benchmark_dataset"]["instruction_field"]]
            input_text = example[self.config["datasets"]["benchmark_dataset"]["input_field"]]
            expected_output = example[self.config["datasets"]["benchmark_dataset"]["expected_output_field"]]

            # Call model
            predicted_output, latency = self.call_model(instruction, input_text)

            if not predicted_output.startswith("Error"):
                successful_requests += 1

                # Calculate metrics
                pii_detection = self.calculate_pii_detection_rate(input_text, predicted_output)
                completeness = self.calculate_completeness(predicted_output)
                semantic_preservation = self.calculate_semantic_preservation(input_text, predicted_output, expected_output)

                # Update totals
                total_pii_detection += pii_detection
                total_completeness += completeness
                total_semantic_preservation += semantic_preservation
                total_latency += latency

                # Store example if requested
                if len(self.results["examples"]) < self.config["output"]["max_examples"]:
                    self.results["examples"].append({
                        "input": input_text,
                        "expected": expected_output,
                        "predicted": predicted_output,
                        "metrics": {
                            "pii_detection": pii_detection,
                            "completeness": completeness,
                            "semantic_preservation": semantic_preservation,
                            "latency_ms": latency
                        }
                    })

        # Calculate final metrics
        if successful_requests > 0:
            self.results["metrics"] = {
                "pii_detection_rate": total_pii_detection / successful_requests,
                "completeness_score": total_completeness / successful_requests,
                "semantic_preservation": total_semantic_preservation / successful_requests,
                "average_latency_ms": total_latency / successful_requests,
                "successful_requests": successful_requests,
                "total_requests": len(examples)
            }

        self.save_results()

    def save_results(self):
        """Save benchmark results to files"""
        # Save detailed JSON results
        with open(self.config["output"]["detailed_results_file"], 'w') as f:
            json.dump(self.results, f, indent=2)

        # Save human-readable summary
        summary = self.generate_summary()
        with open(self.config["output"]["results_file"], 'w') as f:
            f.write(summary)

        print("\nβœ… Benchmark complete!")
        print(f"πŸ“„ Detailed results saved to: {self.config['output']['detailed_results_file']}")
        print(f"πŸ“Š Summary saved to: {self.config['output']['results_file']}")

    def generate_summary(self) -> str:
        """Generate a human-readable benchmark summary"""
        m = self.results["metrics"]

        summary = f"""# De-identification Benchmark Results
**Model:** {self.results['metadata']['model']}
**Dataset:** {self.results['metadata']['dataset']}
**Sample Size:** {self.results['metadata']['sample_size']}
**Date:** {self.results['metadata']['timestamp']}

## Overall Performance

| Metric | Score | Description |
|--------|-------|-------------|
| PII Detection Rate | {m.get('pii_detection_rate', 0):.3f} | How well personal identifiers are detected |
| Completeness Score | {m.get('completeness_score', 0):.3f} | Percentage of texts fully de-identified |
| Semantic Preservation | {m.get('semantic_preservation', 0):.3f} | How well meaning is preserved |
| Average Latency | {m.get('average_latency_ms', 0):.1f}ms | Response time performance |

## Key Improvements

- **PII Detection**: Now measures if model generates ANY placeholders when PII is present in input
- **Unified Evaluation**: All examples evaluated together (no domain separation)
- **Lenient Scoring**: Focuses on detection capability rather than exact placeholder matching

"""

        if self.config["output"]["include_examples"] and self.results["examples"]:
            summary += "## Example Results\n\n"
            for i, example in enumerate(self.results["examples"][:3]):  # Show first 3 examples
                summary += f"### Example {i+1}\n"
                summary += f"**Input:** {example['input'][:100]}...\n"
                summary += f"**Expected:** {example['expected'][:100]}...\n"
                summary += f"**Predicted:** {example['predicted'][:100]}...\n"
                summary += f"**PII Detection:** {example['metrics']['pii_detection']:.3f}\n\n"

        return summary

def main():
    if len(sys.argv) != 2:
        print("Usage: python run_benchmarks.py <config_file>")
        sys.exit(1)

    config_path = sys.argv[1]
    if not os.path.exists(config_path):
        print(f"Error: Config file {config_path} not found")
        sys.exit(1)

    runner = DeIdBenchmarkRunner(config_path)
    runner.run_benchmarks()

if __name__ == "__main__":
    main()