| """Utility functions for transducer models.""" | |
| import os | |
| import numpy as np | |
| import torch | |
| from espnet.nets.pytorch_backend.nets_utils import pad_list | |
| def prepare_loss_inputs(ys_pad, hlens, blank_id=0, ignore_id=-1): | |
| """Prepare tensors for transducer loss computation. | |
| Args: | |
| ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) | |
| hlens (torch.Tensor): batch of hidden sequence lengthts (B) | |
| or batch of masks (B, 1, Tmax) | |
| blank_id (int): index of blank label | |
| ignore_id (int): index of initial padding | |
| Returns: | |
| ys_in_pad (torch.Tensor): batch of padded target sequences + blank (B, Lmax + 1) | |
| target (torch.Tensor): batch of padded target sequences (B, Lmax) | |
| pred_len (torch.Tensor): batch of hidden sequence lengths (B) | |
| target_len (torch.Tensor): batch of output sequence lengths (B) | |
| """ | |
| device = ys_pad.device | |
| ys = [y[y != ignore_id] for y in ys_pad] | |
| blank = ys[0].new([blank_id]) | |
| ys_in_pad = pad_list([torch.cat([blank, y], dim=0) for y in ys], blank_id) | |
| ys_out_pad = pad_list([torch.cat([y, blank], dim=0) for y in ys], ignore_id) | |
| target = pad_list(ys, blank_id).type(torch.int32).to(device) | |
| target_len = torch.IntTensor([y.size(0) for y in ys]).to(device) | |
| if torch.is_tensor(hlens): | |
| if hlens.dim() > 1: | |
| hs = [h[h != 0] for h in hlens] | |
| hlens = list(map(int, [h.size(0) for h in hs])) | |
| else: | |
| hlens = list(map(int, hlens)) | |
| pred_len = torch.IntTensor(hlens).to(device) | |
| return ys_in_pad, ys_out_pad, target, pred_len, target_len | |
| def valid_aux_task_layer_list(aux_layer_ids, enc_num_layers): | |
| """Check whether input list of auxiliary layer ids is valid. | |
| Return the valid list sorted with duplicated removed. | |
| Args: | |
| aux_layer_ids (list): Auxiliary layers ids | |
| enc_num_layers (int): Number of encoder layers | |
| Returns: | |
| valid (list): Validated list of layers for auxiliary task | |
| """ | |
| if ( | |
| not isinstance(aux_layer_ids, list) | |
| or not aux_layer_ids | |
| or not all(isinstance(layer, int) for layer in aux_layer_ids) | |
| ): | |
| raise ValueError("--aux-task-layer-list argument takes a list of layer ids.") | |
| sorted_list = sorted(aux_layer_ids, key=int, reverse=False) | |
| valid = list(filter(lambda x: 0 <= x < enc_num_layers, sorted_list)) | |
| if sorted_list != valid: | |
| raise ValueError( | |
| "Provided list of layer ids for auxiliary task is incorrect. " | |
| "IDs should be between [0, %d]" % (enc_num_layers - 1) | |
| ) | |
| return valid | |
| def is_prefix(x, pref): | |
| """Check prefix. | |
| Args: | |
| x (list): token id sequence | |
| pref (list): token id sequence | |
| Returns: | |
| (boolean): whether pref is a prefix of x. | |
| """ | |
| if len(pref) >= len(x): | |
| return False | |
| for i in range(len(pref)): | |
| if pref[i] != x[i]: | |
| return False | |
| return True | |
| def substract(x, subset): | |
| """Remove elements of subset if corresponding token id sequence exist in x. | |
| Args: | |
| x (list): set of hypotheses | |
| subset (list): subset of hypotheses | |
| Returns: | |
| final (list): new set | |
| """ | |
| final = [] | |
| for x_ in x: | |
| if any(x_.yseq == sub.yseq for sub in subset): | |
| continue | |
| final.append(x_) | |
| return final | |
| def select_lm_state(lm_states, idx, lm_layers, is_wordlm): | |
| """Get LM state from batch for given id. | |
| Args: | |
| lm_states (list or dict): batch of LM states | |
| idx (int): index to extract state from batch state | |
| lm_layers (int): number of LM layers | |
| is_wordlm (bool): whether provided LM is a word-LM | |
| Returns: | |
| idx_state (dict): LM state for given id | |
| """ | |
| if is_wordlm: | |
| idx_state = lm_states[idx] | |
| else: | |
| idx_state = {} | |
| idx_state["c"] = [lm_states["c"][layer][idx] for layer in range(lm_layers)] | |
| idx_state["h"] = [lm_states["h"][layer][idx] for layer in range(lm_layers)] | |
| return idx_state | |
| def create_lm_batch_state(lm_states_list, lm_layers, is_wordlm): | |
| """Create batch of LM states. | |
| Args: | |
| lm_states (list or dict): list of individual LM states | |
| lm_layers (int): number of LM layers | |
| is_wordlm (bool): whether provided LM is a word-LM | |
| Returns: | |
| batch_states (list): batch of LM states | |
| """ | |
| if is_wordlm: | |
| batch_states = lm_states_list | |
| else: | |
| batch_states = {} | |
| batch_states["c"] = [ | |
| torch.stack([state["c"][layer] for state in lm_states_list]) | |
| for layer in range(lm_layers) | |
| ] | |
| batch_states["h"] = [ | |
| torch.stack([state["h"][layer] for state in lm_states_list]) | |
| for layer in range(lm_layers) | |
| ] | |
| return batch_states | |
| def init_lm_state(lm_model): | |
| """Initialize LM state. | |
| Args: | |
| lm_model (torch.nn.Module): LM module | |
| Returns: | |
| lm_state (dict): initial LM state | |
| """ | |
| lm_layers = len(lm_model.rnn) | |
| lm_units_typ = lm_model.typ | |
| lm_units = lm_model.n_units | |
| p = next(lm_model.parameters()) | |
| h = [ | |
| torch.zeros(lm_units).to(device=p.device, dtype=p.dtype) | |
| for _ in range(lm_layers) | |
| ] | |
| lm_state = {"h": h} | |
| if lm_units_typ == "lstm": | |
| lm_state["c"] = [ | |
| torch.zeros(lm_units).to(device=p.device, dtype=p.dtype) | |
| for _ in range(lm_layers) | |
| ] | |
| return lm_state | |
| def recombine_hyps(hyps): | |
| """Recombine hypotheses with equivalent output sequence. | |
| Args: | |
| hyps (list): list of hypotheses | |
| Returns: | |
| final (list): list of recombined hypotheses | |
| """ | |
| final = [] | |
| for hyp in hyps: | |
| seq_final = [f.yseq for f in final if f.yseq] | |
| if hyp.yseq in seq_final: | |
| seq_pos = seq_final.index(hyp.yseq) | |
| final[seq_pos].score = np.logaddexp(final[seq_pos].score, hyp.score) | |
| else: | |
| final.append(hyp) | |
| return hyps | |
| def pad_sequence(seqlist, pad_token): | |
| """Left pad list of token id sequences. | |
| Args: | |
| seqlist (list): list of token id sequences | |
| pad_token (int): padding token id | |
| Returns: | |
| final (list): list of padded token id sequences | |
| """ | |
| maxlen = max(len(x) for x in seqlist) | |
| final = [([pad_token] * (maxlen - len(x))) + x for x in seqlist] | |
| return final | |
| def check_state(state, max_len, pad_token): | |
| """Check state and left pad or trim if necessary. | |
| Args: | |
| state (list): list of of L decoder states (in_len, dec_dim) | |
| max_len (int): maximum length authorized | |
| pad_token (int): padding token id | |
| Returns: | |
| final (list): list of L padded decoder states (1, max_len, dec_dim) | |
| """ | |
| if state is None or max_len < 1 or state[0].size(1) == max_len: | |
| return state | |
| curr_len = state[0].size(1) | |
| if curr_len > max_len: | |
| trim_val = int(state[0].size(1) - max_len) | |
| for i, s in enumerate(state): | |
| state[i] = s[:, trim_val:, :] | |
| else: | |
| layers = len(state) | |
| ddim = state[0].size(2) | |
| final_dims = (1, max_len, ddim) | |
| final = [state[0].data.new(*final_dims).fill_(pad_token) for _ in range(layers)] | |
| for i, s in enumerate(state): | |
| final[i][:, (max_len - s.size(1)) : max_len, :] = s | |
| return final | |
| return state | |
| def check_batch_state(state, max_len, pad_token): | |
| """Check batch of states and left pad or trim if necessary. | |
| Args: | |
| state (list): list of of L decoder states (B, ?, dec_dim) | |
| max_len (int): maximum length authorized | |
| pad_token (int): padding token id | |
| Returns: | |
| final (list): list of L decoder states (B, pred_len, dec_dim) | |
| """ | |
| final_dims = (len(state), max_len, state[0].size(1)) | |
| final = state[0].data.new(*final_dims).fill_(pad_token) | |
| for i, s in enumerate(state): | |
| curr_len = s.size(0) | |
| if curr_len < max_len: | |
| final[i, (max_len - curr_len) : max_len, :] = s | |
| else: | |
| final[i, :, :] = s[(curr_len - max_len) :, :] | |
| return final | |
| def custom_torch_load(model_path, model, training=True): | |
| """Load transducer model modules and parameters with training-only ones removed. | |
| Args: | |
| model_path (str): Model path | |
| model (torch.nn.Module): The model with pretrained modules | |
| """ | |
| if "snapshot" in os.path.basename(model_path): | |
| model_state_dict = torch.load( | |
| model_path, map_location=lambda storage, loc: storage | |
| )["model"] | |
| else: | |
| model_state_dict = torch.load( | |
| model_path, map_location=lambda storage, loc: storage | |
| ) | |
| if not training: | |
| model_state_dict = { | |
| k: v for k, v in model_state_dict.items() if not k.startswith("aux") | |
| } | |
| model.load_state_dict(model_state_dict) | |
| del model_state_dict | |