Mirror / src /analyze.py
Spico's picture
update
5953ef9
raw
history blame
No virus
4.72 kB
from collections import defaultdict
from rex.metrics.tagging import tagging_prf1
from rex.utils.io import load_jsonlines
from rex.utils.position import find_all_positions
def main():
middle_filepath = "outputs/InstructBert_TagSpan_DebertaV3Base_ACE05EN_labelmap_Rel_updateTag_bs32/middle/test.final.jsonl"
data = load_jsonlines(middle_filepath)
for ins in data:
gold = ins["gold"]
pred = ins["pred"]
if gold["spans"] != pred["spans"]:
breakpoint()
def check_ent_string_matching_upper_bound(filepath: str, strategy: str = "first"):
def _check_overlap(x, y):
if x[0] > y[1] or y[0] > x[1]:
return False
else:
return True
data = load_jsonlines(filepath)
golds = []
preds = []
for ins in data:
text = ins["text"]
gold_ents = ins["ans"]["ent"]
gold_ents = list(
set([(ent["text"], ent["type"], tuple(ent["span"])) for ent in gold_ents])
)
gold_ents.sort(key=lambda x: len(x[0]), reverse=True)
pred_ents = []
matched = set()
for gold_ent in gold_ents:
ent_string = gold_ent[0]
ent_type = gold_ent[1]
positions = find_all_positions(text, ent_string)
if strategy == "first":
for position in positions:
if (ent_type, position) not in matched:
matched.add((ent_type, position))
pred_ents.append((ent_string, ent_type, tuple(position)))
else:
flag = False
for position in positions:
for _, g in matched:
if _check_overlap(g, position):
flag = True
if flag:
continue
if (ent_type, position) not in matched:
matched.add((ent_type, position))
pred_ents.append((ent_string, ent_type, tuple(position)))
break
golds.append(gold_ents)
preds.append(pred_ents)
results = tagging_prf1(golds, preds)
print(f"filepath: {filepath}, Strategy: {strategy}")
print(f"Results: {results['micro']}")
def check_rel_tanl_upper_bound(filepath):
data = load_jsonlines(filepath)
golds = []
preds = []
for ins in data:
text = ins["text"]
gold_rels = ins["ans"]["rel"]
ent_text_to_spans = defaultdict(set)
for ent in ins["ans"]["ent"]:
ent_text_to_spans[ent["text"]].add(tuple(ent["span"]))
gold_rels = list(
set(
[
(
tuple(rel["head"]["span"]),
rel["relation"],
tuple(rel["tail"]["span"]),
)
for rel in gold_rels
]
)
)
pred_rels = []
for pred_rel in ins["ans"]["rel"]:
# pred_triple = ()
tail_text = pred_rel["tail"]["text"]
if (
tail_text in ent_text_to_spans
and len(ent_text_to_spans[tail_text]) == 1
):
tail_span = list(ent_text_to_spans[tail_text])[0]
pred_rels.append(
(tuple(pred_rel["head"]["span"]), pred_rel["relation"], tail_span)
)
# if tail_text in ent_text_to_spans:
# tail_span = list(ent_text_to_spans[tail_text])[0]
# else:
# tail_span = find_all_positions(text, tail_text)[0]
# pred_rels.append((tuple(pred_rel["head"]["span"]), pred_rel["relation"], tail_span))
golds.append(gold_rels)
preds.append(pred_rels)
results = tagging_prf1(golds, preds)
print(f"filepath: {filepath}")
print(f"Results: {results['micro']}")
if __name__ == "__main__":
# main()
# for filepath in [
# "/data/tzhu/Mirror/resources/Mirror/uie/ent/ace04/test.jsonl",
# "/data/tzhu/Mirror/resources/Mirror/uie/ent/ace05/test.jsonl",
# "/data/tzhu/Mirror/resources/Mirror/uie/ent/conll03/test.jsonl",
# ]:
# for strategy in ["first", "longer_first"]:
# check_ent_string_matching_upper_bound(filepath, strategy)
for filepath in [
"/data/tzhu/Mirror/resources/Mirror/uie/rel/ace05-rel/test.jsonl",
"/data/tzhu/Mirror/resources/Mirror/uie/rel/conll04/test.jsonl",
"/data/tzhu/Mirror/resources/Mirror/uie/rel/nyt/test.jsonl",
"/data/tzhu/Mirror/resources/Mirror/uie/rel/scierc/test.jsonl",
]:
check_rel_tanl_upper_bound(filepath)