SATEv1.5 / eval_morpheme.py
Shuwei Hou
initial_for_hf
5806e12
import csv
import os
from collections import Counter
from morpheme.morpheme_stanza_v1 import extract_inflectional_morphemes
# Morpheme mappings
MORPHEME_NUM_MAP = {
"Plural": "1",
"Possessive": "2",
"3rd Person Singular": "3",
"Past Tense": "4",
"Progressive": "6"
}
MORPHEME_LABEL_MAP = {v: k for k, v in MORPHEME_NUM_MAP.items()}
ALL_MORPHEME_LABELS = list(MORPHEME_NUM_MAP.keys())
class Detector:
def analyze(self, text):
annotations = extract_inflectional_morphemes(text)
print(f"\n🔎 DEBUG: Analyzing utterance: '{text}'")
for i, ann in enumerate(annotations):
print(f" Annotation {i}: {ann}")
filtered = []
for a in annotations:
morph = a.get("inflectional_morpheme")
word = a.get("word")
if morph in MORPHEME_NUM_MAP and word:
filtered.append((word.lower(), morph))
else:
print(f" Skipping invalid annotation: {a}")
print(f" Filtered predictions: {filtered}")
return filtered
def to_morpheme_string(morpheme_list, utterance):
tokens = utterance.strip().split()
token_tags = ['0'] * len(tokens)
lowered_tokens = [t.lower() for t in tokens]
for word, morph in morpheme_list:
matches = [i for i, tok in enumerate(lowered_tokens) if tok == word]
if matches:
token_tags[matches[0]] = MORPHEME_NUM_MAP[morph]
else:
print(f"⚠ Word '{word}' not found in utterance: '{utterance}'")
return ''.join(token_tags)
def eval_morpheme(dataset_path, morpheme_detector):
output_dir = "benchmark_result/morpheme"
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "result.csv")
with open(dataset_path, newline='', encoding='utf-8') as f:
reader = csv.DictReader(f)
data = list(reader)
results = []
TP_counter = Counter()
FP_counter = Counter()
FN_counter = Counter()
for i, row in enumerate(data):
print("\n=== NEW SAMPLE ===")
utterance = row["cleaned_transcription"]
gold_str = row.get("morpheme_code", "")
print(f"Row {i}: {utterance}")
print(f"Raw GOLD code: '{gold_str}'")
gold = []
if gold_str:
gold_numbers = [x.strip() for x in gold_str.split()]
gold = [MORPHEME_LABEL_MAP.get(num) for num in gold_numbers if num in MORPHEME_LABEL_MAP]
print(f"Parsed GOLD labels: {gold}")
pred_morphemes = morpheme_detector.analyze(utterance)
gold_set = set(gold)
pred_set = set([m for (_, m) in pred_morphemes])
TP = gold_set & pred_set
FP = pred_set - gold_set
FN = gold_set - pred_set
if TP: print(f" TP: {TP}")
if FP: print(f"⚠ FP: {FP}")
if FN: print(f" FN: {FN}")
if not TP and not FP and not FN: print("No match — check formatting.")
for m in TP:
TP_counter[m] += 1
for m in FP:
FP_counter[m] += 1
for m in FN:
FN_counter[m] += 1
predicted_str = to_morpheme_string(pred_morphemes, utterance)
token_count = len(utterance.strip().split())
gold_str_fixed = ''.join(gold_str.split()).ljust(token_count, '0')[:token_count]
results.append({
"utterance": utterance,
"gold": gold_str_fixed,
"predicted": predicted_str,
})
with open(output_path, "w", newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=["utterance", "gold", "predicted"])
writer.writeheader()
writer.writerows(results)
# METRICS
print("\n\n=== Evaluation Metrics ===")
observed_labels = {m for m in TP_counter.keys() | FP_counter.keys() | FN_counter.keys()}
if "Comparative" in observed_labels:
eval_labels = ALL_MORPHEME_LABELS
else:
eval_labels = [m for m in ALL_MORPHEME_LABELS if m != "Comparative"]
macro_p, macro_r, macro_f1 = 0, 0, 0
for label in eval_labels:
TP = TP_counter[label]
FP = FP_counter[label]
FN = FN_counter[label]
precision = TP / (TP + FP) if TP + FP > 0 else 0
recall = TP / (TP + FN) if TP + FN > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
macro_p += precision
macro_r += recall
macro_f1 += f1
print(f"{MORPHEME_NUM_MAP[label]} ({label}): Precision={precision:.3f}, Recall={recall:.3f}, F1={f1:.3f}")
n = len(eval_labels)
print("\n-- Macro-Averaged Metrics --")
print(f"Precision: {macro_p / n:.3f}")
print(f"Recall: {macro_r / n:.3f}")
print(f"F1 Score: {macro_f1 / n:.3f}")
if __name__ == "__main__":
dataset_path = "./data/enni_salt_for_morpheme/test.csv"
morpheme_detector = Detector()
eval_morpheme(dataset_path, morpheme_detector)