conex / espnet /nets /beam_search_transducer.py
tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
21.3 kB
"""Search algorithms for transducer models."""
from typing import List
from typing import Union
import numpy as np
import torch
from espnet.nets.pytorch_backend.transducer.utils import create_lm_batch_state
from espnet.nets.pytorch_backend.transducer.utils import init_lm_state
from espnet.nets.pytorch_backend.transducer.utils import is_prefix
from espnet.nets.pytorch_backend.transducer.utils import recombine_hyps
from espnet.nets.pytorch_backend.transducer.utils import select_lm_state
from espnet.nets.pytorch_backend.transducer.utils import substract
from espnet.nets.transducer_decoder_interface import Hypothesis
from espnet.nets.transducer_decoder_interface import NSCHypothesis
from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface
class BeamSearchTransducer:
"""Beam search implementation for transducer."""
def __init__(
self,
decoder: Union[TransducerDecoderInterface, torch.nn.Module],
joint_network: torch.nn.Module,
beam_size: int,
lm: torch.nn.Module = None,
lm_weight: float = 0.1,
search_type: str = "default",
max_sym_exp: int = 2,
u_max: int = 50,
nstep: int = 1,
prefix_alpha: int = 1,
score_norm: bool = True,
nbest: int = 1,
):
"""Initialize transducer beam search.
Args:
decoder: Decoder class to use
joint_network: Joint Network class
beam_size: Number of hypotheses kept during search
lm: LM class to use
lm_weight: lm weight for soft fusion
search_type: type of algorithm to use for search
max_sym_exp: number of maximum symbol expansions at each time step ("tsd")
u_max: maximum output sequence length ("alsd")
nstep: number of maximum expansion steps at each time step ("nsc")
prefix_alpha: maximum prefix length in prefix search ("nsc")
score_norm: normalize final scores by length ("default")
nbest: number of returned final hypothesis
"""
self.decoder = decoder
self.joint_network = joint_network
self.beam_size = beam_size
self.hidden_size = decoder.dunits
self.vocab_size = decoder.odim
self.blank = decoder.blank
if self.beam_size <= 1:
self.search_algorithm = self.greedy_search
elif search_type == "default":
self.search_algorithm = self.default_beam_search
elif search_type == "tsd":
self.search_algorithm = self.time_sync_decoding
elif search_type == "alsd":
self.search_algorithm = self.align_length_sync_decoding
elif search_type == "nsc":
self.search_algorithm = self.nsc_beam_search
else:
raise NotImplementedError
self.lm = lm
self.lm_weight = lm_weight
if lm is not None:
self.use_lm = True
self.is_wordlm = True if hasattr(lm.predictor, "wordlm") else False
self.lm_predictor = lm.predictor.wordlm if self.is_wordlm else lm.predictor
self.lm_layers = len(self.lm_predictor.rnn)
else:
self.use_lm = False
self.max_sym_exp = max_sym_exp
self.u_max = u_max
self.nstep = nstep
self.prefix_alpha = prefix_alpha
self.score_norm = score_norm
self.nbest = nbest
def __call__(self, h: torch.Tensor) -> Union[List[Hypothesis], List[NSCHypothesis]]:
"""Perform beam search.
Args:
h: Encoded speech features (T_max, D_enc)
Returns:
nbest_hyps: N-best decoding results
"""
self.decoder.set_device(h.device)
if not hasattr(self.decoder, "decoders"):
self.decoder.set_data_type(h.dtype)
nbest_hyps = self.search_algorithm(h)
return nbest_hyps
def sort_nbest(
self, hyps: Union[List[Hypothesis], List[NSCHypothesis]]
) -> Union[List[Hypothesis], List[NSCHypothesis]]:
"""Sort hypotheses by score or score given sequence length.
Args:
hyps: list of hypotheses
Return:
hyps: sorted list of hypotheses
"""
if self.score_norm:
hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True)
else:
hyps.sort(key=lambda x: x.score, reverse=True)
return hyps[: self.nbest]
def greedy_search(self, h: torch.Tensor) -> List[Hypothesis]:
"""Greedy search implementation for transformer-transducer.
Args:
h: Encoded speech features (T_max, D_enc)
Returns:
hyp: 1-best decoding results
"""
dec_state = self.decoder.init_state(1)
hyp = Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state)
cache = {}
y, state, _ = self.decoder.score(hyp, cache)
for i, hi in enumerate(h):
ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1)
logp, pred = torch.max(ytu, dim=-1)
if pred != self.blank:
hyp.yseq.append(int(pred))
hyp.score += float(logp)
hyp.dec_state = state
y, state, _ = self.decoder.score(hyp, cache)
return [hyp]
def default_beam_search(self, h: torch.Tensor) -> List[Hypothesis]:
"""Beam search implementation.
Args:
x: Encoded speech features (T_max, D_enc)
Returns:
nbest_hyps: N-best decoding results
"""
beam = min(self.beam_size, self.vocab_size)
beam_k = min(beam, (self.vocab_size - 1))
dec_state = self.decoder.init_state(1)
kept_hyps = [Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state)]
cache = {}
for hi in h:
hyps = kept_hyps
kept_hyps = []
while True:
max_hyp = max(hyps, key=lambda x: x.score)
hyps.remove(max_hyp)
y, state, lm_tokens = self.decoder.score(max_hyp, cache)
ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1)
top_k = ytu[1:].topk(beam_k, dim=-1)
kept_hyps.append(
Hypothesis(
score=(max_hyp.score + float(ytu[0:1])),
yseq=max_hyp.yseq[:],
dec_state=max_hyp.dec_state,
lm_state=max_hyp.lm_state,
)
)
if self.use_lm:
lm_state, lm_scores = self.lm.predict(max_hyp.lm_state, lm_tokens)
else:
lm_state = max_hyp.lm_state
for logp, k in zip(*top_k):
score = max_hyp.score + float(logp)
if self.use_lm:
score += self.lm_weight * lm_scores[0][k + 1]
hyps.append(
Hypothesis(
score=score,
yseq=max_hyp.yseq[:] + [int(k + 1)],
dec_state=state,
lm_state=lm_state,
)
)
hyps_max = float(max(hyps, key=lambda x: x.score).score)
kept_most_prob = sorted(
[hyp for hyp in kept_hyps if hyp.score > hyps_max],
key=lambda x: x.score,
)
if len(kept_most_prob) >= beam:
kept_hyps = kept_most_prob
break
return self.sort_nbest(kept_hyps)
def time_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]:
"""Time synchronous beam search implementation.
Based on https://ieeexplore.ieee.org/document/9053040
Args:
h: Encoded speech features (T_max, D_enc)
Returns:
nbest_hyps: N-best decoding results
"""
beam = min(self.beam_size, self.vocab_size)
beam_state = self.decoder.init_state(beam)
B = [
Hypothesis(
yseq=[self.blank],
score=0.0,
dec_state=self.decoder.select_state(beam_state, 0),
)
]
cache = {}
if self.use_lm and not self.is_wordlm:
B[0].lm_state = init_lm_state(self.lm_predictor)
for hi in h:
A = []
C = B
h_enc = hi.unsqueeze(0)
for v in range(self.max_sym_exp):
D = []
beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
C,
beam_state,
cache,
self.use_lm,
)
beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1)
beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)
seq_A = [h.yseq for h in A]
for i, hyp in enumerate(C):
if hyp.yseq not in seq_A:
A.append(
Hypothesis(
score=(hyp.score + float(beam_logp[i, 0])),
yseq=hyp.yseq[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
)
)
else:
dict_pos = seq_A.index(hyp.yseq)
A[dict_pos].score = np.logaddexp(
A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))
)
if v < (self.max_sym_exp - 1):
if self.use_lm:
beam_lm_states = create_lm_batch_state(
[c.lm_state for c in C], self.lm_layers, self.is_wordlm
)
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
beam_lm_states, beam_lm_tokens, len(C)
)
for i, hyp in enumerate(C):
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
new_hyp = Hypothesis(
score=(hyp.score + float(logp)),
yseq=(hyp.yseq + [int(k)]),
dec_state=self.decoder.select_state(beam_state, i),
lm_state=hyp.lm_state,
)
if self.use_lm:
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
new_hyp.lm_state = select_lm_state(
beam_lm_states, i, self.lm_layers, self.is_wordlm
)
D.append(new_hyp)
C = sorted(D, key=lambda x: x.score, reverse=True)[:beam]
B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
return self.sort_nbest(B)
def align_length_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]:
"""Alignment-length synchronous beam search implementation.
Based on https://ieeexplore.ieee.org/document/9053040
Args:
h: Encoded speech features (T_max, D_enc)
Returns:
nbest_hyps: N-best decoding results
"""
beam = min(self.beam_size, self.vocab_size)
h_length = int(h.size(0))
u_max = min(self.u_max, (h_length - 1))
beam_state = self.decoder.init_state(beam)
B = [
Hypothesis(
yseq=[self.blank],
score=0.0,
dec_state=self.decoder.select_state(beam_state, 0),
)
]
final = []
cache = {}
if self.use_lm and not self.is_wordlm:
B[0].lm_state = init_lm_state(self.lm_predictor)
for i in range(h_length + u_max):
A = []
B_ = []
h_states = []
for hyp in B:
u = len(hyp.yseq) - 1
t = i - u + 1
if t > (h_length - 1):
continue
B_.append(hyp)
h_states.append((t, h[t]))
if B_:
beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
B_,
beam_state,
cache,
self.use_lm,
)
h_enc = torch.stack([h[1] for h in h_states])
beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1)
beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)
if self.use_lm:
beam_lm_states = create_lm_batch_state(
[b.lm_state for b in B_], self.lm_layers, self.is_wordlm
)
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
beam_lm_states, beam_lm_tokens, len(B_)
)
for i, hyp in enumerate(B_):
new_hyp = Hypothesis(
score=(hyp.score + float(beam_logp[i, 0])),
yseq=hyp.yseq[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
)
A.append(new_hyp)
if h_states[i][0] == (h_length - 1):
final.append(new_hyp)
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
new_hyp = Hypothesis(
score=(hyp.score + float(logp)),
yseq=(hyp.yseq[:] + [int(k)]),
dec_state=self.decoder.select_state(beam_state, i),
lm_state=hyp.lm_state,
)
if self.use_lm:
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
new_hyp.lm_state = select_lm_state(
beam_lm_states, i, self.lm_layers, self.is_wordlm
)
A.append(new_hyp)
B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
B = recombine_hyps(B)
if final:
return self.sort_nbest(final)
else:
return B
def nsc_beam_search(self, h: torch.Tensor) -> List[NSCHypothesis]:
"""N-step constrained beam search implementation.
Based and modified from https://arxiv.org/pdf/2002.03577.pdf.
Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet
until further modifications.
Note: the algorithm is not in his "complete" form but works almost as
intended.
Args:
h: Encoded speech features (T_max, D_enc)
Returns:
nbest_hyps: N-best decoding results
"""
beam = min(self.beam_size, self.vocab_size)
beam_k = min(beam, (self.vocab_size - 1))
beam_state = self.decoder.init_state(beam)
init_tokens = [
NSCHypothesis(
yseq=[self.blank],
score=0.0,
dec_state=self.decoder.select_state(beam_state, 0),
)
]
cache = {}
beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
init_tokens,
beam_state,
cache,
self.use_lm,
)
state = self.decoder.select_state(beam_state, 0)
if self.use_lm:
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
None, beam_lm_tokens, 1
)
lm_state = select_lm_state(
beam_lm_states, 0, self.lm_layers, self.is_wordlm
)
lm_scores = beam_lm_scores[0]
else:
lm_state = None
lm_scores = None
kept_hyps = [
NSCHypothesis(
yseq=[self.blank],
score=0.0,
dec_state=state,
y=[beam_y[0]],
lm_state=lm_state,
lm_scores=lm_scores,
)
]
for hi in h:
hyps = sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True)
kept_hyps = []
h_enc = hi.unsqueeze(0)
for j, hyp_j in enumerate(hyps[:-1]):
for hyp_i in hyps[(j + 1) :]:
curr_id = len(hyp_j.yseq)
next_id = len(hyp_i.yseq)
if (
is_prefix(hyp_j.yseq, hyp_i.yseq)
and (curr_id - next_id) <= self.prefix_alpha
):
ytu = torch.log_softmax(
self.joint_network(hi, hyp_i.y[-1]), dim=-1
)
curr_score = hyp_i.score + float(ytu[hyp_j.yseq[next_id]])
for k in range(next_id, (curr_id - 1)):
ytu = torch.log_softmax(
self.joint_network(hi, hyp_j.y[k]), dim=-1
)
curr_score += float(ytu[hyp_j.yseq[k + 1]])
hyp_j.score = np.logaddexp(hyp_j.score, curr_score)
S = []
V = []
for n in range(self.nstep):
beam_y = torch.stack([hyp.y[-1] for hyp in hyps])
beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1)
beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1)
for i, hyp in enumerate(hyps):
S.append(
NSCHypothesis(
yseq=hyp.yseq[:],
score=hyp.score + float(beam_logp[i, 0:1]),
y=hyp.y[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
lm_scores=hyp.lm_scores,
)
)
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
score = hyp.score + float(logp)
if self.use_lm:
score += self.lm_weight * float(hyp.lm_scores[k])
V.append(
NSCHypothesis(
yseq=hyp.yseq[:] + [int(k)],
score=score,
y=hyp.y[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
lm_scores=hyp.lm_scores,
)
)
V.sort(key=lambda x: x.score, reverse=True)
V = substract(V, hyps)[:beam]
beam_state = self.decoder.create_batch_states(
beam_state,
[v.dec_state for v in V],
[v.yseq for v in V],
)
beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
V,
beam_state,
cache,
self.use_lm,
)
if self.use_lm:
beam_lm_states = create_lm_batch_state(
[v.lm_state for v in V], self.lm_layers, self.is_wordlm
)
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
beam_lm_states, beam_lm_tokens, len(V)
)
if n < (self.nstep - 1):
for i, v in enumerate(V):
v.y.append(beam_y[i])
v.dec_state = self.decoder.select_state(beam_state, i)
if self.use_lm:
v.lm_state = select_lm_state(
beam_lm_states, i, self.lm_layers, self.is_wordlm
)
v.lm_scores = beam_lm_scores[i]
hyps = V[:]
else:
beam_logp = torch.log_softmax(
self.joint_network(h_enc, beam_y), dim=-1
)
for i, v in enumerate(V):
if self.nstep != 1:
v.score += float(beam_logp[i, 0])
v.y.append(beam_y[i])
v.dec_state = self.decoder.select_state(beam_state, i)
if self.use_lm:
v.lm_state = select_lm_state(
beam_lm_states, i, self.lm_layers, self.is_wordlm
)
v.lm_scores = beam_lm_scores[i]
kept_hyps = sorted((S + V), key=lambda x: x.score, reverse=True)[:beam]
return self.sort_nbest(kept_hyps)