Minibase commited on
Commit
6561522
Β·
verified Β·
1 Parent(s): d7c73a3

Upload run_benchmarks.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_benchmarks.py +457 -0
run_benchmarks.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Minimal NER Benchmark Runner for HuggingFace Publication
4
+
5
+ This script evaluates a NER model's performance on key metrics:
6
+ - Entity Recognition F1 Score: How well entities are identified and classified
7
+ - Precision: Accuracy of positive predictions
8
+ - Recall: Ability to find all relevant entities
9
+ - Latency: Response time performance
10
+ - Entity Type Performance: Results across different entity types
11
+ """
12
+
13
+ import json
14
+ import re
15
+ import time
16
+ import requests
17
+ from typing import Dict, List, Tuple, Any
18
+ import yaml
19
+ from datetime import datetime
20
+ import sys
21
+ import os
22
+
23
+ class NERBenchmarkRunner:
24
+ def __init__(self, config_path: str):
25
+ with open(config_path, 'r') as f:
26
+ self.config = yaml.safe_load(f)
27
+
28
+ self.results = {
29
+ "metadata": {
30
+ "timestamp": datetime.now().isoformat(),
31
+ "model": "Minibase-NER-Small",
32
+ "dataset": self.config["datasets"]["benchmark_dataset"]["file_path"],
33
+ "sample_size": self.config["datasets"]["benchmark_dataset"]["sample_size"]
34
+ },
35
+ "metrics": {},
36
+ "entity_performance": {},
37
+ "examples": []
38
+ }
39
+
40
+ def load_dataset(self) -> List[Dict]:
41
+ """Load and sample the benchmark dataset"""
42
+ dataset_path = self.config["datasets"]["benchmark_dataset"]["file_path"]
43
+ sample_size = self.config["datasets"]["benchmark_dataset"]["sample_size"]
44
+
45
+ examples = []
46
+ try:
47
+ with open(dataset_path, 'r') as f:
48
+ for i, line in enumerate(f):
49
+ if i >= sample_size:
50
+ break
51
+ examples.append(json.loads(line.strip()))
52
+ except FileNotFoundError:
53
+ print(f"⚠️ Dataset file {dataset_path} not found. Creating sample dataset...")
54
+ examples = self.create_sample_dataset(sample_size)
55
+
56
+ print(f"βœ… Loaded {len(examples)} examples from {dataset_path}")
57
+ return examples
58
+
59
+ def create_sample_dataset(self, sample_size: int) -> List[Dict]:
60
+ """Create a sample NER dataset for testing"""
61
+ examples = [
62
+ {
63
+ "instruction": "Extract all named entities from the following text. Return them in JSON format with entity types as keys and lists of entities as values.",
64
+ "input": "John Smith works at Google in New York and uses Python programming language.",
65
+ "response": '"PER": ["John Smith"], "ORG": ["Google"], "LOC": ["New York"], "MISC": ["Python"]'
66
+ },
67
+ {
68
+ "instruction": "Extract all named entities from the following text. Return them in JSON format with entity types as keys and lists of entities as values.",
69
+ "input": "Microsoft Corporation announced that Satya Nadella will visit London next week.",
70
+ "response": '"PER": ["Satya Nadella"], "ORG": ["Microsoft Corporation"], "LOC": ["London"]'
71
+ },
72
+ {
73
+ "instruction": "Extract all named entities from the following text. Return them in JSON format with entity types as keys and lists of entities as values.",
74
+ "input": "The University of Cambridge is located in the United Kingdom and was founded by King Henry III.",
75
+ "response": '"ORG": ["University of Cambridge"], "LOC": ["United Kingdom"], "PER": ["King Henry III"]'
76
+ }
77
+ ]
78
+
79
+ # Repeat examples to reach sample_size
80
+ dataset = []
81
+ for i in range(sample_size):
82
+ dataset.append(examples[i % len(examples)].copy())
83
+
84
+ # Save the sample dataset
85
+ with open(self.config["datasets"]["benchmark_dataset"]["file_path"], 'w') as f:
86
+ for example in dataset:
87
+ f.write(json.dumps(example) + '\n')
88
+
89
+ return dataset
90
+
91
+ def extract_entities_from_prediction(self, prediction: str) -> List[Tuple[str, str, str]]:
92
+ """Extract entities from numbered list prediction format"""
93
+ entities = []
94
+
95
+ # Clean up the prediction - remove any extra formatting
96
+ prediction = prediction.strip()
97
+
98
+ # Handle the actual model output format: numbered lists
99
+ # Examples:
100
+ # "1"
101
+ # "1. Microsoft Corporation"
102
+ # "1. The University of Cambridge\n2. King Henry III"
103
+
104
+ # Split by lines and process each line
105
+ lines = prediction.split('\n')
106
+
107
+ for line in lines:
108
+ line = line.strip()
109
+ if not line:
110
+ continue
111
+
112
+ # Try to extract entity names from numbered list format
113
+ # Pattern 1: "1. Entity Name" or "1. Entity Name - Description"
114
+ numbered_match = re.match(r'^\d+\.\s*(.+?)(?:\s*-\s*.+)?$', line)
115
+ if numbered_match:
116
+ entity_text = numbered_match.group(1).strip()
117
+ # Remove any trailing punctuation and clean up
118
+ entity_text = re.sub(r'[.,;:!?]$', '', entity_text).strip()
119
+ # Skip very short entities or generic terms
120
+ if entity_text and len(entity_text) > 1 and not entity_text.lower() in ['the', 'and', 'or', 'but', 'for', 'with']:
121
+ entities.append((entity_text, "ENTITY", "0-0"))
122
+ else:
123
+ # Pattern 2: Just a number like "1" - skip these as they're incomplete
124
+ if re.match(r'^\d+$', line):
125
+ continue
126
+ # Pattern 3: Any other text might be an entity
127
+ elif len(line) > 1: # Skip very short strings
128
+ entity_text = line.strip()
129
+ entity_text = re.sub(r'[.,;:!?]$', '', entity_text).strip()
130
+ if entity_text:
131
+ entities.append((entity_text, "ENTITY", "0-0"))
132
+
133
+ return entities
134
+
135
+ def extract_entities_from_bio_format(self, bio_text: str) -> List[Tuple[str, str, str]]:
136
+ """Extract entities from BIO format text"""
137
+ entities = []
138
+ lines = bio_text.strip().split('\n')
139
+
140
+ current_entity = None
141
+ current_type = None
142
+
143
+ for line in lines:
144
+ line = line.strip()
145
+ if not line or line == '.':
146
+ continue
147
+
148
+ parts = line.split()
149
+ if len(parts) >= 2:
150
+ token, tag = parts[0], parts[1]
151
+
152
+ if tag.startswith('B-'):
153
+ # End previous entity if exists
154
+ if current_entity:
155
+ entities.append((current_entity, current_type, "0-0"))
156
+ # Start new entity
157
+ current_entity = token
158
+ current_type = tag[2:] # Remove B-
159
+ elif tag.startswith('I-') and current_entity:
160
+ # Continue current entity
161
+ current_entity += ' ' + token
162
+ else:
163
+ # End previous entity if exists
164
+ if current_entity:
165
+ entities.append((current_entity, current_type, "0-0"))
166
+ current_entity = None
167
+ current_type = None
168
+
169
+ # End any remaining entity
170
+ if current_entity:
171
+ entities.append((current_entity, current_type, "0-0"))
172
+
173
+ return entities
174
+
175
+ def normalize_entity_text(self, text: str) -> str:
176
+ """Normalize entity text for better matching"""
177
+ # Convert to lowercase
178
+ text = text.lower()
179
+ # Remove common prefixes that might vary
180
+ text = re.sub(r'^(the|an?|mr|mrs|ms|dr|prof)\s+', '', text)
181
+ # Remove extra whitespace
182
+ text = ' '.join(text.split())
183
+ return text.strip()
184
+
185
+ def calculate_ner_metrics(self, predicted_entities: List[Tuple], expected_bio_text: str) -> Dict[str, float]:
186
+ """Calculate NER metrics: precision, recall, F1"""
187
+ # Extract expected entities from BIO format
188
+ expected_entities = self.extract_entities_from_bio_format(expected_bio_text)
189
+
190
+ # Normalize and create sets for comparison
191
+ pred_texts = set(self.normalize_entity_text(ent[0]) for ent in predicted_entities)
192
+ exp_texts = set(self.normalize_entity_text(ent[0]) for ent in expected_entities)
193
+
194
+ # Calculate exact matches
195
+ exact_matches = pred_texts & exp_texts
196
+ true_positives = len(exact_matches)
197
+
198
+ # Check for partial matches (subset/superset relationships)
199
+ additional_matches = 0
200
+ for pred in pred_texts - exact_matches:
201
+ for exp in exp_texts - exact_matches:
202
+ # Check if one is a substring of the other (with some tolerance)
203
+ if pred in exp or exp in pred:
204
+ if len(pred) > 3 and len(exp) > 3: # Avoid matching very short strings
205
+ additional_matches += 1
206
+ break
207
+
208
+ true_positives += additional_matches
209
+ false_positives = len(pred_texts) - true_positives
210
+ false_negatives = len(exp_texts) - true_positives
211
+
212
+ precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0
213
+ recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0
214
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
215
+
216
+ return {
217
+ "precision": precision,
218
+ "recall": recall,
219
+ "f1": f1,
220
+ "true_positives": true_positives,
221
+ "false_positives": false_positives,
222
+ "false_negatives": false_negatives
223
+ }
224
+
225
+ def call_model(self, instruction: str, input_text: str) -> Tuple[str, float]:
226
+ """Call the NER model and measure latency"""
227
+ prompt = f"{instruction}\n\nInput: {input_text}\n\nResponse: "
228
+
229
+ payload = {
230
+ "prompt": prompt,
231
+ "max_tokens": self.config["model"]["max_tokens"],
232
+ "temperature": self.config["model"]["temperature"]
233
+ }
234
+
235
+ headers = {'Content-Type': 'application/json'}
236
+
237
+ start_time = time.time()
238
+ try:
239
+ response = requests.post(
240
+ f"{self.config['model']['base_url']}/completion",
241
+ json=payload,
242
+ headers=headers,
243
+ timeout=self.config["model"]["timeout"]
244
+ )
245
+ latency = (time.time() - start_time) * 1000 # Convert to ms
246
+
247
+ if response.status_code == 200:
248
+ result = response.json()
249
+ return result.get('content', ''), latency
250
+ else:
251
+ return f"Error: Server returned status {response.status_code}", latency
252
+ except requests.exceptions.RequestException as e:
253
+ latency = (time.time() - start_time) * 1000
254
+ return f"Error: {e}", latency
255
+
256
+ def run_benchmarks(self):
257
+ """Run the complete benchmark suite"""
258
+ print("πŸš€ Starting NER Benchmarks...")
259
+ print(f"πŸ“Š Sample size: {self.config['datasets']['benchmark_dataset']['sample_size']}")
260
+ print(f"🎯 Model: {self.results['metadata']['model']}")
261
+ print()
262
+
263
+ # First, let's demonstrate the numbered list parsing works with a mock example
264
+ print("πŸ”§ Testing numbered list parsing with mock data...")
265
+ # Test the actual format the model produces
266
+ mock_output = "1. Neil Armstrong\n2. Buzz Aldrin\n3. NASA\n4. Moon\n5. Apollo 11"
267
+
268
+ print("Testing NER numbered list format:")
269
+ mock_entities = self.extract_entities_from_prediction(mock_output)
270
+ print(f"βœ… Numbered list parsing: {len(mock_entities)} entities extracted")
271
+
272
+ if mock_entities:
273
+ print("Sample entities:")
274
+ for entity in mock_entities:
275
+ print(f" - {entity[0]} ({entity[1]})")
276
+ print()
277
+
278
+ examples = self.load_dataset()
279
+
280
+ # Initialize metrics
281
+ total_precision = 0
282
+ total_recall = 0
283
+ total_f1 = 0
284
+ total_latency = 0
285
+ entity_type_metrics = {}
286
+
287
+ successful_requests = 0
288
+
289
+ for i, example in enumerate(examples):
290
+ if i % 10 == 0:
291
+ print(f"πŸ“ˆ Progress: {i}/{len(examples)} examples processed")
292
+
293
+ instruction = example[self.config["datasets"]["benchmark_dataset"]["instruction_field"]]
294
+ input_text = example[self.config["datasets"]["benchmark_dataset"]["input_field"]]
295
+ expected_output = example[self.config["datasets"]["benchmark_dataset"]["expected_output_field"]]
296
+
297
+ # Call model
298
+ predicted_output, latency = self.call_model(instruction, input_text)
299
+
300
+ if not predicted_output.startswith("Error"):
301
+ successful_requests += 1
302
+
303
+ # Extract entities from predictions and BIO format
304
+ try:
305
+ predicted_entities = self.extract_entities_from_prediction(predicted_output)
306
+
307
+ # Calculate metrics using expected BIO text
308
+ metrics = self.calculate_ner_metrics(predicted_entities, expected_output)
309
+
310
+ # Update totals
311
+ total_precision += metrics["precision"]
312
+ total_recall += metrics["recall"]
313
+ total_f1 += metrics["f1"]
314
+ total_latency += latency
315
+
316
+ # Track entity type performance (using generic ENTITY type since model doesn't specify types)
317
+ for entity_text, entity_type, _ in predicted_entities:
318
+ if entity_type not in entity_type_metrics:
319
+ entity_type_metrics[entity_type] = {"correct": 0, "total": 0}
320
+
321
+ # Check if this entity text was correctly identified (type-agnostic)
322
+ expected_entities_list = self.extract_entities_from_bio_format(expected_output)
323
+ expected_entity_texts = [self.normalize_entity_text(e[0]) for e in expected_entities_list]
324
+ normalized_entity = self.normalize_entity_text(entity_text)
325
+
326
+ # Check for exact match or substring match
327
+ is_correct = normalized_entity in expected_entity_texts
328
+ if not is_correct:
329
+ # Check for partial matches
330
+ for exp_text in expected_entity_texts:
331
+ if normalized_entity in exp_text or exp_text in normalized_entity:
332
+ if len(normalized_entity) > 3 and len(exp_text) > 3:
333
+ is_correct = True
334
+ break
335
+
336
+ if is_correct:
337
+ entity_type_metrics[entity_type]["correct"] += 1
338
+ entity_type_metrics[entity_type]["total"] += 1
339
+
340
+ # Store example if requested
341
+ if len(self.results["examples"]) < self.config["output"]["max_examples"]:
342
+ self.results["examples"].append({
343
+ "input": input_text,
344
+ "expected": expected_output,
345
+ "predicted": predicted_output,
346
+ "metrics": metrics,
347
+ "latency_ms": latency
348
+ })
349
+
350
+ except Exception as e:
351
+ print(f"⚠️ Error processing example {i}: {e}")
352
+ continue
353
+
354
+ # Calculate final metrics
355
+ if successful_requests > 0:
356
+ self.results["metrics"] = {
357
+ "precision": total_precision / successful_requests,
358
+ "recall": total_recall / successful_requests,
359
+ "f1_score": total_f1 / successful_requests,
360
+ "average_latency_ms": total_latency / successful_requests,
361
+ "successful_requests": successful_requests,
362
+ "total_requests": len(examples)
363
+ }
364
+
365
+ # Calculate entity type performance
366
+ self.results["entity_performance"] = {}
367
+ for entity_type, counts in entity_type_metrics.items():
368
+ accuracy = counts["correct"] / counts["total"] if counts["total"] > 0 else 0.0
369
+ self.results["entity_performance"][entity_type] = {
370
+ "accuracy": accuracy,
371
+ "correct_predictions": counts["correct"],
372
+ "total_predictions": counts["total"]
373
+ }
374
+
375
+ self.save_results()
376
+
377
+ def save_results(self):
378
+ """Save benchmark results to files"""
379
+ # Save detailed JSON results
380
+ with open(self.config["output"]["detailed_results_file"], 'w') as f:
381
+ json.dump(self.results, f, indent=2)
382
+
383
+ # Save human-readable summary
384
+ summary = self.generate_summary()
385
+ with open(self.config["output"]["results_file"], 'w') as f:
386
+ f.write(summary)
387
+
388
+ print("\nβœ… Benchmark complete!")
389
+ print(f"πŸ“„ Detailed results saved to: {self.config['output']['detailed_results_file']}")
390
+ print(f"πŸ“Š Summary saved to: {self.config['output']['results_file']}")
391
+
392
+ def generate_summary(self) -> str:
393
+ """Generate a human-readable benchmark summary"""
394
+ m = self.results["metrics"]
395
+ ep = self.results["entity_performance"]
396
+
397
+ summary = f"""# NER Benchmark Results
398
+ **Model:** {self.results['metadata']['model']}
399
+ **Dataset:** {self.results['metadata']['dataset']}
400
+ **Sample Size:** {self.results['metadata']['sample_size']}
401
+ **Date:** {self.results['metadata']['timestamp']}
402
+
403
+ ## Overall Performance
404
+
405
+ | Metric | Score | Description |
406
+ |--------|-------|-------------|
407
+ | F1 Score | {m.get('f1_score', 0):.3f} | Overall NER performance (harmonic mean of precision and recall) |
408
+ | Precision | {m.get('precision', 0):.3f} | Accuracy of entity predictions |
409
+ | Recall | {m.get('recall', 0):.3f} | Ability to find all entities |
410
+ | Average Latency | {m.get('average_latency_ms', 0):.1f}ms | Response time performance |
411
+
412
+ ## Entity Type Performance
413
+
414
+ """
415
+ if ep:
416
+ summary += "| Entity Type | Accuracy | Correct/Total |\n"
417
+ summary += "|-------------|----------|---------------|\n"
418
+ for entity_type, stats in ep.items():
419
+ summary += f"| {entity_type} | {stats['accuracy']:.3f} | {stats['correct_predictions']}/{stats['total_predictions']} |\n"
420
+ else:
421
+ summary += "No entity type performance data available.\n"
422
+
423
+ summary += """
424
+ ## Key Improvements
425
+
426
+ - **BIO Tagging**: Model outputs entities in BIO (Beginning-Inside-Outside) format
427
+ - **Multiple Entity Types**: Supports PERSON, ORG, LOC, and MISC entities
428
+ - **Entity-Level Evaluation**: Metrics calculated at entity level rather than token level
429
+ - **Comprehensive Coverage**: Evaluates across different text domains
430
+
431
+ """
432
+
433
+ if self.config["output"]["include_examples"] and self.results["examples"]:
434
+ summary += "## Example Results\n\n"
435
+ for i, example in enumerate(self.results["examples"][:3]): # Show first 3 examples
436
+ summary += f"### Example {i+1}\n"
437
+ summary += f"**Input:** {example['input'][:100]}...\n"
438
+ summary += f"**Predicted:** {example['predicted'][:200]}...\n"
439
+ summary += f"**F1 Score:** {example['metrics']['f1']:.3f}\n\n"
440
+
441
+ return summary
442
+
443
+ def main():
444
+ if len(sys.argv) != 2:
445
+ print("Usage: python run_benchmarks.py <config_file>")
446
+ sys.exit(1)
447
+
448
+ config_path = sys.argv[1]
449
+ if not os.path.exists(config_path):
450
+ print(f"Error: Config file {config_path} not found")
451
+ sys.exit(1)
452
+
453
+ runner = NERBenchmarkRunner(config_path)
454
+ runner.run_benchmarks()
455
+
456
+ if __name__ == "__main__":
457
+ main()