| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Evaluation script for Vietnamese POS Tagger (TRE-1). |
| |
| Usage: |
| uv run scripts/evaluate.py |
| uv run scripts/evaluate.py --version v1.0.0 |
| uv run scripts/evaluate.py --model models/pos_tagger/v1.0.0/model.crfsuite |
| uv run scripts/evaluate.py --save-plots |
| """ |
|
|
| import re |
| from collections import Counter |
| from pathlib import Path |
|
|
| import click |
| import pycrfsuite |
| from datasets import load_dataset |
|
|
| |
| PROJECT_ROOT = Path(__file__).parent.parent |
| from sklearn.metrics import ( |
| accuracy_score, |
| precision_recall_fscore_support, |
| classification_report, |
| confusion_matrix, |
| ) |
|
|
|
|
| FEATURE_TEMPLATES = [ |
| "T[0]", "T[0].lower", "T[0].istitle", "T[0].isupper", |
| "T[0].isdigit", "T[0].isalpha", "T[0].prefix2", "T[0].prefix3", |
| "T[0].suffix2", "T[0].suffix3", "T[-1]", "T[-1].lower", |
| "T[-1].istitle", "T[-1].isupper", "T[-2]", "T[-2].lower", |
| "T[1]", "T[1].lower", "T[1].istitle", "T[1].isupper", |
| "T[2]", "T[2].lower", "T[-1,0]", "T[0,1]", |
| "T[0].is_in_dict", "T[-1,0].is_in_dict", "T[0,1].is_in_dict", |
| ] |
|
|
|
|
| def get_token_value(tokens, position, index): |
| actual_pos = position + index |
| if actual_pos < 0: |
| return "__BOS__" |
| elif actual_pos >= len(tokens): |
| return "__EOS__" |
| return tokens[actual_pos] |
|
|
|
|
| def apply_attribute(value, attribute, dictionary=None): |
| if value in ("__BOS__", "__EOS__"): |
| return value |
| if attribute is None: |
| return value |
| elif attribute == "lower": |
| return value.lower() |
| elif attribute == "upper": |
| return value.upper() |
| elif attribute == "istitle": |
| return str(value.istitle()) |
| elif attribute == "isupper": |
| return str(value.isupper()) |
| elif attribute == "islower": |
| return str(value.islower()) |
| elif attribute == "isdigit": |
| return str(value.isdigit()) |
| elif attribute == "isalpha": |
| return str(value.isalpha()) |
| elif attribute == "is_in_dict": |
| return str(value in dictionary) if dictionary else "False" |
| elif attribute.startswith("prefix"): |
| n = int(attribute[6:]) if len(attribute) > 6 else 2 |
| return value[:n] if len(value) >= n else value |
| elif attribute.startswith("suffix"): |
| n = int(attribute[6:]) if len(attribute) > 6 else 2 |
| return value[-n:] if len(value) >= n else value |
| return value |
|
|
|
|
| def parse_template(template): |
| match = re.match(r"T\[([^\]]+)\](?:\.(\w+))?", template) |
| if not match: |
| return None, None |
| indices_str = match.group(1) |
| attribute = match.group(2) |
| indices = [int(i.strip()) for i in indices_str.split(",")] |
| return indices, attribute |
|
|
|
|
| def extract_features(tokens, position, dictionary=None): |
| features = {} |
| for template in FEATURE_TEMPLATES: |
| indices, attribute = parse_template(template) |
| if indices is None: |
| continue |
| if len(indices) == 1: |
| value = get_token_value(tokens, position, indices[0]) |
| value = apply_attribute(value, attribute, dictionary) |
| features[template] = value |
| else: |
| values = [get_token_value(tokens, position, idx) for idx in indices] |
| if attribute == "is_in_dict": |
| combined = " ".join(values) |
| features[template] = str(combined in dictionary) if dictionary else "False" |
| else: |
| combined = "|".join(values) |
| features[template] = combined |
| return features |
|
|
|
|
| def sentence_to_features(tokens): |
| return [ |
| [f"{k}={v}" for k, v in extract_features(tokens, i).items()] |
| for i in range(len(tokens)) |
| ] |
|
|
|
|
| def load_test_data(): |
| click.echo("Loading UDD-1 dataset...") |
| dataset = load_dataset("undertheseanlp/UDD-1") |
|
|
| sentences = [] |
| for item in dataset["test"]: |
| tokens = item["tokens"] |
| tags = item["upos"] |
| if tokens and tags: |
| sentences.append((tokens, tags)) |
|
|
| click.echo(f"Test set: {len(sentences)} sentences") |
| return sentences |
|
|
|
|
| def plot_confusion_matrix(y_true, y_pred, labels, output_path): |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
|
|
| cm = confusion_matrix(y_true, y_pred, labels=labels) |
|
|
| plt.figure(figsize=(12, 10)) |
| sns.heatmap( |
| cm, |
| annot=True, |
| fmt="d", |
| cmap="Blues", |
| xticklabels=labels, |
| yticklabels=labels, |
| ) |
| plt.xlabel("Predicted") |
| plt.ylabel("True") |
| plt.title("Confusion Matrix - Vietnamese POS Tagger (TRE-1)") |
| plt.tight_layout() |
| plt.savefig(output_path, dpi=150) |
| plt.close() |
| click.echo(f"Confusion matrix saved to {output_path}") |
|
|
|
|
| def plot_per_tag_metrics(report_dict, output_path): |
| import matplotlib.pyplot as plt |
|
|
| tags = [k for k in report_dict.keys() if k not in ("accuracy", "macro avg", "weighted avg")] |
|
|
| precision = [report_dict[t]["precision"] for t in tags] |
| recall = [report_dict[t]["recall"] for t in tags] |
| f1 = [report_dict[t]["f1-score"] for t in tags] |
|
|
| x = range(len(tags)) |
| width = 0.25 |
|
|
| fig, ax = plt.subplots(figsize=(14, 6)) |
| ax.bar([i - width for i in x], precision, width, label="Precision", color="#2ecc71") |
| ax.bar(x, recall, width, label="Recall", color="#3498db") |
| ax.bar([i + width for i in x], f1, width, label="F1-Score", color="#e74c3c") |
|
|
| ax.set_xlabel("POS Tag") |
| ax.set_ylabel("Score") |
| ax.set_title("Per-Tag Performance Metrics - Vietnamese POS Tagger (TRE-1)") |
| ax.set_xticks(x) |
| ax.set_xticklabels(tags, rotation=45) |
| ax.legend() |
| ax.set_ylim(0, 1.1) |
| ax.grid(axis="y", alpha=0.3) |
|
|
| plt.tight_layout() |
| plt.savefig(output_path, dpi=150) |
| plt.close() |
| click.echo(f"Per-tag metrics saved to {output_path}") |
|
|
|
|
| def analyze_errors(y_true, y_pred, tokens_flat, top_n=10): |
| """Analyze common error patterns.""" |
| errors = Counter() |
| error_examples = {} |
|
|
| for true, pred, token in zip(y_true, y_pred, tokens_flat): |
| if true != pred: |
| key = (true, pred) |
| errors[key] += 1 |
| if key not in error_examples: |
| error_examples[key] = token |
|
|
| click.echo(f"\nTop {top_n} Error Patterns:") |
| click.echo("-" * 60) |
| click.echo(f"{'True':<10} {'Predicted':<10} {'Count':<8} {'Example'}") |
| click.echo("-" * 60) |
|
|
| for (true, pred), count in errors.most_common(top_n): |
| example = error_examples.get((true, pred), "") |
| click.echo(f"{true:<10} {pred:<10} {count:<8} {example}") |
|
|
|
|
| def get_latest_version(task="pos_tagger"): |
| """Get the latest model version (sorted by timestamp).""" |
| models_dir = PROJECT_ROOT / "models" / task |
| if not models_dir.exists(): |
| return None |
| versions = [d.name for d in models_dir.iterdir() if d.is_dir()] |
| if not versions: |
| return None |
| return sorted(versions)[-1] |
|
|
|
|
| @click.command() |
| @click.option( |
| "--version", "-v", |
| default=None, |
| help="Model version to evaluate (default: latest)", |
| ) |
| @click.option( |
| "--model", "-m", |
| default=None, |
| help="Custom model path (overrides version-based path)", |
| ) |
| @click.option( |
| "--save-plots", |
| is_flag=True, |
| help="Save confusion matrix and per-tag metrics plots", |
| ) |
| def evaluate(version, model, save_plots): |
| """Evaluate Vietnamese POS Tagger on UDD-1 test set.""" |
| |
| if version is None and model is None: |
| version = get_latest_version("pos_tagger") |
| if version is None: |
| raise click.ClickException("No models found in models/pos_tagger/") |
|
|
| |
| if model: |
| model_path = Path(model) |
| else: |
| model_path = PROJECT_ROOT / "models" / "pos_tagger" / version / "model.crfsuite" |
|
|
| |
| if save_plots: |
| results_dir = PROJECT_ROOT / "results" / "pos_tagger" |
| results_dir.mkdir(parents=True, exist_ok=True) |
|
|
| click.echo(f"Loading model from {model_path}...") |
| tagger = pycrfsuite.Tagger() |
| tagger.open(str(model_path)) |
|
|
| test_data = load_test_data() |
|
|
| click.echo("Extracting features and predicting...") |
| X_test = [sentence_to_features(tokens) for tokens, _ in test_data] |
| y_test = [tags for _, tags in test_data] |
| tokens_test = [tokens for tokens, _ in test_data] |
|
|
| y_pred = [tagger.tag(xseq) for xseq in X_test] |
|
|
| |
| y_test_flat = [tag for tags in y_test for tag in tags] |
| y_pred_flat = [tag for tags in y_pred for tag in tags] |
| tokens_flat = [token for tokens in tokens_test for token in tokens] |
|
|
| |
| labels = sorted(set(y_test_flat)) |
|
|
| |
| accuracy = accuracy_score(y_test_flat, y_pred_flat) |
| precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( |
| y_test_flat, y_pred_flat, average="macro" |
| ) |
| _, _, f1_weighted, _ = precision_recall_fscore_support( |
| y_test_flat, y_pred_flat, average="weighted" |
| ) |
|
|
| click.echo("\n" + "=" * 60) |
| click.echo("EVALUATION RESULTS") |
| click.echo("=" * 60) |
|
|
| click.echo("\nOverall Metrics:") |
| click.echo(f" Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)") |
| click.echo(f" Precision (macro): {precision_macro:.4f}") |
| click.echo(f" Recall (macro): {recall_macro:.4f}") |
| click.echo(f" F1 (macro): {f1_macro:.4f}") |
| click.echo(f" F1 (weighted): {f1_weighted:.4f}") |
|
|
| click.echo("\nPer-Tag Classification Report:") |
| report = classification_report(y_test_flat, y_pred_flat, digits=4) |
| click.echo(report) |
|
|
| |
| analyze_errors(y_test_flat, y_pred_flat, tokens_flat) |
|
|
| |
| tag_counts = Counter(y_test_flat) |
| total_tokens = len(y_test_flat) |
|
|
| click.echo("\nTest Set Tag Distribution:") |
| click.echo("-" * 40) |
| for tag in labels: |
| count = tag_counts[tag] |
| pct = count / total_tokens * 100 |
| click.echo(f" {tag:<8} {count:>6} ({pct:>5.2f}%)") |
|
|
| if save_plots: |
| cm_path = results_dir / f"confusion_matrix_{version}.png" |
| plot_confusion_matrix( |
| y_test_flat, y_pred_flat, labels, |
| str(cm_path) |
| ) |
|
|
| report_dict = classification_report( |
| y_test_flat, y_pred_flat, output_dict=True |
| ) |
| metrics_path = results_dir / f"per_tag_metrics_{version}.png" |
| plot_per_tag_metrics(report_dict, str(metrics_path)) |
|
|
| return accuracy |
|
|
|
|
| if __name__ == "__main__": |
| evaluate() |
|
|