spell-correction / utils /metrics.py
Đặng Xuân Lộc
Add application file
9fab08e
raw
history blame
5.55 kB
from utils.logger import get_logger
import numpy as np
from rapidfuzz.distance.Levenshtein import normalized_distance
import multiprocessing
import time
import utils.alignment as alignment
def _get_mned_metric_from_TruePredict(true_text, predict_text):
return normalized_distance(predict_text, true_text)
def get_mned_metric_from_TruePredict(batch_true_text, batch_predict_text):
total_NMED = 0.0
count = 0
for true_text, predict_text in zip(batch_true_text, batch_predict_text):
total_NMED += _get_mned_metric_from_TruePredict(true_text, predict_text)
count += 1
return total_NMED / count
def get_metric_for_tfm(batch_predicts, batch_targets, batch_length):
num_correct, num_wrong = 0, 0
for predict, target, length in zip(batch_predicts, batch_targets, batch_length):
predict = predict[1:-1]
target = target[1:-1]
predict = np.array(predict[0:length])
target = np.array(target[0:length])
num_correct += np.sum(predict == target)
num_wrong += np.sum(predict != target)
return num_correct, num_wrong
def allign_seq2trueseq(seq, true_seq, gap_symbol = "-"):
prev_sep = None
next_sep = None
seq_list = []
true_list = []
accumulate_true_word = ""
accumulate_pred_word = ""
assert len(true_seq) == len(seq)
for i in range(len(true_seq)):
if true_seq[i] != " ":
accumulate_true_word += true_seq[i]
accumulate_pred_word += seq[i]
else:
if seq[i] == gap_symbol:
next_sep = gap_symbol
if prev_sep != None and prev_sep == gap_symbol:
accumulate_pred_word = "@@" + accumulate_pred_word
if next_sep != None and next_sep == gap_symbol:
accumulate_pred_word = accumulate_pred_word + "@@"
else:
next_sep = " "
if prev_sep != None and prev_sep == gap_symbol:
accumulate_pred_word = "@@" + accumulate_pred_word
if next_sep != None and next_sep == gap_symbol:
accumulate_pred_word = accumulate_pred_word + "@@"
true_list.append(accumulate_true_word.replace(gap_symbol, ""))
seq_list.append(accumulate_pred_word)
accumulate_pred_word = ""
accumulate_true_word = ""
prev_sep = next_sep
next_sep = None
return seq_list, true_list
def align_2seq2trueseq(wrong_text, pred_text, true_text, gap_symbol = "-"):
assert gap_symbol != None and len(gap_symbol) == 1
seq1, true_seq = alignment.needle(wrong_text, true_text, gap_symbol)
seq1_list, true_list = allign_seq2trueseq(seq1, true_seq, gap_symbol)
seq2, true_seq = alignment.needle(pred_text, true_text, gap_symbol)
seq2_list, _ = allign_seq2trueseq(seq2, true_seq, gap_symbol)
return list(zip(seq1_list, seq2_list, true_list))
def _get_metric_from_TrueWrongPredictV3(true_text, wrong_text, predict_text, vocab = None):
gap_symbol = None
if vocab != None:
all_symbols = set(list(vocab.chartoken2idx.keys())[4:])
symbols = set(list(wrong_text + predict_text + true_text))
usable_symbols = all_symbols.difference(symbols)
assert len(usable_symbols) > 0
if "-" not in usable_symbols:
gap_symbol = usable_symbols.pop()
else:
gap_symbol = "-"
gap_symbol = gap_symbol if gap_symbol != None else "-"
alignment = align_2seq2trueseq(wrong_text, predict_text, true_text, gap_symbol)
TP, FP, FN = 0, 0, 0
for wrong, predict, true in alignment:
if wrong == true:
if predict[:-2] == true:
pass
elif predict != true:
if len(predict.split(" ")) == len(true.split(" ")):
FP += 1
else:
penalty = len(predict.split(" ")) - len(true.split(" "))
assert penalty > 0
FP += penalty
else:
if predict == true:
TP += 1
else:
if len(predict.split(" ")) == len(true.split(" ")):
FN += 1
else:
penalty = len(predict.split(" ")) - len(true.split(" "))
assert penalty > 0
FN += penalty
return TP, FP, FN
def worker_task(true_text, wrong_text, predict_text, vocab):
_TP, _FP, _FN = _get_metric_from_TrueWrongPredictV3(true_text, wrong_text, predict_text, vocab)
return (_TP, _FP, _FN)
from multiprocessing import Pool
def get_metric_from_TrueWrongPredictV3(batch_true_text, batch_wrong_text, batch_predict_text, vocab, twp_logger):
assert vocab != None
TPs, FPs, FNs = 0, 0, 0
with Pool(int(multiprocessing.cpu_count() / 3)) as pool:
data = [(true_text, wrong_text, pred_text, vocab) for true_text, wrong_text, pred_text in zip(batch_true_text, batch_wrong_text, batch_predict_text)]
result = pool.starmap_async(worker_task, data)
for i, result in enumerate(result.get()):
TPs += result[0]
FPs += result[1]
FNs += result[2]
if twp_logger:
twp_logger.log(batch_true_text[i], file_only=True)
twp_logger.log(batch_wrong_text[i], file_only=True)
twp_logger.log(batch_predict_text[i], file_only=True)
twp_logger.log(f"{result[0]} - {result[1]} - {result[2]}", file_only=True)
return TPs, FPs, FNs