| """ |
| ERRANT-based grammatical error evaluation. |
| Uses the ERRANT toolkit for standardised GEC evaluation with |
| precision, recall, and F0.5 scores. |
| """ |
|
|
| from typing import List, Dict |
| from loguru import logger |
|
|
|
|
| class ERRANTEvaluator: |
| """Evaluates grammar correction quality using ERRANT annotations.""" |
|
|
| def __init__(self): |
| try: |
| import errant |
| self.annotator = errant.load("en") |
| logger.info("ERRANT annotator loaded") |
| except Exception as e: |
| logger.warning(f"ERRANT failed to load: {e}. Evaluation will use fallback.") |
| self.annotator = None |
|
|
| def evaluate( |
| self, |
| sources: List[str], |
| predictions: List[str], |
| references: List[str], |
| ) -> Dict[str, float]: |
| """Compute ERRANT precision, recall, F0.5.""" |
| if self.annotator is None: |
| logger.warning("ERRANT not available, returning zero scores") |
| return {"precision": 0.0, "recall": 0.0, "f0.5": 0.0} |
|
|
| tp = 0 |
| fp = 0 |
| fn = 0 |
|
|
| for src, pred, ref in zip(sources, predictions, references): |
| try: |
| |
| orig = self.annotator.parse(src) |
| cor_pred = self.annotator.parse(pred) |
| cor_ref = self.annotator.parse(ref) |
|
|
| |
| pred_edits = self.annotator.annotate(orig, cor_pred) |
| ref_edits = self.annotator.annotate(orig, cor_ref) |
|
|
| |
| pred_set = set() |
| for e in pred_edits: |
| pred_set.add((e.o_start, e.o_end, e.c_str, e.type)) |
|
|
| ref_set = set() |
| for e in ref_edits: |
| ref_set.add((e.o_start, e.o_end, e.c_str, e.type)) |
|
|
| |
| tp += len(pred_set & ref_set) |
| fp += len(pred_set - ref_set) |
| fn += len(ref_set - pred_set) |
|
|
| except Exception as e: |
| logger.debug(f"ERRANT annotation failed for a sample: {e}") |
| continue |
|
|
| |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
|
|
| |
| beta = 0.5 |
| if precision + recall > 0: |
| f_score = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall) |
| else: |
| f_score = 0.0 |
|
|
| return { |
| "precision": precision, |
| "recall": recall, |
| "f0.5": f_score, |
| } |
|
|