File size: 4,156 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from .rec_metric import RecMetric


class RecMPGMetric(object):

    def __init__(self,
                 main_indicator='acc',
                 is_filter=False,
                 ignore_space=True,
                 stream=False,
                 with_ratio=False,
                 max_len=25,
                 max_ratio=4,
                 **kwargs):
        self.main_indicator = main_indicator
        self.is_filter = is_filter
        self.ignore_space = ignore_space
        self.eps = 1e-5
        self.char_metric = RecMetric(main_indicator=main_indicator,
                                     is_filter=is_filter,
                                     ignore_space=ignore_space,
                                     stream=stream,
                                     with_ratio=with_ratio,
                                     max_len=max_len,
                                     max_ratio=max_ratio)
        self.bpe_metric = RecMetric(main_indicator=main_indicator,
                                    is_filter=is_filter,
                                    ignore_space=ignore_space,
                                    stream=stream,
                                    with_ratio=with_ratio,
                                    max_len=max_len,
                                    max_ratio=max_ratio)

        self.wp_metric = RecMetric(main_indicator=main_indicator,
                                   is_filter=is_filter,
                                   ignore_space=ignore_space,
                                   stream=stream,
                                   with_ratio=with_ratio,
                                   max_len=max_len,
                                   max_ratio=max_ratio)
        self.final_metric = RecMetric(main_indicator=main_indicator,
                                      is_filter=is_filter,
                                      ignore_space=ignore_space,
                                      stream=stream,
                                      with_ratio=with_ratio,
                                      max_len=max_len,
                                      max_ratio=max_ratio)

    def __call__(self,
                 pred_label,
                 batch=None,
                 training=False,
                 *args,
                 **kwargs):

        char_metric = self.char_metric((pred_label[0], pred_label[-1]),
                                       batch,
                                       training=training)
        bpe_metric = self.bpe_metric((pred_label[1], pred_label[-1]),
                                     batch,
                                     training=training)
        wp_metric = self.wp_metric((pred_label[2], pred_label[-1]),
                                   batch,
                                   training=training)
        final_metric = self.final_metric((pred_label[3], pred_label[-1]),
                                         batch,
                                         training=training)
        final_metric['char_acc'] = char_metric['acc']
        final_metric['char_norm_edit_dis'] = char_metric['norm_edit_dis']
        final_metric['bpe_acc'] = bpe_metric['acc']
        final_metric['bpe_norm_edit_dis'] = bpe_metric['norm_edit_dis']
        final_metric['wp_acc'] = wp_metric['acc']
        final_metric['wp_norm_edit_dis'] = wp_metric['norm_edit_dis']
        return final_metric

    def get_metric(self):
        """
        return metrics {
                 'acc': 0,
                 'norm_edit_dis': 0,
            }
        """
        char_metric = self.char_metric.get_metric()
        bpe_metric = self.bpe_metric.get_metric()
        wp_metric = self.wp_metric.get_metric()
        final_metric = self.final_metric.get_metric()
        final_metric['char_acc'] = char_metric['acc']
        final_metric['char_norm_edit_dis'] = char_metric['norm_edit_dis']
        final_metric['bpe_acc'] = bpe_metric['acc']
        final_metric['bpe_norm_edit_dis'] = bpe_metric['norm_edit_dis']
        final_metric['wp_acc'] = wp_metric['acc']
        final_metric['wp_norm_edit_dis'] = wp_metric['norm_edit_dis']
        return final_metric