Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
from collections import Counter | |
import torch | |
from torch import nn | |
# import seqeval | |
from .utils_ner import get_entities | |
class metrics_mlm_acc(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, logits, labels, masked_lm_metric): | |
# if len(list(logits.shape))==3: | |
mask_label_size = 0 | |
for i in masked_lm_metric: | |
for j in i: | |
if j > 0: | |
mask_label_size += 1 | |
y_pred = torch.argmax(logits, dim=-1) | |
y_pred = y_pred.view(size=(-1,)) | |
y_true = labels.view(size=(-1,)) | |
masked_lm_metric = masked_lm_metric.view(size=(-1,)) | |
corr = torch.eq(y_pred, y_true) | |
corr = torch.multiply(masked_lm_metric, corr) | |
acc = torch.sum(corr.float())/mask_label_size | |
return acc | |
class SeqEntityScore(object): | |
def __init__(self, id2label, markup='bios', middle_prefix='I-'): | |
self.id2label = id2label | |
self.markup = markup | |
self.middle_prefix = middle_prefix | |
self.reset() | |
def reset(self): | |
self.origins = [] | |
self.founds = [] | |
self.rights = [] | |
def compute(self, origin, found, right): | |
recall = 0 if origin == 0 else (right / origin) | |
precision = 0 if found == 0 else (right / found) | |
f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall) | |
return recall, precision, f1 | |
def result(self): | |
class_info = {} | |
origin_counter = Counter([x[0] for x in self.origins]) | |
found_counter = Counter([x[0] for x in self.founds]) | |
right_counter = Counter([x[0] for x in self.rights]) | |
for type_, count in origin_counter.items(): | |
origin = count | |
found = found_counter.get(type_, 0) | |
right = right_counter.get(type_, 0) | |
# print('origin:', origin, ' found:', found, ' right:', right) | |
recall, precision, f1 = self.compute(origin, found, right) | |
class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)} | |
origin = len(self.origins) | |
found = len(self.founds) | |
right = len(self.rights) | |
recall, precision, f1 = self.compute(origin, found, right) | |
return {'acc': precision, 'recall': recall, 'f1': f1}, class_info | |
def update(self, label_paths, pred_paths): | |
''' | |
labels_paths: [[],[],[],....] | |
pred_paths: [[],[],[],.....] | |
:param label_paths: | |
:param pred_paths: | |
:return: | |
Example: | |
>>> labels_paths = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] | |
>>> pred_paths = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] | |
''' | |
for label_path, pre_path in zip(label_paths, pred_paths): | |
label_entities = get_entities(label_path, self.id2label, self.markup, self.middle_prefix) | |
pre_entities = get_entities(pre_path, self.id2label, self.markup, self.middle_prefix) | |
# print('label:', label_path, ',label_entities: ', label_entities) | |
# print('pred:', pre_path, ',pre_entities: ', pre_entities) | |
self.origins.extend(label_entities) | |
self.founds.extend(pre_entities) | |
self.rights.extend([pre_entity for pre_entity in pre_entities if pre_entity in label_entities]) | |