File size: 10,144 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
"""Parallel beam search module for online simulation."""

import logging
from pathlib import Path
from typing import List

import yaml

import torch

from espnet.nets.batch_beam_search import BatchBeamSearch
from espnet.nets.beam_search import Hypothesis
from espnet.nets.e2e_asr_common import end_detect


class BatchBeamSearchOnlineSim(BatchBeamSearch):
    """Online beam search implementation.

    This simulates streaming decoding.
    It requires encoded features of entire utterance and
    extracts block by block from it as it shoud be done
    in streaming processing.
    This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR
    WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH"
    (https://arxiv.org/abs/2006.14941).
    """

    def set_streaming_config(self, asr_config: str):
        """Set config file for streaming decoding.

        Args:
            asr_config (str): The config file for asr training

        """
        train_config_file = Path(asr_config)
        self.block_size = None
        self.hop_size = None
        self.look_ahead = None
        config = None
        with train_config_file.open("r", encoding="utf-8") as f:
            args = yaml.safe_load(f)
            if "encoder_conf" in args.keys():
                if "block_size" in args["encoder_conf"].keys():
                    self.block_size = args["encoder_conf"]["block_size"]
                if "hop_size" in args["encoder_conf"].keys():
                    self.hop_size = args["encoder_conf"]["hop_size"]
                if "look_ahead" in args["encoder_conf"].keys():
                    self.look_ahead = args["encoder_conf"]["look_ahead"]
            elif "config" in args.keys():
                config = args["config"]
                if config is None:
                    logging.info(
                        "Cannot find config file for streaming decoding: "
                        + "apply batch beam search instead."
                    )
                    return
        if (
            self.block_size is None or self.hop_size is None or self.look_ahead is None
        ) and config is not None:
            config_file = Path(config)
            with config_file.open("r", encoding="utf-8") as f:
                args = yaml.safe_load(f)
            if "encoder_conf" in args.keys():
                enc_args = args["encoder_conf"]
            if enc_args and "block_size" in enc_args:
                self.block_size = enc_args["block_size"]
            if enc_args and "hop_size" in enc_args:
                self.hop_size = enc_args["hop_size"]
            if enc_args and "look_ahead" in enc_args:
                self.look_ahead = enc_args["look_ahead"]

    def set_block_size(self, block_size: int):
        """Set block size for streaming decoding.

        Args:
            block_size (int): The block size of encoder
        """
        self.block_size = block_size

    def set_hop_size(self, hop_size: int):
        """Set hop size for streaming decoding.

        Args:
            hop_size (int): The hop size of encoder
        """
        self.hop_size = hop_size

    def set_look_ahead(self, look_ahead: int):
        """Set look ahead size for streaming decoding.

        Args:
            look_ahead (int): The look ahead size of encoder
        """
        self.look_ahead = look_ahead

    def forward(
        self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
    ) -> List[Hypothesis]:
        """Perform beam search.

        Args:
            x (torch.Tensor): Encoded speech feature (T, D)
            maxlenratio (float): Input length ratio to obtain max output length.
                If maxlenratio=0.0 (default), it uses a end-detect function
                to automatically find maximum hypothesis lengths
            minlenratio (float): Input length ratio to obtain min output length.

        Returns:
            list[Hypothesis]: N-best decoding results

        """
        self.conservative = True  # always true

        if self.block_size and self.hop_size and self.look_ahead:
            cur_end_frame = int(self.block_size - self.look_ahead)
        else:
            cur_end_frame = x.shape[0]
        process_idx = 0
        if cur_end_frame < x.shape[0]:
            h = x.narrow(0, 0, cur_end_frame)
        else:
            h = x

        # set length bounds
        if maxlenratio == 0:
            maxlen = x.shape[0]
        else:
            maxlen = max(1, int(maxlenratio * x.size(0)))
        minlen = int(minlenratio * x.size(0))
        logging.info("decoder input length: " + str(x.shape[0]))
        logging.info("max output length: " + str(maxlen))
        logging.info("min output length: " + str(minlen))

        # main loop of prefix search
        running_hyps = self.init_hyp(h)
        prev_hyps = []
        ended_hyps = []
        prev_repeat = False

        continue_decode = True

        while continue_decode:
            move_to_next_block = False
            if cur_end_frame < x.shape[0]:
                h = x.narrow(0, 0, cur_end_frame)
            else:
                h = x

            # extend states for ctc
            self.extend(h, running_hyps)

            while process_idx < maxlen:
                logging.debug("position " + str(process_idx))
                best = self.search(running_hyps, h)

                if process_idx == maxlen - 1:
                    # end decoding
                    running_hyps = self.post_process(
                        process_idx, maxlen, maxlenratio, best, ended_hyps
                    )
                n_batch = best.yseq.shape[0]
                local_ended_hyps = []
                is_local_eos = (
                    best.yseq[torch.arange(n_batch), best.length - 1] == self.eos
                )
                for i in range(is_local_eos.shape[0]):
                    if is_local_eos[i]:
                        hyp = self._select(best, i)
                        local_ended_hyps.append(hyp)
                    # NOTE(tsunoo): check repetitions here
                    # This is a implicit implementation of
                    # Eq (11) in https://arxiv.org/abs/2006.14941
                    # A flag prev_repeat is used instead of using set
                    elif (
                        not prev_repeat
                        and best.yseq[i, -1] in best.yseq[i, :-1]
                        and cur_end_frame < x.shape[0]
                    ):
                        move_to_next_block = True
                        prev_repeat = True
                if maxlenratio == 0.0 and end_detect(
                    [lh.asdict() for lh in local_ended_hyps], process_idx
                ):
                    logging.info(f"end detected at {process_idx}")
                    continue_decode = False
                    break
                if len(local_ended_hyps) > 0 and cur_end_frame < x.shape[0]:
                    move_to_next_block = True

                if move_to_next_block:
                    if (
                        self.hop_size
                        and cur_end_frame + int(self.hop_size) + int(self.look_ahead)
                        < x.shape[0]
                    ):
                        cur_end_frame += int(self.hop_size)
                    else:
                        cur_end_frame = x.shape[0]
                    logging.debug("Going to next block: %d", cur_end_frame)
                    if process_idx > 1 and len(prev_hyps) > 0 and self.conservative:
                        running_hyps = prev_hyps
                        process_idx -= 1
                        prev_hyps = []
                    break

                prev_repeat = False
                prev_hyps = running_hyps
                running_hyps = self.post_process(
                    process_idx, maxlen, maxlenratio, best, ended_hyps
                )

                if cur_end_frame >= x.shape[0]:
                    for hyp in local_ended_hyps:
                        ended_hyps.append(hyp)

                if len(running_hyps) == 0:
                    logging.info("no hypothesis. Finish decoding.")
                    continue_decode = False
                    break
                else:
                    logging.debug(f"remained hypotheses: {len(running_hyps)}")
                # increment number
                process_idx += 1

        nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
        # check the number of hypotheses reaching to eos
        if len(nbest_hyps) == 0:
            logging.warning(
                "there is no N-best results, perform recognition "
                "again with smaller minlenratio."
            )
            return (
                []
                if minlenratio < 0.1
                else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
            )

        # report the best result
        best = nbest_hyps[0]
        for k, v in best.scores.items():
            logging.info(
                f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
            )
        logging.info(f"total log probability: {best.score:.2f}")
        logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
        logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
        if self.token_list is not None:
            logging.info(
                "best hypo: "
                + "".join([self.token_list[x] for x in best.yseq[1:-1]])
                + "\n"
            )
        return nbest_hyps

    def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]:
        """Extend probabilities and states with more encoded chunks.

        Args:
            x (torch.Tensor): The extended encoder output feature
            hyps (Hypothesis): Current list of hypothesis

        Returns:
            Hypothesis: The exxtended hypothesis

        """
        for k, d in self.scorers.items():
            if hasattr(d, "extend_prob"):
                d.extend_prob(x)
            if hasattr(d, "extend_state"):
                hyps.states[k] = d.extend_state(hyps.states[k])