SATEv1.5 / eval_segmentation.py
Shuwei Hou
initial_for_hf
5806e12
import ast
import os
import re
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support
from segmentation import segment_batchalign
from segmentation import segment_SaT
from segmentation import segment_SaT_cunit_3l
from segmentation import segment_SaT_cunit_12l
from segmentation import segment_SaT_cunit_3l_r32a64
from segmentation import segment_SaT_cunit_3l_r64a128
from segmentation import segment_SaT_cunit_3l_no_shuffle
from tqdm import tqdm
def clean_text(text):
return re.sub(r"[^\w\s]", "", text.lower()).strip()
def eval_segmentation(dataset_path, segmentation_model, model_name="unknown", chunk_num=10):
os.makedirs("benchmark_result/segmentation", exist_ok=True)
df = pd.read_csv(dataset_path)
results = []
for i in tqdm(range(0, len(df), chunk_num), desc="Evaluating chunks"):
chunk = df.iloc[i:i + chunk_num]
if len(chunk) < chunk_num:
continue
word_sequence = []
gt_label_sequence = []
for row in chunk["cleaned_transcription"]:
if pd.isna(row):
continue
cleaned = clean_text(row)
words = cleaned.split()
if not words:
continue
word_sequence.extend(words)
gt_label_sequence.extend([0] * (len(words) - 1) + [1])
input_text = " ".join(word_sequence)
predicted_labels = segmentation_model(input_text)
if len(predicted_labels) != len(gt_label_sequence):
print(f"Label length mismatch at chunk {i}. Skipping...")
continue
results.append({
"word_sequence": input_text,
"gt_label_sequence": " ".join(map(str, gt_label_sequence)),
"predict_label_sequence": " ".join(map(str, predicted_labels))
})
result_df = pd.DataFrame(results)
result_df.to_csv(f"benchmark_result/segmentation/{model_name}_results.csv", index=False)
all_gt = []
all_pred = []
for row in results:
all_gt.extend(map(int, row["gt_label_sequence"].split()))
all_pred.extend(map(int, row["predict_label_sequence"].split()))
tp = sum((g == 1 and p == 1) for g, p in zip(all_gt, all_pred))
fp = sum((g == 0 and p == 1) for g, p in zip(all_gt, all_pred))
fn = sum((g == 1 and p == 0) for g, p in zip(all_gt, all_pred))
precision, recall, f1, _ = precision_recall_fscore_support(all_gt, all_pred, average='binary', zero_division=0)
print(f"{model_name} - TP: {tp}, FP: {fp}, FN: {fn}")
print(f"{model_name} - Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}")
return precision, recall, f1
if __name__ == "__main__":
dataset_path = "./data/enni_salt_for_segmentation/test.csv"
# print("Evaluating BatchAlign segmentation model...")
# batchalign_precision, batchalign_recall, batchalign_f1 = eval_segmentation(
# dataset_path, segment_batchalign, "batchalign"
# )
print("\nEvaluating SaT segmentation model...")
sat_precision, sat_recall, sat_f1 = eval_segmentation(
dataset_path, segment_SaT, "SaT"
)
print("\nEvaluating SaT_cunit_3l segmentation model...")
sat_cunit_3l_precision, sat_cunit_3l_recall, sat_cunit_3l_f1 = eval_segmentation(
dataset_path, segment_SaT_cunit_3l, "SaT_cunit_3l"
)
print("\nEvaluating SaT_cunit_12l segmentation model...")
sat_cunit_12l_precision, sat_cunit_12l_recall, sat_cunit_12l_f1 = eval_segmentation(
dataset_path, segment_SaT_cunit_12l, "SaT_cunit_12l"
)
print("\nEvaluating SaT_cunit_3l_r32a64 segmentation model...")
sat_cunit_3l_r32a64_precision, sat_cunit_3l_r32a64_recall, sat_cunit_3l_r32a64_f1 = eval_segmentation(
dataset_path, segment_SaT_cunit_3l_r32a64, "SaT_cunit_3l_r32a64"
)
print("\nEvaluating SaT_cunit_3l_r64a128 segmentation model...")
sat_cunit_3l_r64a128_precision, sat_cunit_3l_r64a128_recall, sat_cunit_3l_r64a128_f1 = eval_segmentation(
dataset_path, segment_SaT_cunit_3l_r64a128, "SaT_cunit_3l_r64a128"
)
print("\nEvaluating SaT_cunit_3l_no_shuffle segmentation model...")
sat_cunit_3l_no_shuffle_precision, sat_cunit_3l_no_shuffle_recall, sat_cunit_3l_no_shuffle_f1 = eval_segmentation(
dataset_path, segment_SaT_cunit_3l_no_shuffle, "SaT_cunit_3l_no_shuffle"
)
print("\n" + "="*80)
print("COMPARISON RESULTS:")
print("="*80)
# print(f"BatchAlign - Precision: {batchalign_precision:.3f}, Recall: {batchalign_recall:.3f}, F1: {batchalign_f1:.3f}")
print(f"SaT - Precision: {sat_precision:.3f}, Recall: {sat_recall:.3f}, F1: {sat_f1:.3f}")
print(f"SaT_cunit_3l - Precision: {sat_cunit_3l_precision:.3f}, Recall: {sat_cunit_3l_recall:.3f}, F1: {sat_cunit_3l_f1:.3f}")
print(f"SaT_cunit_12l - Precision: {sat_cunit_12l_precision:.3f}, Recall: {sat_cunit_12l_recall:.3f}, F1: {sat_cunit_12l_f1:.3f}")
print(f"SaT_cunit_3l_r32a64 - Precision: {sat_cunit_3l_r32a64_precision:.3f}, Recall: {sat_cunit_3l_r32a64_recall:.3f}, F1: {sat_cunit_3l_r32a64_f1:.3f}")
print(f"SaT_cunit_3l_r64a128 - Precision: {sat_cunit_3l_r64a128_precision:.3f}, Recall: {sat_cunit_3l_r64a128_recall:.3f}, F1: {sat_cunit_3l_r64a128_f1:.3f}")
print(f"SaT_cunit_3l_no_shuffle - Precision: {sat_cunit_3l_no_shuffle_precision:.3f}, Recall: {sat_cunit_3l_no_shuffle_recall:.3f}, F1: {sat_cunit_3l_no_shuffle_f1:.3f}")