|
import ast |
|
import os |
|
import re |
|
import pandas as pd |
|
from sklearn.metrics import precision_recall_fscore_support |
|
from segmentation import segment_batchalign |
|
from segmentation import segment_SaT |
|
from segmentation import segment_SaT_cunit_3l |
|
from segmentation import segment_SaT_cunit_12l |
|
from segmentation import segment_SaT_cunit_3l_r32a64 |
|
from segmentation import segment_SaT_cunit_3l_r64a128 |
|
from segmentation import segment_SaT_cunit_3l_no_shuffle |
|
|
|
from tqdm import tqdm |
|
|
|
def clean_text(text): |
|
return re.sub(r"[^\w\s]", "", text.lower()).strip() |
|
|
|
|
|
def eval_segmentation(dataset_path, segmentation_model, model_name="unknown", chunk_num=10): |
|
os.makedirs("benchmark_result/segmentation", exist_ok=True) |
|
|
|
df = pd.read_csv(dataset_path) |
|
results = [] |
|
|
|
for i in tqdm(range(0, len(df), chunk_num), desc="Evaluating chunks"): |
|
chunk = df.iloc[i:i + chunk_num] |
|
if len(chunk) < chunk_num: |
|
continue |
|
|
|
word_sequence = [] |
|
gt_label_sequence = [] |
|
|
|
for row in chunk["cleaned_transcription"]: |
|
if pd.isna(row): |
|
continue |
|
cleaned = clean_text(row) |
|
words = cleaned.split() |
|
if not words: |
|
continue |
|
word_sequence.extend(words) |
|
gt_label_sequence.extend([0] * (len(words) - 1) + [1]) |
|
|
|
input_text = " ".join(word_sequence) |
|
predicted_labels = segmentation_model(input_text) |
|
|
|
if len(predicted_labels) != len(gt_label_sequence): |
|
print(f"Label length mismatch at chunk {i}. Skipping...") |
|
continue |
|
|
|
results.append({ |
|
"word_sequence": input_text, |
|
"gt_label_sequence": " ".join(map(str, gt_label_sequence)), |
|
"predict_label_sequence": " ".join(map(str, predicted_labels)) |
|
}) |
|
|
|
result_df = pd.DataFrame(results) |
|
result_df.to_csv(f"benchmark_result/segmentation/{model_name}_results.csv", index=False) |
|
|
|
all_gt = [] |
|
all_pred = [] |
|
|
|
for row in results: |
|
all_gt.extend(map(int, row["gt_label_sequence"].split())) |
|
all_pred.extend(map(int, row["predict_label_sequence"].split())) |
|
|
|
tp = sum((g == 1 and p == 1) for g, p in zip(all_gt, all_pred)) |
|
fp = sum((g == 0 and p == 1) for g, p in zip(all_gt, all_pred)) |
|
fn = sum((g == 1 and p == 0) for g, p in zip(all_gt, all_pred)) |
|
|
|
precision, recall, f1, _ = precision_recall_fscore_support(all_gt, all_pred, average='binary', zero_division=0) |
|
|
|
print(f"{model_name} - TP: {tp}, FP: {fp}, FN: {fn}") |
|
print(f"{model_name} - Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}") |
|
|
|
return precision, recall, f1 |
|
|
|
|
|
if __name__ == "__main__": |
|
dataset_path = "./data/enni_salt_for_segmentation/test.csv" |
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nEvaluating SaT segmentation model...") |
|
sat_precision, sat_recall, sat_f1 = eval_segmentation( |
|
dataset_path, segment_SaT, "SaT" |
|
) |
|
|
|
print("\nEvaluating SaT_cunit_3l segmentation model...") |
|
sat_cunit_3l_precision, sat_cunit_3l_recall, sat_cunit_3l_f1 = eval_segmentation( |
|
dataset_path, segment_SaT_cunit_3l, "SaT_cunit_3l" |
|
) |
|
|
|
print("\nEvaluating SaT_cunit_12l segmentation model...") |
|
sat_cunit_12l_precision, sat_cunit_12l_recall, sat_cunit_12l_f1 = eval_segmentation( |
|
dataset_path, segment_SaT_cunit_12l, "SaT_cunit_12l" |
|
) |
|
|
|
print("\nEvaluating SaT_cunit_3l_r32a64 segmentation model...") |
|
sat_cunit_3l_r32a64_precision, sat_cunit_3l_r32a64_recall, sat_cunit_3l_r32a64_f1 = eval_segmentation( |
|
dataset_path, segment_SaT_cunit_3l_r32a64, "SaT_cunit_3l_r32a64" |
|
) |
|
|
|
print("\nEvaluating SaT_cunit_3l_r64a128 segmentation model...") |
|
sat_cunit_3l_r64a128_precision, sat_cunit_3l_r64a128_recall, sat_cunit_3l_r64a128_f1 = eval_segmentation( |
|
dataset_path, segment_SaT_cunit_3l_r64a128, "SaT_cunit_3l_r64a128" |
|
) |
|
|
|
print("\nEvaluating SaT_cunit_3l_no_shuffle segmentation model...") |
|
sat_cunit_3l_no_shuffle_precision, sat_cunit_3l_no_shuffle_recall, sat_cunit_3l_no_shuffle_f1 = eval_segmentation( |
|
dataset_path, segment_SaT_cunit_3l_no_shuffle, "SaT_cunit_3l_no_shuffle" |
|
) |
|
|
|
print("\n" + "="*80) |
|
print("COMPARISON RESULTS:") |
|
print("="*80) |
|
|
|
print(f"SaT - Precision: {sat_precision:.3f}, Recall: {sat_recall:.3f}, F1: {sat_f1:.3f}") |
|
print(f"SaT_cunit_3l - Precision: {sat_cunit_3l_precision:.3f}, Recall: {sat_cunit_3l_recall:.3f}, F1: {sat_cunit_3l_f1:.3f}") |
|
print(f"SaT_cunit_12l - Precision: {sat_cunit_12l_precision:.3f}, Recall: {sat_cunit_12l_recall:.3f}, F1: {sat_cunit_12l_f1:.3f}") |
|
print(f"SaT_cunit_3l_r32a64 - Precision: {sat_cunit_3l_r32a64_precision:.3f}, Recall: {sat_cunit_3l_r32a64_recall:.3f}, F1: {sat_cunit_3l_r32a64_f1:.3f}") |
|
print(f"SaT_cunit_3l_r64a128 - Precision: {sat_cunit_3l_r64a128_precision:.3f}, Recall: {sat_cunit_3l_r64a128_recall:.3f}, F1: {sat_cunit_3l_r64a128_f1:.3f}") |
|
print(f"SaT_cunit_3l_no_shuffle - Precision: {sat_cunit_3l_no_shuffle_precision:.3f}, Recall: {sat_cunit_3l_no_shuffle_recall:.3f}, F1: {sat_cunit_3l_no_shuffle_f1:.3f}") |