Tom Aarsen
Add cloned GLiNER repository
914502f
raw
history blame
4.63 kB
from collections import defaultdict
import numpy as np
import torch
from seqeval.metrics.v1 import _prf_divide
def extract_tp_actual_correct(y_true, y_pred):
entities_true = defaultdict(set)
entities_pred = defaultdict(set)
for type_name, (start, end), idx in y_true:
entities_true[type_name].add((start, end, idx))
for type_name, (start, end), idx in y_pred:
entities_pred[type_name].add((start, end, idx))
target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys()))
tp_sum = np.array([], dtype=np.int32)
pred_sum = np.array([], dtype=np.int32)
true_sum = np.array([], dtype=np.int32)
for type_name in target_names:
entities_true_type = entities_true.get(type_name, set())
entities_pred_type = entities_pred.get(type_name, set())
tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type))
pred_sum = np.append(pred_sum, len(entities_pred_type))
true_sum = np.append(true_sum, len(entities_true_type))
return pred_sum, tp_sum, true_sum, target_names
def flatten_for_eval(y_true, y_pred):
all_true = []
all_pred = []
for i, (true, pred) in enumerate(zip(y_true, y_pred)):
all_true.extend([t + [i] for t in true])
all_pred.extend([p + [i] for p in pred])
return all_true, all_pred
def compute_prf(y_true, y_pred, average='micro'):
y_true, y_pred = flatten_for_eval(y_true, y_pred)
pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred)
if average == 'micro':
tp_sum = np.array([tp_sum.sum()])
pred_sum = np.array([pred_sum.sum()])
true_sum = np.array([true_sum.sum()])
precision = _prf_divide(
numerator=tp_sum,
denominator=pred_sum,
metric='precision',
modifier='predicted',
average=average,
warn_for=('precision', 'recall', 'f-score'),
zero_division='warn'
)
recall = _prf_divide(
numerator=tp_sum,
denominator=true_sum,
metric='recall',
modifier='true',
average=average,
warn_for=('precision', 'recall', 'f-score'),
zero_division='warn'
)
denominator = precision + recall
denominator[denominator == 0.] = 1
f_score = 2 * (precision * recall) / denominator
return {'precision': precision[0], 'recall': recall[0], 'f_score': f_score[0]}
class Evaluator:
def __init__(self, all_true, all_outs):
self.all_true = all_true
self.all_outs = all_outs
def get_entities_fr(self, ents):
all_ents = []
for s, e, lab in ents:
all_ents.append([lab, (s, e)])
return all_ents
def transform_data(self):
all_true_ent = []
all_outs_ent = []
for i, j in zip(self.all_true, self.all_outs):
e = self.get_entities_fr(i)
all_true_ent.append(e)
e = self.get_entities_fr(j)
all_outs_ent.append(e)
return all_true_ent, all_outs_ent
@torch.no_grad()
def evaluate(self):
all_true_typed, all_outs_typed = self.transform_data()
precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values()
output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n"
return output_str, f1
def is_nested(idx1, idx2):
# Return True if idx2 is nested inside idx1 or vice versa
return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1])
def has_overlapping(idx1, idx2):
overlapping = True
if idx1[:2] == idx2[:2]:
return overlapping
if (idx1[0] > idx2[1] or idx2[0] > idx1[1]):
overlapping = False
return overlapping
def has_overlapping_nested(idx1, idx2):
# Return True if idx1 and idx2 overlap, but neither is nested inside the other
if idx1[:2] == idx2[:2]:
return True
if ((idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2)) and idx1 != idx2:
return False
else:
return True
def greedy_search(spans, flat_ner=True): # start, end, class, score
if flat_ner:
has_ov = has_overlapping
else:
has_ov = has_overlapping_nested
new_list = []
span_prob = sorted(spans, key=lambda x: -x[-1])
for i in range(len(spans)):
b = span_prob[i]
flag = False
for new in new_list:
if has_ov(b[:-1], new):
flag = True
break
if not flag:
new_list.append(b[:-1])
new_list = sorted(new_list, key=lambda x: x[0])
return new_list