Mirror / src /metric.py
Spico's picture
update
5953ef9
raw
history blame contribute delete
No virus
21 kB
from collections import defaultdict
from typing import Tuple
from rex.metrics import calc_p_r_f1_from_tp_fp_fn, safe_division
from rex.metrics.base import MetricBase
from rex.metrics.tagging import tagging_prf1
from rex.utils.batch import decompose_batch_into_instances
from rex.utils.iteration import windowed_queue_iter
from rex.utils.random import generate_random_string_with_datetime
from sklearn.metrics import accuracy_score, matthews_corrcoef
class MrcNERMetric(MetricBase):
def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple:
gold_instances = []
pred_instances = []
batch_gold = decompose_batch_into_instances(raw_batch)
assert len(batch_gold) == len(out_batch["pred"])
for i, gold in enumerate(batch_gold):
gold_instances.append(
{
"id": gold["id"],
"ents": {(gold["ent_type"], gent) for gent in gold["gold_ents"]},
}
)
pred_instances.append(
{
"id": gold["id"],
"ents": {(gold["ent_type"], pent) for pent in out_batch["pred"][i]},
}
)
return gold_instances, pred_instances
def calculate_scores(self, golds: list, preds: list) -> dict:
id2gold = defaultdict(set)
id2pred = defaultdict(set)
# aggregate all ents with diff queries before evaluating
for gold in golds:
id2gold[gold["id"]].update(gold["ents"])
for pred in preds:
id2pred[pred["id"]].update(pred["ents"])
assert len(id2gold) == len(id2pred)
gold_ents = []
pred_ents = []
for _id in id2gold:
gold_ents.append(id2gold[_id])
pred_ents.append(id2pred[_id])
return tagging_prf1(gold_ents, pred_ents, type_idx=0)
class MrcSpanMetric(MetricBase):
def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple:
gold_instances = []
pred_instances = []
batch_gold = decompose_batch_into_instances(raw_batch)
assert len(batch_gold) == len(out_batch["pred"])
for i, gold in enumerate(batch_gold):
gold_instances.append(
{
"id": gold["id"],
"spans": set(tuple(span) for span in gold["gold_spans"]),
}
)
pred_instances.append(
{
"id": gold["id"],
"spans": set(out_batch["pred"][i]),
}
)
return gold_instances, pred_instances
def calculate_scores(self, golds: list, preds: list) -> dict:
id2gold = defaultdict(set)
id2pred = defaultdict(set)
# aggregate all ents with diff queries before evaluating
for gold in golds:
id2gold[gold["id"]].update(gold["spans"])
for pred in preds:
id2pred[pred["id"]].update(pred["spans"])
assert len(id2gold) == len(id2pred)
gold_spans = []
pred_spans = []
for _id in id2gold:
gold_spans.append(id2gold[_id])
pred_spans.append(id2pred[_id])
return tagging_prf1(gold_spans, pred_spans, type_idx=None)
def calc_char_event(golds, preds):
"""
Calculate char-level event argument scores
References:
- https://aistudio.baidu.com/aistudio/competition/detail/46/0/submit-result
Args:
golds: a list of gold answers (a list of `event_list`), len=#data,
format is a list of `event_list`
preds: a list of pred answers, len=#data
"""
def _match_arg_char_f1(gold_arg, pred_args):
gtype, grole, gstring = gold_arg
gchars = set(gstring)
garg_len = len(gchars)
cands = []
for parg in pred_args:
if parg[0] == gtype and parg[1] == grole:
pchars = set(str(parg[-1]))
parg_len = len(pchars)
pmatch = len(pchars & gchars)
p = safe_division(pmatch, parg_len)
r = safe_division(pmatch, garg_len)
f1 = safe_division(2 * p * r, p + r)
cands.append(f1)
if len(cands) > 0:
f1 = sorted(cands)[-1]
return f1
else:
return 0.0
pscore = num_gargs = num_pargs = 0
for _golds, _preds in zip(golds, preds):
# _golds and _preds pair in one data instance
gold_args = []
pred_args = []
for gold in _golds:
for arg in gold.get("arguments", []):
gold_args.append(
(gold.get("event_type"), arg.get("role"), arg.get("argument"))
)
for pred in _preds:
for arg in pred.get("arguments", []):
pred_args.append(
(pred.get("event_type"), arg.get("role"), arg.get("argument"))
)
num_gargs += len(gold_args)
num_pargs += len(pred_args)
for gold_arg in gold_args:
pscore += _match_arg_char_f1(gold_arg, pred_args)
p = safe_division(pscore, num_pargs)
r = safe_division(pscore, num_gargs)
f1 = safe_division(2 * p * r, p + r)
return {
"p": p,
"r": r,
"f1": f1,
"pscore": pscore,
"num_pargs": num_pargs,
"num_gargs": num_gargs,
}
def calc_trigger_identification_metrics(golds, preds):
tp = fp = fn = 0
for _golds, _preds in zip(golds, preds):
gold_triggers = {gold["trigger"] for gold in _golds}
pred_triggers = {pred["trigger"] for pred in _preds}
tp += len(gold_triggers & pred_triggers)
fp += len(pred_triggers - gold_triggers)
fn += len(gold_triggers - pred_triggers)
metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn)
return metrics
def calc_trigger_classification_metrics(golds, preds):
tp = fp = fn = 0
for _golds, _preds in zip(golds, preds):
gold_tgg_cls = {(gold["trigger"], gold["event_type"]) for gold in _golds}
pred_tgg_cls = {(pred["trigger"], pred["event_type"]) for pred in _preds}
tp += len(gold_tgg_cls & pred_tgg_cls)
fp += len(pred_tgg_cls - gold_tgg_cls)
fn += len(gold_tgg_cls - pred_tgg_cls)
metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn)
return metrics
def calc_arg_identification_metrics(golds, preds):
"""Calculate argument identification metrics
Notice:
An entity could take different roles in an event,
so the base number must be calculated by
(arg, event type, pos, role)
"""
tp = fp = fn = 0
for _golds, _preds in zip(golds, preds):
gold_args = set()
pred_args = set()
for gold in _golds:
_args = {
(arg["role"], arg["argument"], gold["event_type"])
for arg in gold["arguments"]
}
gold_args.update(_args)
for pred in _preds:
_args = {
(arg["role"], arg["argument"], pred["event_type"])
for arg in pred["arguments"]
}
pred_args.update(_args)
# logic derived from OneIE
_tp = 0
_tp_fp = len(pred_args)
_tp_fn = len(gold_args)
_gold_args_wo_role = {_ga[1:] for _ga in gold_args}
for pred_arg in pred_args:
if pred_arg[1:] in _gold_args_wo_role:
_tp += 1
tp += _tp
fp += _tp_fp - _tp
fn += _tp_fn - _tp
metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn)
return metrics
def calc_arg_classification_metrics(golds, preds):
tp = fp = fn = 0
for _golds, _preds in zip(golds, preds):
gold_arg_cls = set()
pred_arg_cls = set()
for gold in _golds:
_args = {
(arg["argument"], arg["role"], gold["event_type"])
for arg in gold["arguments"]
}
gold_arg_cls.update(_args)
for pred in _preds:
_args = {
(arg["argument"], arg["role"], pred["event_type"])
for arg in pred["arguments"]
}
pred_arg_cls.update(_args)
tp += len(gold_arg_cls & pred_arg_cls)
fp += len(pred_arg_cls - gold_arg_cls)
fn += len(gold_arg_cls - pred_arg_cls)
metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn)
return metrics
def calc_ent(golds, preds):
"""
Args:
golds, preds: [(type, index list), ...]
"""
res = tagging_prf1(golds, preds, type_idx=0)
return res
def calc_rel(golds, preds):
gold_ents = []
pred_ents = []
for gold, pred in zip(golds, preds):
gold_ins_ents = []
for t in gold:
gold_ins_ents.extend(t[1:])
gold_ents.append(gold_ins_ents)
pred_ins_ents = []
for t in pred:
pred_ins_ents.extend(t[1:])
pred_ents.append(pred_ins_ents)
metrics = {
"ent": tagging_prf1(gold_ents, pred_ents, type_idx=None),
"rel": tagging_prf1(golds, preds, type_idx=None),
}
return metrics
def calc_cls(golds, preds):
metrics = {
"mcc": -1,
"acc": -1,
"mf1": tagging_prf1(golds, preds, type_idx=None),
}
y_true = []
y_pred = []
for gold, pred in zip(golds, preds):
y_true.append(" ".join(sorted(gold)))
y_pred.append(" ".join(sorted(pred)))
if y_true and y_pred:
metrics["acc"] = accuracy_score(y_true, y_pred)
else:
metrics["acc"] = 0.0
metrics["mcc"] = matthews_corrcoef(y_true, y_pred)
return metrics
def calc_span(golds, preds, mode="span"):
def _get_tokens(spans: list[tuple[tuple[int]]]) -> list[int]:
tokens = []
for span in spans:
for part in span:
_toks = []
if len(part) == 1:
_toks = [part[0]]
elif len(part) > 1:
if mode == "w2":
_toks = [*part]
elif mode == "span":
_toks = [*range(part[0], part[1] + 1)]
else:
raise ValueError
tokens.extend(_toks)
return tokens
metrics = {
"em": -1,
"f1": None,
}
acc_num = 0
tp = fp = fn = 0
for gold, pred in zip(golds, preds):
if gold == pred:
acc_num += 1
gold_tokens = _get_tokens(gold)
pred_tokens = _get_tokens(pred)
tp += len(set(gold_tokens) & set(pred_tokens))
fp += len(set(pred_tokens) - set(gold_tokens))
fn += len(set(gold_tokens) - set(pred_tokens))
if len(golds) > 0:
metrics["em"] = acc_num / len(golds)
else:
metrics["em"] = 0.0
metrics["f1"] = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn)
return metrics
class MultiPartSpanMetric(MetricBase):
def _encode_span_to_label_dict(self, span_to_label: dict) -> list:
span_to_label_list = []
for key, val in span_to_label.items():
span_to_label_list.append({"key": key, "val": val})
return span_to_label_list
def _decode_span_to_label(self, span_to_label_list: list) -> dict:
span_to_label = {}
for content in span_to_label_list:
span_to_label[tuple(content["key"])] = content["val"]
return span_to_label
def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple:
gold_instances = []
pred_instances = []
batch_gold = decompose_batch_into_instances(raw_batch)
assert len(batch_gold) == len(out_batch["pred"])
for i, gold in enumerate(batch_gold):
ins_id = gold["raw"].get("id", generate_random_string_with_datetime())
# encode to list to make the span_to_label dict json-serializable
# where the original dict key is a tuple
span_to_label_list = self._encode_span_to_label_dict(gold["span_to_label"])
gold["span_to_label"] = span_to_label_list
gold_instances.append(
{
"id": ins_id,
"span_to_label_list": span_to_label_list,
"raw_gold_content": gold,
"spans": set(
tuple(multi_part_span) for multi_part_span in gold["spans"]
),
}
)
pred_instances.append(
{
"id": ins_id,
"spans": set(
tuple(multi_part_span)
for multi_part_span in out_batch["pred"][i]
),
}
)
return gold_instances, pred_instances
def calculate_scores(self, golds: list, preds: list) -> dict:
# for general purpose evaluation
general_gold_spans, general_pred_spans = [], []
# cls task
gold_cls_list, pred_cls_list = [], []
# ent task
gold_ent_list, pred_ent_list = [], []
# rel task
gold_rel_list, pred_rel_list = [], []
# event task
gold_event_list, pred_event_list = [], []
# span task
gold_span_list, pred_span_list = [], []
# discon ent task
gold_discon_ent_list, pred_discon_ent_list = [], []
# hyper rel task
gold_hyper_rel_list, pred_hyper_rel_list = [], []
for gold, pred in zip(golds, preds):
general_gold_spans.append(gold["spans"])
general_pred_spans.append(pred["spans"])
span_to_label = self._decode_span_to_label(gold["span_to_label_list"])
gold_clses, pred_clses = [], []
gold_ents, pred_ents = [], []
gold_rels, pred_rels = [], []
gold_trigger_to_event = defaultdict(
lambda: {"event_type": "", "arguments": []}
)
pred_trigger_to_event = defaultdict(
lambda: {"event_type": "", "arguments": []}
)
gold_events, pred_events = [], []
gold_spans, pred_spans = [], []
gold_discon_ents, pred_discon_ents = [], []
gold_hyper_rels, pred_hyper_rels = [], []
raw_schema = gold["raw_gold_content"]["raw"]["schema"]
for span in gold["spans"]:
if span[0] in span_to_label:
label = span_to_label[span[0]]
if label["task"] == "cls" and len(span) == 1:
gold_clses.append(label["string"])
elif label["task"] == "ent" and len(span) == 2:
gold_ents.append((label["string"], *span[1:]))
elif label["task"] == "rel" and len(span) == 3:
gold_rels.append((label["string"], *span[1:]))
elif label["task"] == "event":
if label["type"] == "lm" and len(span) == 2:
gold_trigger_to_event[span[1]]["event_type"] = label["string"] # fmt: skip
elif label["type"] == "lr" and len(span) == 3:
gold_trigger_to_event[span[1]]["arguments"].append(
{"argument": span[2], "role": label["string"]}
)
elif label["task"] == "discontinuous_ent" and len(span) > 1:
gold_discon_ents.append((label["string"], *span[1:]))
elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: # fmt: skip
q_label = span_to_label[span[3]]
gold_hyper_rels.append((label["string"], span[1], span[2], q_label["string"], span[4])) # fmt: skip
else:
# span task has no labels
gold_spans.append(tuple(span))
for trigger, item in gold_trigger_to_event.items():
legal_roles = raw_schema["event"][item["event_type"]]
gold_events.append(
{
"trigger": trigger,
"event_type": item["event_type"],
"arguments": [
arg
for arg in filter(
lambda arg: arg["role"] in legal_roles,
item["arguments"],
)
],
}
)
for span in pred["spans"]:
if span[0] in span_to_label:
label = span_to_label[span[0]]
if label["task"] == "cls" and len(span) == 1:
pred_clses.append(label["string"])
elif label["task"] == "ent" and len(span) == 2:
pred_ents.append((label["string"], *span[1:]))
elif label["task"] == "rel" and len(span) == 3:
pred_rels.append((label["string"], *span[1:]))
elif label["task"] == "event":
if label["type"] == "lm" and len(span) == 2:
pred_trigger_to_event[span[1]]["event_type"] = label["string"] # fmt: skip
elif label["type"] == "lr" and len(span) == 3:
pred_trigger_to_event[span[1]]["arguments"].append(
{"argument": span[2], "role": label["string"]}
)
elif label["task"] == "discontinuous_ent" and len(span) > 1:
pred_discon_ents.append((label["string"], *span[1:]))
elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: # fmt: skip
q_label = span_to_label[span[3]]
pred_hyper_rels.append((label["string"], span[1], span[2], q_label["string"], span[4])) # fmt: skip
else:
# span task has no labels
pred_spans.append(tuple(span))
for trigger, item in pred_trigger_to_event.items():
if item["event_type"] not in raw_schema["event"]:
continue
legal_roles = raw_schema["event"][item["event_type"]]
pred_events.append(
{
"trigger": trigger,
"event_type": item["event_type"],
"arguments": [
arg
for arg in filter(
lambda arg: arg["role"] in legal_roles,
item["arguments"],
)
],
}
)
gold_cls_list.append(gold_clses)
pred_cls_list.append(pred_clses)
gold_ent_list.append(gold_ents)
pred_ent_list.append(pred_ents)
gold_rel_list.append(gold_rels)
pred_rel_list.append(pred_rels)
gold_event_list.append(gold_events)
pred_event_list.append(pred_events)
gold_span_list.append(gold_spans)
pred_span_list.append(pred_spans)
gold_discon_ent_list.append(gold_discon_ents)
pred_discon_ent_list.append(pred_discon_ents)
gold_hyper_rel_list.append(gold_hyper_rels)
pred_hyper_rel_list.append(pred_hyper_rels)
metrics = {
"general_spans": tagging_prf1(
general_gold_spans, general_pred_spans, type_idx=None
),
"cls": calc_cls(gold_cls_list, pred_cls_list),
"ent": calc_ent(gold_ent_list, pred_ent_list),
"rel": calc_rel(gold_rel_list, pred_rel_list),
"event": {
"trigger_id": calc_trigger_identification_metrics(
gold_event_list, pred_event_list
),
"trigger_cls": calc_trigger_classification_metrics(
gold_event_list, pred_event_list
),
"arg_id": calc_arg_identification_metrics(
gold_event_list, pred_event_list
),
"arg_cls": calc_arg_classification_metrics(
gold_event_list, pred_event_list
),
"char_event": calc_char_event(gold_event_list, pred_event_list),
},
"discontinuous_ent": tagging_prf1(
gold_discon_ent_list, pred_discon_ent_list, type_idx=None
),
"hyper_rel": tagging_prf1(
gold_hyper_rel_list, pred_hyper_rel_list, type_idx=None
),
# "span": tagging_prf1(gold_span_list, pred_span_list, type_idx=None),
"span": calc_span(gold_span_list, pred_span_list),
}
return metrics