conex / espnet /nets /batch_beam_search_online_sim.py
tobiasc's picture
Initial commit
ad16788
raw
history blame
10.1 kB
"""Parallel beam search module for online simulation."""
import logging
from pathlib import Path
from typing import List
import yaml
import torch
from espnet.nets.batch_beam_search import BatchBeamSearch
from espnet.nets.beam_search import Hypothesis
from espnet.nets.e2e_asr_common import end_detect
class BatchBeamSearchOnlineSim(BatchBeamSearch):
"""Online beam search implementation.
This simulates streaming decoding.
It requires encoded features of entire utterance and
extracts block by block from it as it shoud be done
in streaming processing.
This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR
WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH"
(https://arxiv.org/abs/2006.14941).
"""
def set_streaming_config(self, asr_config: str):
"""Set config file for streaming decoding.
Args:
asr_config (str): The config file for asr training
"""
train_config_file = Path(asr_config)
self.block_size = None
self.hop_size = None
self.look_ahead = None
config = None
with train_config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
if "encoder_conf" in args.keys():
if "block_size" in args["encoder_conf"].keys():
self.block_size = args["encoder_conf"]["block_size"]
if "hop_size" in args["encoder_conf"].keys():
self.hop_size = args["encoder_conf"]["hop_size"]
if "look_ahead" in args["encoder_conf"].keys():
self.look_ahead = args["encoder_conf"]["look_ahead"]
elif "config" in args.keys():
config = args["config"]
if config is None:
logging.info(
"Cannot find config file for streaming decoding: "
+ "apply batch beam search instead."
)
return
if (
self.block_size is None or self.hop_size is None or self.look_ahead is None
) and config is not None:
config_file = Path(config)
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
if "encoder_conf" in args.keys():
enc_args = args["encoder_conf"]
if enc_args and "block_size" in enc_args:
self.block_size = enc_args["block_size"]
if enc_args and "hop_size" in enc_args:
self.hop_size = enc_args["hop_size"]
if enc_args and "look_ahead" in enc_args:
self.look_ahead = enc_args["look_ahead"]
def set_block_size(self, block_size: int):
"""Set block size for streaming decoding.
Args:
block_size (int): The block size of encoder
"""
self.block_size = block_size
def set_hop_size(self, hop_size: int):
"""Set hop size for streaming decoding.
Args:
hop_size (int): The hop size of encoder
"""
self.hop_size = hop_size
def set_look_ahead(self, look_ahead: int):
"""Set look ahead size for streaming decoding.
Args:
look_ahead (int): The look ahead size of encoder
"""
self.look_ahead = look_ahead
def forward(
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
) -> List[Hypothesis]:
"""Perform beam search.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
minlenratio (float): Input length ratio to obtain min output length.
Returns:
list[Hypothesis]: N-best decoding results
"""
self.conservative = True # always true
if self.block_size and self.hop_size and self.look_ahead:
cur_end_frame = int(self.block_size - self.look_ahead)
else:
cur_end_frame = x.shape[0]
process_idx = 0
if cur_end_frame < x.shape[0]:
h = x.narrow(0, 0, cur_end_frame)
else:
h = x
# set length bounds
if maxlenratio == 0:
maxlen = x.shape[0]
else:
maxlen = max(1, int(maxlenratio * x.size(0)))
minlen = int(minlenratio * x.size(0))
logging.info("decoder input length: " + str(x.shape[0]))
logging.info("max output length: " + str(maxlen))
logging.info("min output length: " + str(minlen))
# main loop of prefix search
running_hyps = self.init_hyp(h)
prev_hyps = []
ended_hyps = []
prev_repeat = False
continue_decode = True
while continue_decode:
move_to_next_block = False
if cur_end_frame < x.shape[0]:
h = x.narrow(0, 0, cur_end_frame)
else:
h = x
# extend states for ctc
self.extend(h, running_hyps)
while process_idx < maxlen:
logging.debug("position " + str(process_idx))
best = self.search(running_hyps, h)
if process_idx == maxlen - 1:
# end decoding
running_hyps = self.post_process(
process_idx, maxlen, maxlenratio, best, ended_hyps
)
n_batch = best.yseq.shape[0]
local_ended_hyps = []
is_local_eos = (
best.yseq[torch.arange(n_batch), best.length - 1] == self.eos
)
for i in range(is_local_eos.shape[0]):
if is_local_eos[i]:
hyp = self._select(best, i)
local_ended_hyps.append(hyp)
# NOTE(tsunoo): check repetitions here
# This is a implicit implementation of
# Eq (11) in https://arxiv.org/abs/2006.14941
# A flag prev_repeat is used instead of using set
elif (
not prev_repeat
and best.yseq[i, -1] in best.yseq[i, :-1]
and cur_end_frame < x.shape[0]
):
move_to_next_block = True
prev_repeat = True
if maxlenratio == 0.0 and end_detect(
[lh.asdict() for lh in local_ended_hyps], process_idx
):
logging.info(f"end detected at {process_idx}")
continue_decode = False
break
if len(local_ended_hyps) > 0 and cur_end_frame < x.shape[0]:
move_to_next_block = True
if move_to_next_block:
if (
self.hop_size
and cur_end_frame + int(self.hop_size) + int(self.look_ahead)
< x.shape[0]
):
cur_end_frame += int(self.hop_size)
else:
cur_end_frame = x.shape[0]
logging.debug("Going to next block: %d", cur_end_frame)
if process_idx > 1 and len(prev_hyps) > 0 and self.conservative:
running_hyps = prev_hyps
process_idx -= 1
prev_hyps = []
break
prev_repeat = False
prev_hyps = running_hyps
running_hyps = self.post_process(
process_idx, maxlen, maxlenratio, best, ended_hyps
)
if cur_end_frame >= x.shape[0]:
for hyp in local_ended_hyps:
ended_hyps.append(hyp)
if len(running_hyps) == 0:
logging.info("no hypothesis. Finish decoding.")
continue_decode = False
break
else:
logging.debug(f"remained hypotheses: {len(running_hyps)}")
# increment number
process_idx += 1
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
# check the number of hypotheses reaching to eos
if len(nbest_hyps) == 0:
logging.warning(
"there is no N-best results, perform recognition "
"again with smaller minlenratio."
)
return (
[]
if minlenratio < 0.1
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
)
# report the best result
best = nbest_hyps[0]
for k, v in best.scores.items():
logging.info(
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
)
logging.info(f"total log probability: {best.score:.2f}")
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
if self.token_list is not None:
logging.info(
"best hypo: "
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
+ "\n"
)
return nbest_hyps
def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]:
"""Extend probabilities and states with more encoded chunks.
Args:
x (torch.Tensor): The extended encoder output feature
hyps (Hypothesis): Current list of hypothesis
Returns:
Hypothesis: The exxtended hypothesis
"""
for k, d in self.scorers.items():
if hasattr(d, "extend_prob"):
d.extend_prob(x)
if hasattr(d, "extend_state"):
hyps.states[k] = d.extend_state(hyps.states[k])