|
"""Parallel beam search module for online simulation.""" |
|
|
|
import logging |
|
from pathlib import Path |
|
from typing import List |
|
|
|
import yaml |
|
|
|
import torch |
|
|
|
from espnet.nets.batch_beam_search import BatchBeamSearch |
|
from espnet.nets.beam_search import Hypothesis |
|
from espnet.nets.e2e_asr_common import end_detect |
|
|
|
|
|
class BatchBeamSearchOnlineSim(BatchBeamSearch): |
|
"""Online beam search implementation. |
|
|
|
This simulates streaming decoding. |
|
It requires encoded features of entire utterance and |
|
extracts block by block from it as it shoud be done |
|
in streaming processing. |
|
This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR |
|
WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH" |
|
(https://arxiv.org/abs/2006.14941). |
|
""" |
|
|
|
def set_streaming_config(self, asr_config: str): |
|
"""Set config file for streaming decoding. |
|
|
|
Args: |
|
asr_config (str): The config file for asr training |
|
|
|
""" |
|
train_config_file = Path(asr_config) |
|
self.block_size = None |
|
self.hop_size = None |
|
self.look_ahead = None |
|
config = None |
|
with train_config_file.open("r", encoding="utf-8") as f: |
|
args = yaml.safe_load(f) |
|
if "encoder_conf" in args.keys(): |
|
if "block_size" in args["encoder_conf"].keys(): |
|
self.block_size = args["encoder_conf"]["block_size"] |
|
if "hop_size" in args["encoder_conf"].keys(): |
|
self.hop_size = args["encoder_conf"]["hop_size"] |
|
if "look_ahead" in args["encoder_conf"].keys(): |
|
self.look_ahead = args["encoder_conf"]["look_ahead"] |
|
elif "config" in args.keys(): |
|
config = args["config"] |
|
if config is None: |
|
logging.info( |
|
"Cannot find config file for streaming decoding: " |
|
+ "apply batch beam search instead." |
|
) |
|
return |
|
if ( |
|
self.block_size is None or self.hop_size is None or self.look_ahead is None |
|
) and config is not None: |
|
config_file = Path(config) |
|
with config_file.open("r", encoding="utf-8") as f: |
|
args = yaml.safe_load(f) |
|
if "encoder_conf" in args.keys(): |
|
enc_args = args["encoder_conf"] |
|
if enc_args and "block_size" in enc_args: |
|
self.block_size = enc_args["block_size"] |
|
if enc_args and "hop_size" in enc_args: |
|
self.hop_size = enc_args["hop_size"] |
|
if enc_args and "look_ahead" in enc_args: |
|
self.look_ahead = enc_args["look_ahead"] |
|
|
|
def set_block_size(self, block_size: int): |
|
"""Set block size for streaming decoding. |
|
|
|
Args: |
|
block_size (int): The block size of encoder |
|
""" |
|
self.block_size = block_size |
|
|
|
def set_hop_size(self, hop_size: int): |
|
"""Set hop size for streaming decoding. |
|
|
|
Args: |
|
hop_size (int): The hop size of encoder |
|
""" |
|
self.hop_size = hop_size |
|
|
|
def set_look_ahead(self, look_ahead: int): |
|
"""Set look ahead size for streaming decoding. |
|
|
|
Args: |
|
look_ahead (int): The look ahead size of encoder |
|
""" |
|
self.look_ahead = look_ahead |
|
|
|
def forward( |
|
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0 |
|
) -> List[Hypothesis]: |
|
"""Perform beam search. |
|
|
|
Args: |
|
x (torch.Tensor): Encoded speech feature (T, D) |
|
maxlenratio (float): Input length ratio to obtain max output length. |
|
If maxlenratio=0.0 (default), it uses a end-detect function |
|
to automatically find maximum hypothesis lengths |
|
minlenratio (float): Input length ratio to obtain min output length. |
|
|
|
Returns: |
|
list[Hypothesis]: N-best decoding results |
|
|
|
""" |
|
self.conservative = True |
|
|
|
if self.block_size and self.hop_size and self.look_ahead: |
|
cur_end_frame = int(self.block_size - self.look_ahead) |
|
else: |
|
cur_end_frame = x.shape[0] |
|
process_idx = 0 |
|
if cur_end_frame < x.shape[0]: |
|
h = x.narrow(0, 0, cur_end_frame) |
|
else: |
|
h = x |
|
|
|
|
|
if maxlenratio == 0: |
|
maxlen = x.shape[0] |
|
else: |
|
maxlen = max(1, int(maxlenratio * x.size(0))) |
|
minlen = int(minlenratio * x.size(0)) |
|
logging.info("decoder input length: " + str(x.shape[0])) |
|
logging.info("max output length: " + str(maxlen)) |
|
logging.info("min output length: " + str(minlen)) |
|
|
|
|
|
running_hyps = self.init_hyp(h) |
|
prev_hyps = [] |
|
ended_hyps = [] |
|
prev_repeat = False |
|
|
|
continue_decode = True |
|
|
|
while continue_decode: |
|
move_to_next_block = False |
|
if cur_end_frame < x.shape[0]: |
|
h = x.narrow(0, 0, cur_end_frame) |
|
else: |
|
h = x |
|
|
|
|
|
self.extend(h, running_hyps) |
|
|
|
while process_idx < maxlen: |
|
logging.debug("position " + str(process_idx)) |
|
best = self.search(running_hyps, h) |
|
|
|
if process_idx == maxlen - 1: |
|
|
|
running_hyps = self.post_process( |
|
process_idx, maxlen, maxlenratio, best, ended_hyps |
|
) |
|
n_batch = best.yseq.shape[0] |
|
local_ended_hyps = [] |
|
is_local_eos = ( |
|
best.yseq[torch.arange(n_batch), best.length - 1] == self.eos |
|
) |
|
for i in range(is_local_eos.shape[0]): |
|
if is_local_eos[i]: |
|
hyp = self._select(best, i) |
|
local_ended_hyps.append(hyp) |
|
|
|
|
|
|
|
|
|
elif ( |
|
not prev_repeat |
|
and best.yseq[i, -1] in best.yseq[i, :-1] |
|
and cur_end_frame < x.shape[0] |
|
): |
|
move_to_next_block = True |
|
prev_repeat = True |
|
if maxlenratio == 0.0 and end_detect( |
|
[lh.asdict() for lh in local_ended_hyps], process_idx |
|
): |
|
logging.info(f"end detected at {process_idx}") |
|
continue_decode = False |
|
break |
|
if len(local_ended_hyps) > 0 and cur_end_frame < x.shape[0]: |
|
move_to_next_block = True |
|
|
|
if move_to_next_block: |
|
if ( |
|
self.hop_size |
|
and cur_end_frame + int(self.hop_size) + int(self.look_ahead) |
|
< x.shape[0] |
|
): |
|
cur_end_frame += int(self.hop_size) |
|
else: |
|
cur_end_frame = x.shape[0] |
|
logging.debug("Going to next block: %d", cur_end_frame) |
|
if process_idx > 1 and len(prev_hyps) > 0 and self.conservative: |
|
running_hyps = prev_hyps |
|
process_idx -= 1 |
|
prev_hyps = [] |
|
break |
|
|
|
prev_repeat = False |
|
prev_hyps = running_hyps |
|
running_hyps = self.post_process( |
|
process_idx, maxlen, maxlenratio, best, ended_hyps |
|
) |
|
|
|
if cur_end_frame >= x.shape[0]: |
|
for hyp in local_ended_hyps: |
|
ended_hyps.append(hyp) |
|
|
|
if len(running_hyps) == 0: |
|
logging.info("no hypothesis. Finish decoding.") |
|
continue_decode = False |
|
break |
|
else: |
|
logging.debug(f"remained hypotheses: {len(running_hyps)}") |
|
|
|
process_idx += 1 |
|
|
|
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) |
|
|
|
if len(nbest_hyps) == 0: |
|
logging.warning( |
|
"there is no N-best results, perform recognition " |
|
"again with smaller minlenratio." |
|
) |
|
return ( |
|
[] |
|
if minlenratio < 0.1 |
|
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) |
|
) |
|
|
|
|
|
best = nbest_hyps[0] |
|
for k, v in best.scores.items(): |
|
logging.info( |
|
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" |
|
) |
|
logging.info(f"total log probability: {best.score:.2f}") |
|
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") |
|
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") |
|
if self.token_list is not None: |
|
logging.info( |
|
"best hypo: " |
|
+ "".join([self.token_list[x] for x in best.yseq[1:-1]]) |
|
+ "\n" |
|
) |
|
return nbest_hyps |
|
|
|
def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]: |
|
"""Extend probabilities and states with more encoded chunks. |
|
|
|
Args: |
|
x (torch.Tensor): The extended encoder output feature |
|
hyps (Hypothesis): Current list of hypothesis |
|
|
|
Returns: |
|
Hypothesis: The exxtended hypothesis |
|
|
|
""" |
|
for k, d in self.scorers.items(): |
|
if hasattr(d, "extend_prob"): |
|
d.extend_prob(x) |
|
if hasattr(d, "extend_state"): |
|
hyps.states[k] = d.extend_state(hyps.states[k]) |
|
|