Spaces:
Running
Running
from .rec_metric import RecMetric | |
class RecMPGMetric(object): | |
def __init__(self, | |
main_indicator='acc', | |
is_filter=False, | |
ignore_space=True, | |
stream=False, | |
with_ratio=False, | |
max_len=25, | |
max_ratio=4, | |
**kwargs): | |
self.main_indicator = main_indicator | |
self.is_filter = is_filter | |
self.ignore_space = ignore_space | |
self.eps = 1e-5 | |
self.char_metric = RecMetric(main_indicator=main_indicator, | |
is_filter=is_filter, | |
ignore_space=ignore_space, | |
stream=stream, | |
with_ratio=with_ratio, | |
max_len=max_len, | |
max_ratio=max_ratio) | |
self.bpe_metric = RecMetric(main_indicator=main_indicator, | |
is_filter=is_filter, | |
ignore_space=ignore_space, | |
stream=stream, | |
with_ratio=with_ratio, | |
max_len=max_len, | |
max_ratio=max_ratio) | |
self.wp_metric = RecMetric(main_indicator=main_indicator, | |
is_filter=is_filter, | |
ignore_space=ignore_space, | |
stream=stream, | |
with_ratio=with_ratio, | |
max_len=max_len, | |
max_ratio=max_ratio) | |
self.final_metric = RecMetric(main_indicator=main_indicator, | |
is_filter=is_filter, | |
ignore_space=ignore_space, | |
stream=stream, | |
with_ratio=with_ratio, | |
max_len=max_len, | |
max_ratio=max_ratio) | |
def __call__(self, | |
pred_label, | |
batch=None, | |
training=False, | |
*args, | |
**kwargs): | |
char_metric = self.char_metric((pred_label[0], pred_label[-1]), | |
batch, | |
training=training) | |
bpe_metric = self.bpe_metric((pred_label[1], pred_label[-1]), | |
batch, | |
training=training) | |
wp_metric = self.wp_metric((pred_label[2], pred_label[-1]), | |
batch, | |
training=training) | |
final_metric = self.final_metric((pred_label[3], pred_label[-1]), | |
batch, | |
training=training) | |
final_metric['char_acc'] = char_metric['acc'] | |
final_metric['char_norm_edit_dis'] = char_metric['norm_edit_dis'] | |
final_metric['bpe_acc'] = bpe_metric['acc'] | |
final_metric['bpe_norm_edit_dis'] = bpe_metric['norm_edit_dis'] | |
final_metric['wp_acc'] = wp_metric['acc'] | |
final_metric['wp_norm_edit_dis'] = wp_metric['norm_edit_dis'] | |
return final_metric | |
def get_metric(self): | |
""" | |
return metrics { | |
'acc': 0, | |
'norm_edit_dis': 0, | |
} | |
""" | |
char_metric = self.char_metric.get_metric() | |
bpe_metric = self.bpe_metric.get_metric() | |
wp_metric = self.wp_metric.get_metric() | |
final_metric = self.final_metric.get_metric() | |
final_metric['char_acc'] = char_metric['acc'] | |
final_metric['char_norm_edit_dis'] = char_metric['norm_edit_dis'] | |
final_metric['bpe_acc'] = bpe_metric['acc'] | |
final_metric['bpe_norm_edit_dis'] = bpe_metric['norm_edit_dis'] | |
final_metric['wp_acc'] = wp_metric['acc'] | |
final_metric['wp_norm_edit_dis'] = wp_metric['norm_edit_dis'] | |
return final_metric | |