|
from pathlib import Path |
|
|
|
import pandas as pd |
|
from rex.utils.initialization import set_seed_and_log_path |
|
from rex.utils.io import load_json |
|
from rich.console import Console |
|
from rich.table import Table |
|
|
|
from src.task import SchemaGuidedInstructBertTask |
|
|
|
set_seed_and_log_path(log_path="tmp_eval.log") |
|
|
|
|
|
if __name__ == "__main__": |
|
task_dir = "mirror_outputs/Mirror_Pretrain_AllExcluded_2" |
|
|
|
task: SchemaGuidedInstructBertTask = SchemaGuidedInstructBertTask.from_taskdir( |
|
task_dir, |
|
load_best_model=True, |
|
initialize=False, |
|
dump_configfile=False, |
|
update_config={ |
|
"regenerate_cache": True, |
|
"eval_on_data": ["dev"], |
|
"select_best_on_data": "dev", |
|
"select_best_by_key": "metric", |
|
"best_metric_field": "general_spans.micro.f1", |
|
"eval_batch_size": 32, |
|
}, |
|
) |
|
table = Table(title=task_dir) |
|
|
|
data_pairs = [ |
|
|
|
|
|
|
|
|
|
|
|
["ent_conll03_test", "resources/Mirror/uie/ent/conll03/test.jsonl"], |
|
|
|
["rel_conll04_test", "resources/Mirror/uie/rel/conll04/test.jsonl"], |
|
|
|
|
|
["event_ace05_test", "resources/Mirror/uie/event/ace05-evt/test.jsonl"], |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
eval_res = {"task": [], "dataset": [], "metric_val": []} |
|
table.add_column("Task", justify="left", style="cyan") |
|
table.add_column("Dataset", justify="left", style="magenta") |
|
table.add_column("Metric (%)", justify="right", style="green") |
|
for dname, fpath in data_pairs: |
|
dname = dname.lower() |
|
task.data_manager.update_datapath(dname, fpath) |
|
_, res = task.eval(dname, verbose=True, dump=True, dump_middle=True) |
|
|
|
if dname.startswith("ent_"): |
|
eval_res["task"].append("ent") |
|
eval_res["dataset"].append(dname) |
|
eval_res["metric_val"].append(res["ent"]["micro"]["f1"]) |
|
elif dname.startswith("rel_"): |
|
eval_res["task"].append("rel") |
|
eval_res["dataset"].append(dname) |
|
eval_res["metric_val"].append(res["rel"]["rel"]["micro"]["f1"]) |
|
elif dname.startswith("event_"): |
|
eval_res["task"].append("event") |
|
eval_res["dataset"].append(dname + "_tgg") |
|
eval_res["metric_val"].append(res["event"]["trigger_cls"]["f1"]) |
|
eval_res["task"].append("event") |
|
eval_res["dataset"].append(dname + "_arg") |
|
eval_res["metric_val"].append(res["event"]["arg_cls"]["f1"]) |
|
elif dname.startswith("absa_"): |
|
eval_res["task"].append("absa") |
|
eval_res["dataset"].append(dname) |
|
eval_res["metric_val"].append(res["rel"]["rel"]["micro"]["f1"]) |
|
elif dname.startswith("cls_"): |
|
eval_res["task"].append("cls") |
|
eval_res["dataset"].append(dname) |
|
if "_glue_" in dname: |
|
if "_cola" in dname: |
|
eval_res["metric_val"].append(res["cls"]["mcc"]) |
|
else: |
|
eval_res["metric_val"].append(res["cls"]["acc"]) |
|
else: |
|
eval_res["metric_val"].append(res["cls"]["mf1"]["micro"]["f1"]) |
|
elif dname.startswith("span"): |
|
eval_res["task"].append("span_em") |
|
eval_res["dataset"].append(dname) |
|
eval_res["metric_val"].append(res["span"]["em"]) |
|
eval_res["task"].append("span_f1") |
|
eval_res["dataset"].append(dname) |
|
eval_res["metric_val"].append(res["span"]["f1"]["f1"]) |
|
elif dname.startswith("discontinuous_ent"): |
|
eval_res["task"].append("discontinuous_ent") |
|
eval_res["dataset"].append(dname) |
|
eval_res["metric_val"].append(res["discontinuous_ent"]["micro"]["f1"]) |
|
elif dname.startswith("hyper_rel"): |
|
eval_res["task"].append("hyper_rel") |
|
eval_res["dataset"].append(dname) |
|
eval_res["metric_val"].append(res["hyper_rel"]["micro"]["f1"]) |
|
else: |
|
raise ValueError |
|
|
|
for i in range(len(eval_res["task"])): |
|
table.add_row( |
|
eval_res["task"][i], |
|
eval_res["dataset"][i], |
|
f"{100*eval_res['metric_val'][i]:.3f}", |
|
) |
|
|
|
console = Console() |
|
console.print(table) |
|
|
|
df = pd.DataFrame(eval_res) |
|
df.to_excel(task.measures_path.joinpath("data_eval_res.xlsx")) |
|
|