Add field-level confusion matrix evaluation script
Browse files- evaluate_model.py +264 -0
evaluate_model.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluate the fine-tuned Donut model and generate a Field-Level Confusion Matrix.
|
| 3 |
+
Run this on the Workbench where the model and datasets are located.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python scripts/evaluate_model.py \
|
| 7 |
+
--model_path outputs/receipt_donut_gcp_enterprise/best_model \
|
| 8 |
+
--config configs/gcp_l4_enterprise.yaml \
|
| 9 |
+
--output_dir evaluation_results
|
| 10 |
+
|
| 11 |
+
Outputs:
|
| 12 |
+
- evaluation_results/field_confusion_matrix.png
|
| 13 |
+
- evaluation_results/field_accuracy.json
|
| 14 |
+
- evaluation_results/error_analysis.html
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
import json
|
| 20 |
+
import argparse
|
| 21 |
+
import Levenshtein
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from collections import defaultdict
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
from PIL import Image
|
| 28 |
+
import matplotlib.pyplot as plt
|
| 29 |
+
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
| 30 |
+
|
| 31 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 32 |
+
from core.unified_dataset import UnifiedReceiptDataset
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
FIELDS = ["merchant", "date", "subtotal", "tax", "total", "address"]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def normalize_text(text):
|
| 39 |
+
"""Lowercase and strip whitespace for fair comparison."""
|
| 40 |
+
if text is None:
|
| 41 |
+
return ""
|
| 42 |
+
return str(text).lower().strip().replace("$", "").replace(",", "")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def categorize_match(gt, pred):
|
| 46 |
+
"""
|
| 47 |
+
Categorize a single field prediction into:
|
| 48 |
+
- correct: exact match after normalization
|
| 49 |
+
- minor_typo: < 20% Levenshtein distance
|
| 50 |
+
- incorrect: everything else
|
| 51 |
+
"""
|
| 52 |
+
gt_norm = normalize_text(gt)
|
| 53 |
+
pred_norm = normalize_text(pred)
|
| 54 |
+
|
| 55 |
+
if not gt_norm and not pred_norm:
|
| 56 |
+
return "correct" # Both missing = agreement
|
| 57 |
+
if not gt_norm or not pred_norm:
|
| 58 |
+
return "incorrect" # One missing, one present
|
| 59 |
+
|
| 60 |
+
if gt_norm == pred_norm:
|
| 61 |
+
return "correct"
|
| 62 |
+
|
| 63 |
+
dist = Levenshtein.distance(gt_norm, pred_norm)
|
| 64 |
+
max_len = max(len(gt_norm), len(pred_norm))
|
| 65 |
+
ratio = dist / max_len if max_len > 0 else 0
|
| 66 |
+
|
| 67 |
+
if ratio < 0.20:
|
| 68 |
+
return "minor_typo"
|
| 69 |
+
return "incorrect"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def run_inference(model, processor, image_path, device):
|
| 73 |
+
"""Run model inference on a single image and return parsed JSON dict."""
|
| 74 |
+
img = Image.open(image_path).convert("RGB")
|
| 75 |
+
pixel_values = processor(img, return_tensors="pt").pixel_values.to(device)
|
| 76 |
+
decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device)
|
| 77 |
+
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
outputs = model.generate(
|
| 80 |
+
pixel_values,
|
| 81 |
+
decoder_input_ids=decoder_input_ids,
|
| 82 |
+
max_length=512,
|
| 83 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
| 84 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
| 85 |
+
use_cache=True,
|
| 86 |
+
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
seq = processor.tokenizer.batch_decode(outputs.sequences)[0]
|
| 90 |
+
seq = seq.replace(processor.tokenizer.eos_token, "").replace(
|
| 91 |
+
processor.tokenizer.pad_token, ""
|
| 92 |
+
)
|
| 93 |
+
seq = seq.replace(
|
| 94 |
+
processor.tokenizer.decode([model.config.decoder_start_token_id]), ""
|
| 95 |
+
).strip()
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
return json.loads(seq)
|
| 99 |
+
except json.JSONDecodeError:
|
| 100 |
+
return {}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def evaluate(model, processor, dataset, device, max_samples=None):
|
| 104 |
+
"""
|
| 105 |
+
Evaluate the model on a dataset and return per-field statistics.
|
| 106 |
+
"""
|
| 107 |
+
counts = {field: {"correct": 0, "minor_typo": 0, "incorrect": 0} for field in FIELDS}
|
| 108 |
+
errors = []
|
| 109 |
+
|
| 110 |
+
n = min(len(dataset), max_samples) if max_samples else len(dataset)
|
| 111 |
+
print(f"Evaluating on {n} samples...")
|
| 112 |
+
|
| 113 |
+
for i in range(n):
|
| 114 |
+
sample = dataset[i]
|
| 115 |
+
image_path = sample["image_path"]
|
| 116 |
+
gt = sample["ground_truth"]
|
| 117 |
+
|
| 118 |
+
pred = run_inference(model, processor, image_path, device)
|
| 119 |
+
|
| 120 |
+
sample_error = {"image": image_path, "gt": gt, "pred": pred, "fields": {}}
|
| 121 |
+
all_correct = True
|
| 122 |
+
|
| 123 |
+
for field in FIELDS:
|
| 124 |
+
gt_val = gt.get(field, "")
|
| 125 |
+
pred_val = pred.get(field, "")
|
| 126 |
+
cat = categorize_match(gt_val, pred_val)
|
| 127 |
+
counts[field][cat] += 1
|
| 128 |
+
sample_error["fields"][field] = cat
|
| 129 |
+
if cat != "correct":
|
| 130 |
+
all_correct = False
|
| 131 |
+
|
| 132 |
+
if not all_correct:
|
| 133 |
+
errors.append(sample_error)
|
| 134 |
+
|
| 135 |
+
if (i + 1) % 50 == 0:
|
| 136 |
+
print(f" Processed {i + 1}/{n}")
|
| 137 |
+
|
| 138 |
+
return counts, errors
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def plot_confusion_matrix(counts, output_dir):
|
| 142 |
+
"""Generate a stacked bar chart confusion matrix per field."""
|
| 143 |
+
categories = ["correct", "minor_typo", "incorrect"]
|
| 144 |
+
colors = ["#4CAF50", "#FFC107", "#F44336"]
|
| 145 |
+
|
| 146 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 147 |
+
x = np.arange(len(FIELDS))
|
| 148 |
+
width = 0.25
|
| 149 |
+
|
| 150 |
+
for i, cat in enumerate(categories):
|
| 151 |
+
values = [counts[f][cat] for f in FIELDS]
|
| 152 |
+
ax.bar(x + i * width, values, width, label=cat.replace("_", " ").title(), color=colors[i])
|
| 153 |
+
|
| 154 |
+
ax.set_xlabel("Field")
|
| 155 |
+
ax.set_ylabel("Count")
|
| 156 |
+
ax.set_title("Field-Level Confusion Matrix (Validation/Test Set)")
|
| 157 |
+
ax.set_xticks(x + width)
|
| 158 |
+
ax.set_xticklabels(FIELDS, rotation=15, ha="right")
|
| 159 |
+
ax.legend()
|
| 160 |
+
ax.grid(axis="y", linestyle="--", alpha=0.5)
|
| 161 |
+
plt.tight_layout()
|
| 162 |
+
|
| 163 |
+
save_path = os.path.join(output_dir, "field_confusion_matrix.png")
|
| 164 |
+
plt.savefig(save_path, dpi=150)
|
| 165 |
+
print(f"Saved confusion matrix to {save_path}")
|
| 166 |
+
plt.close()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def save_accuracy_json(counts, output_dir):
|
| 170 |
+
"""Save numerical accuracy breakdown per field."""
|
| 171 |
+
results = {}
|
| 172 |
+
for field in FIELDS:
|
| 173 |
+
total = sum(counts[field].values())
|
| 174 |
+
results[field] = {
|
| 175 |
+
"correct_pct": round(counts[field]["correct"] / total * 100, 1),
|
| 176 |
+
"minor_typo_pct": round(counts[field]["minor_typo"] / total * 100, 1),
|
| 177 |
+
"incorrect_pct": round(counts[field]["incorrect"] / total * 100, 1),
|
| 178 |
+
"counts": counts[field],
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
save_path = os.path.join(output_dir, "field_accuracy.json")
|
| 182 |
+
with open(save_path, "w") as f:
|
| 183 |
+
json.dump(results, f, indent=2)
|
| 184 |
+
print(f"Saved accuracy JSON to {save_path}")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def save_error_html(errors, output_dir, max_display=50):
|
| 188 |
+
"""Generate an HTML file showing side-by-side GT vs Pred errors."""
|
| 189 |
+
html = ["<html><head><style>",
|
| 190 |
+
"body{font-family:sans-serif;margin:20px}",
|
| 191 |
+
"table{border-collapse:collapse;width:100%}",
|
| 192 |
+
"th,td{border:1px solid #ccc;padding:8px;text-align:left}",
|
| 193 |
+
"th{background:#f0f0f0}",
|
| 194 |
+
".correct{color:green}.minor{color:orange}.incorrect{color:red}",
|
| 195 |
+
"</style></head><body>",
|
| 196 |
+
f"<h1>Error Analysis ({min(len(errors), max_display)} of {len(errors)} failures)</h1>",
|
| 197 |
+
"<table><tr><th>Image</th><th>Field</th><th>Ground Truth</th><th>Predicted</th><th>Status</th></tr>"]
|
| 198 |
+
|
| 199 |
+
for err in errors[:max_display]:
|
| 200 |
+
img_name = os.path.basename(err["image"])
|
| 201 |
+
for field in FIELDS:
|
| 202 |
+
status = err["fields"][field]
|
| 203 |
+
if status == "correct":
|
| 204 |
+
continue
|
| 205 |
+
css_class = "correct" if status == "correct" else ("minor" if status == "minor_typo" else "incorrect")
|
| 206 |
+
html.append(f"<tr><td>{img_name}</td><td>{field}</td>"
|
| 207 |
+
f"<td>{err['gt'].get(field, 'N/A')}</td>"
|
| 208 |
+
f"<td>{err['pred'].get(field, 'N/A')}</td>"
|
| 209 |
+
f"<td class='{css_class}'>{status}</td></tr>")
|
| 210 |
+
|
| 211 |
+
html.append("</table></body></html>")
|
| 212 |
+
|
| 213 |
+
save_path = os.path.join(output_dir, "error_analysis.html")
|
| 214 |
+
with open(save_path, "w") as f:
|
| 215 |
+
f.write("\n".join(html))
|
| 216 |
+
print(f"Saved error analysis HTML to {save_path}")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def main():
|
| 220 |
+
parser = argparse.ArgumentParser(description="Evaluate Donut receipt model")
|
| 221 |
+
parser.add_argument("--model_path", required=True, help="Path to fine-tuned model")
|
| 222 |
+
parser.add_argument("--config", default="configs/gcp_l4_enterprise.yaml", help="Training config YAML")
|
| 223 |
+
parser.add_argument("--output_dir", default="evaluation_results", help="Where to save results")
|
| 224 |
+
parser.add_argument("--max_samples", type=int, default=None, help="Limit evaluation samples")
|
| 225 |
+
parser.add_argument("--split", default="test", choices=["train", "val", "test"], help="Which split to evaluate")
|
| 226 |
+
args = parser.parse_args()
|
| 227 |
+
|
| 228 |
+
import yaml
|
| 229 |
+
with open(args.config, "r") as f:
|
| 230 |
+
config = yaml.safe_load(f)
|
| 231 |
+
|
| 232 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 233 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 234 |
+
|
| 235 |
+
print(f"Loading model from {args.model_path}...")
|
| 236 |
+
processor = DonutProcessor.from_pretrained(args.model_path)
|
| 237 |
+
model = VisionEncoderDecoderModel.from_pretrained(args.model_path)
|
| 238 |
+
model.to(device).eval()
|
| 239 |
+
|
| 240 |
+
print(f"Loading dataset split: {args.split}")
|
| 241 |
+
dataset = UnifiedReceiptDataset(
|
| 242 |
+
root=config["data"]["dataset_root"],
|
| 243 |
+
split=args.split,
|
| 244 |
+
processor=None,
|
| 245 |
+
include_datasets=config["data"].get("include_datasets"),
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
counts, errors = evaluate(model, processor, dataset, device, args.max_samples)
|
| 249 |
+
plot_confusion_matrix(counts, args.output_dir)
|
| 250 |
+
save_accuracy_json(counts, args.output_dir)
|
| 251 |
+
save_error_html(errors, args.output_dir)
|
| 252 |
+
|
| 253 |
+
print("\n=== Evaluation Complete ===")
|
| 254 |
+
for field in FIELDS:
|
| 255 |
+
total = sum(counts[field].values())
|
| 256 |
+
c = counts[field]["correct"]
|
| 257 |
+
m = counts[field]["minor_typo"]
|
| 258 |
+
i = counts[field]["incorrect"]
|
| 259 |
+
print(f" {field:12s}: Correct={c}/{total} ({c/total*100:.1f}%) | "
|
| 260 |
+
f"Minor={m} | Incorrect={i}")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
main()
|