rewrite / src /evaluation /errant_evaluator.py
morpheuslord's picture
Add files using upload-large-folder tool
12fd5f2 verified
"""
ERRANT-based grammatical error evaluation.
Uses the ERRANT toolkit for standardised GEC evaluation with
precision, recall, and F0.5 scores.
"""
from typing import List, Dict
from loguru import logger
class ERRANTEvaluator:
"""Evaluates grammar correction quality using ERRANT annotations."""
def __init__(self):
try:
import errant
self.annotator = errant.load("en")
logger.info("ERRANT annotator loaded")
except Exception as e:
logger.warning(f"ERRANT failed to load: {e}. Evaluation will use fallback.")
self.annotator = None
def evaluate(
self,
sources: List[str],
predictions: List[str],
references: List[str],
) -> Dict[str, float]:
"""Compute ERRANT precision, recall, F0.5."""
if self.annotator is None:
logger.warning("ERRANT not available, returning zero scores")
return {"precision": 0.0, "recall": 0.0, "f0.5": 0.0}
tp = 0
fp = 0
fn = 0
for src, pred, ref in zip(sources, predictions, references):
try:
# Parse source and annotate edits
orig = self.annotator.parse(src)
cor_pred = self.annotator.parse(pred)
cor_ref = self.annotator.parse(ref)
# Get edit annotations
pred_edits = self.annotator.annotate(orig, cor_pred)
ref_edits = self.annotator.annotate(orig, cor_ref)
# Convert to comparable sets of (start, end, correction, type)
pred_set = set()
for e in pred_edits:
pred_set.add((e.o_start, e.o_end, e.c_str, e.type))
ref_set = set()
for e in ref_edits:
ref_set.add((e.o_start, e.o_end, e.c_str, e.type))
# Count TP, FP, FN
tp += len(pred_set & ref_set)
fp += len(pred_set - ref_set)
fn += len(ref_set - pred_set)
except Exception as e:
logger.debug(f"ERRANT annotation failed for a sample: {e}")
continue
# Compute precision, recall, F0.5
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
# F0.5 weighs precision higher than recall (β=0.5)
beta = 0.5
if precision + recall > 0:
f_score = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall)
else:
f_score = 0.0
return {
"precision": precision,
"recall": recall,
"f0.5": f_score,
}