Awarebeyond commited on
Commit
241fbbb
·
verified ·
1 Parent(s): 763f78a

Add field-level confusion matrix evaluation script

Browse files
Files changed (1) hide show
  1. evaluate_model.py +264 -0
evaluate_model.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate the fine-tuned Donut model and generate a Field-Level Confusion Matrix.
3
+ Run this on the Workbench where the model and datasets are located.
4
+
5
+ Usage:
6
+ python scripts/evaluate_model.py \
7
+ --model_path outputs/receipt_donut_gcp_enterprise/best_model \
8
+ --config configs/gcp_l4_enterprise.yaml \
9
+ --output_dir evaluation_results
10
+
11
+ Outputs:
12
+ - evaluation_results/field_confusion_matrix.png
13
+ - evaluation_results/field_accuracy.json
14
+ - evaluation_results/error_analysis.html
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ import json
20
+ import argparse
21
+ import Levenshtein
22
+ from pathlib import Path
23
+ from collections import defaultdict
24
+
25
+ import numpy as np
26
+ import torch
27
+ from PIL import Image
28
+ import matplotlib.pyplot as plt
29
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
30
+
31
+ sys.path.insert(0, str(Path(__file__).parent.parent))
32
+ from core.unified_dataset import UnifiedReceiptDataset
33
+
34
+
35
+ FIELDS = ["merchant", "date", "subtotal", "tax", "total", "address"]
36
+
37
+
38
+ def normalize_text(text):
39
+ """Lowercase and strip whitespace for fair comparison."""
40
+ if text is None:
41
+ return ""
42
+ return str(text).lower().strip().replace("$", "").replace(",", "")
43
+
44
+
45
+ def categorize_match(gt, pred):
46
+ """
47
+ Categorize a single field prediction into:
48
+ - correct: exact match after normalization
49
+ - minor_typo: < 20% Levenshtein distance
50
+ - incorrect: everything else
51
+ """
52
+ gt_norm = normalize_text(gt)
53
+ pred_norm = normalize_text(pred)
54
+
55
+ if not gt_norm and not pred_norm:
56
+ return "correct" # Both missing = agreement
57
+ if not gt_norm or not pred_norm:
58
+ return "incorrect" # One missing, one present
59
+
60
+ if gt_norm == pred_norm:
61
+ return "correct"
62
+
63
+ dist = Levenshtein.distance(gt_norm, pred_norm)
64
+ max_len = max(len(gt_norm), len(pred_norm))
65
+ ratio = dist / max_len if max_len > 0 else 0
66
+
67
+ if ratio < 0.20:
68
+ return "minor_typo"
69
+ return "incorrect"
70
+
71
+
72
+ def run_inference(model, processor, image_path, device):
73
+ """Run model inference on a single image and return parsed JSON dict."""
74
+ img = Image.open(image_path).convert("RGB")
75
+ pixel_values = processor(img, return_tensors="pt").pixel_values.to(device)
76
+ decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device)
77
+
78
+ with torch.no_grad():
79
+ outputs = model.generate(
80
+ pixel_values,
81
+ decoder_input_ids=decoder_input_ids,
82
+ max_length=512,
83
+ pad_token_id=processor.tokenizer.pad_token_id,
84
+ eos_token_id=processor.tokenizer.eos_token_id,
85
+ use_cache=True,
86
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
87
+ )
88
+
89
+ seq = processor.tokenizer.batch_decode(outputs.sequences)[0]
90
+ seq = seq.replace(processor.tokenizer.eos_token, "").replace(
91
+ processor.tokenizer.pad_token, ""
92
+ )
93
+ seq = seq.replace(
94
+ processor.tokenizer.decode([model.config.decoder_start_token_id]), ""
95
+ ).strip()
96
+
97
+ try:
98
+ return json.loads(seq)
99
+ except json.JSONDecodeError:
100
+ return {}
101
+
102
+
103
+ def evaluate(model, processor, dataset, device, max_samples=None):
104
+ """
105
+ Evaluate the model on a dataset and return per-field statistics.
106
+ """
107
+ counts = {field: {"correct": 0, "minor_typo": 0, "incorrect": 0} for field in FIELDS}
108
+ errors = []
109
+
110
+ n = min(len(dataset), max_samples) if max_samples else len(dataset)
111
+ print(f"Evaluating on {n} samples...")
112
+
113
+ for i in range(n):
114
+ sample = dataset[i]
115
+ image_path = sample["image_path"]
116
+ gt = sample["ground_truth"]
117
+
118
+ pred = run_inference(model, processor, image_path, device)
119
+
120
+ sample_error = {"image": image_path, "gt": gt, "pred": pred, "fields": {}}
121
+ all_correct = True
122
+
123
+ for field in FIELDS:
124
+ gt_val = gt.get(field, "")
125
+ pred_val = pred.get(field, "")
126
+ cat = categorize_match(gt_val, pred_val)
127
+ counts[field][cat] += 1
128
+ sample_error["fields"][field] = cat
129
+ if cat != "correct":
130
+ all_correct = False
131
+
132
+ if not all_correct:
133
+ errors.append(sample_error)
134
+
135
+ if (i + 1) % 50 == 0:
136
+ print(f" Processed {i + 1}/{n}")
137
+
138
+ return counts, errors
139
+
140
+
141
+ def plot_confusion_matrix(counts, output_dir):
142
+ """Generate a stacked bar chart confusion matrix per field."""
143
+ categories = ["correct", "minor_typo", "incorrect"]
144
+ colors = ["#4CAF50", "#FFC107", "#F44336"]
145
+
146
+ fig, ax = plt.subplots(figsize=(10, 6))
147
+ x = np.arange(len(FIELDS))
148
+ width = 0.25
149
+
150
+ for i, cat in enumerate(categories):
151
+ values = [counts[f][cat] for f in FIELDS]
152
+ ax.bar(x + i * width, values, width, label=cat.replace("_", " ").title(), color=colors[i])
153
+
154
+ ax.set_xlabel("Field")
155
+ ax.set_ylabel("Count")
156
+ ax.set_title("Field-Level Confusion Matrix (Validation/Test Set)")
157
+ ax.set_xticks(x + width)
158
+ ax.set_xticklabels(FIELDS, rotation=15, ha="right")
159
+ ax.legend()
160
+ ax.grid(axis="y", linestyle="--", alpha=0.5)
161
+ plt.tight_layout()
162
+
163
+ save_path = os.path.join(output_dir, "field_confusion_matrix.png")
164
+ plt.savefig(save_path, dpi=150)
165
+ print(f"Saved confusion matrix to {save_path}")
166
+ plt.close()
167
+
168
+
169
+ def save_accuracy_json(counts, output_dir):
170
+ """Save numerical accuracy breakdown per field."""
171
+ results = {}
172
+ for field in FIELDS:
173
+ total = sum(counts[field].values())
174
+ results[field] = {
175
+ "correct_pct": round(counts[field]["correct"] / total * 100, 1),
176
+ "minor_typo_pct": round(counts[field]["minor_typo"] / total * 100, 1),
177
+ "incorrect_pct": round(counts[field]["incorrect"] / total * 100, 1),
178
+ "counts": counts[field],
179
+ }
180
+
181
+ save_path = os.path.join(output_dir, "field_accuracy.json")
182
+ with open(save_path, "w") as f:
183
+ json.dump(results, f, indent=2)
184
+ print(f"Saved accuracy JSON to {save_path}")
185
+
186
+
187
+ def save_error_html(errors, output_dir, max_display=50):
188
+ """Generate an HTML file showing side-by-side GT vs Pred errors."""
189
+ html = ["<html><head><style>",
190
+ "body{font-family:sans-serif;margin:20px}",
191
+ "table{border-collapse:collapse;width:100%}",
192
+ "th,td{border:1px solid #ccc;padding:8px;text-align:left}",
193
+ "th{background:#f0f0f0}",
194
+ ".correct{color:green}.minor{color:orange}.incorrect{color:red}",
195
+ "</style></head><body>",
196
+ f"<h1>Error Analysis ({min(len(errors), max_display)} of {len(errors)} failures)</h1>",
197
+ "<table><tr><th>Image</th><th>Field</th><th>Ground Truth</th><th>Predicted</th><th>Status</th></tr>"]
198
+
199
+ for err in errors[:max_display]:
200
+ img_name = os.path.basename(err["image"])
201
+ for field in FIELDS:
202
+ status = err["fields"][field]
203
+ if status == "correct":
204
+ continue
205
+ css_class = "correct" if status == "correct" else ("minor" if status == "minor_typo" else "incorrect")
206
+ html.append(f"<tr><td>{img_name}</td><td>{field}</td>"
207
+ f"<td>{err['gt'].get(field, 'N/A')}</td>"
208
+ f"<td>{err['pred'].get(field, 'N/A')}</td>"
209
+ f"<td class='{css_class}'>{status}</td></tr>")
210
+
211
+ html.append("</table></body></html>")
212
+
213
+ save_path = os.path.join(output_dir, "error_analysis.html")
214
+ with open(save_path, "w") as f:
215
+ f.write("\n".join(html))
216
+ print(f"Saved error analysis HTML to {save_path}")
217
+
218
+
219
+ def main():
220
+ parser = argparse.ArgumentParser(description="Evaluate Donut receipt model")
221
+ parser.add_argument("--model_path", required=True, help="Path to fine-tuned model")
222
+ parser.add_argument("--config", default="configs/gcp_l4_enterprise.yaml", help="Training config YAML")
223
+ parser.add_argument("--output_dir", default="evaluation_results", help="Where to save results")
224
+ parser.add_argument("--max_samples", type=int, default=None, help="Limit evaluation samples")
225
+ parser.add_argument("--split", default="test", choices=["train", "val", "test"], help="Which split to evaluate")
226
+ args = parser.parse_args()
227
+
228
+ import yaml
229
+ with open(args.config, "r") as f:
230
+ config = yaml.safe_load(f)
231
+
232
+ os.makedirs(args.output_dir, exist_ok=True)
233
+ device = "cuda" if torch.cuda.is_available() else "cpu"
234
+
235
+ print(f"Loading model from {args.model_path}...")
236
+ processor = DonutProcessor.from_pretrained(args.model_path)
237
+ model = VisionEncoderDecoderModel.from_pretrained(args.model_path)
238
+ model.to(device).eval()
239
+
240
+ print(f"Loading dataset split: {args.split}")
241
+ dataset = UnifiedReceiptDataset(
242
+ root=config["data"]["dataset_root"],
243
+ split=args.split,
244
+ processor=None,
245
+ include_datasets=config["data"].get("include_datasets"),
246
+ )
247
+
248
+ counts, errors = evaluate(model, processor, dataset, device, args.max_samples)
249
+ plot_confusion_matrix(counts, args.output_dir)
250
+ save_accuracy_json(counts, args.output_dir)
251
+ save_error_html(errors, args.output_dir)
252
+
253
+ print("\n=== Evaluation Complete ===")
254
+ for field in FIELDS:
255
+ total = sum(counts[field].values())
256
+ c = counts[field]["correct"]
257
+ m = counts[field]["minor_typo"]
258
+ i = counts[field]["incorrect"]
259
+ print(f" {field:12s}: Correct={c}/{total} ({c/total*100:.1f}%) | "
260
+ f"Minor={m} | Incorrect={i}")
261
+
262
+
263
+ if __name__ == "__main__":
264
+ main()