|
""" |
|
Taken from ESPNet, modified by Florian Lux |
|
""" |
|
|
|
import os |
|
from abc import ABC |
|
|
|
import torch |
|
|
|
|
|
def cumsum_durations(durations): |
|
out = [0] |
|
for duration in durations: |
|
out.append(duration + out[-1]) |
|
centers = list() |
|
for index, _ in enumerate(out): |
|
if index + 1 < len(out): |
|
centers.append((out[index] + out[index + 1]) / 2) |
|
return out, centers |
|
|
|
|
|
def delete_old_checkpoints(checkpoint_dir, keep=5): |
|
checkpoint_list = list() |
|
for el in os.listdir(checkpoint_dir): |
|
if el.endswith(".pt") and el != "best.pt": |
|
checkpoint_list.append(int(el.split(".")[0].split("_")[1])) |
|
if len(checkpoint_list) <= keep: |
|
return |
|
else: |
|
checkpoint_list.sort(reverse=False) |
|
checkpoints_to_delete = [os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(step)) for step in checkpoint_list[:-keep]] |
|
for old_checkpoint in checkpoints_to_delete: |
|
os.remove(os.path.join(old_checkpoint)) |
|
|
|
|
|
def get_most_recent_checkpoint(checkpoint_dir, verbose=True): |
|
checkpoint_list = list() |
|
for el in os.listdir(checkpoint_dir): |
|
if el.endswith(".pt") and el != "best.pt": |
|
checkpoint_list.append(int(el.split(".")[0].split("_")[1])) |
|
if len(checkpoint_list) == 0: |
|
print("No previous checkpoints found, cannot reload.") |
|
return None |
|
checkpoint_list.sort(reverse=True) |
|
if verbose: |
|
print("Reloading checkpoint_{}.pt".format(checkpoint_list[0])) |
|
return os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(checkpoint_list[0])) |
|
|
|
|
|
def make_pad_mask(lengths, xs=None, length_dim=-1, device=None): |
|
""" |
|
Make mask tensor containing indices of padded part. |
|
|
|
Args: |
|
lengths (LongTensor or List): Batch of lengths (B,). |
|
xs (Tensor, optional): The reference tensor. |
|
If set, masks will be the same shape as this tensor. |
|
length_dim (int, optional): Dimension indicator of the above tensor. |
|
See the example. |
|
|
|
Returns: |
|
Tensor: Mask tensor containing indices of padded part. |
|
dtype=torch.uint8 in PyTorch 1.2- |
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
|
|
|
""" |
|
if length_dim == 0: |
|
raise ValueError("length_dim cannot be 0: {}".format(length_dim)) |
|
|
|
if not isinstance(lengths, list): |
|
lengths = lengths.tolist() |
|
bs = int(len(lengths)) |
|
if xs is None: |
|
maxlen = int(max(lengths)) |
|
else: |
|
maxlen = xs.size(length_dim) |
|
|
|
if device is not None: |
|
seq_range = torch.arange(0, maxlen, dtype=torch.int64, device=device) |
|
else: |
|
seq_range = torch.arange(0, maxlen, dtype=torch.int64) |
|
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) |
|
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) |
|
mask = seq_range_expand >= seq_length_expand |
|
|
|
if xs is not None: |
|
assert xs.size(0) == bs, (xs.size(0), bs) |
|
|
|
if length_dim < 0: |
|
length_dim = xs.dim() + length_dim |
|
|
|
ind = tuple(slice(None) if i in (0, length_dim) else None for i in range(xs.dim())) |
|
mask = mask[ind].expand_as(xs).to(xs.device) |
|
return mask |
|
|
|
|
|
def make_non_pad_mask(lengths, xs=None, length_dim=-1, device=None): |
|
""" |
|
Make mask tensor containing indices of non-padded part. |
|
|
|
Args: |
|
lengths (LongTensor or List): Batch of lengths (B,). |
|
xs (Tensor, optional): The reference tensor. |
|
If set, masks will be the same shape as this tensor. |
|
length_dim (int, optional): Dimension indicator of the above tensor. |
|
See the example. |
|
|
|
Returns: |
|
ByteTensor: mask tensor containing indices of padded part. |
|
dtype=torch.uint8 in PyTorch 1.2- |
|
dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
|
|
|
""" |
|
return ~make_pad_mask(lengths, xs, length_dim, device=device) |
|
|
|
|
|
def initialize(model, init): |
|
""" |
|
Initialize weights of a neural network module. |
|
|
|
Parameters are initialized using the given method or distribution. |
|
|
|
Args: |
|
model: Target. |
|
init: Method of initialization. |
|
""" |
|
|
|
|
|
for p in model.parameters(): |
|
if p.dim() > 1: |
|
if init == "xavier_uniform": |
|
torch.nn.init.xavier_uniform_(p.data) |
|
elif init == "xavier_normal": |
|
torch.nn.init.xavier_normal_(p.data) |
|
elif init == "kaiming_uniform": |
|
torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") |
|
elif init == "kaiming_normal": |
|
torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") |
|
else: |
|
raise ValueError("Unknown initialization: " + init) |
|
|
|
for p in model.parameters(): |
|
if p.dim() == 1: |
|
p.data.zero_() |
|
|
|
|
|
for m in model.modules(): |
|
if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm)): |
|
m.reset_parameters() |
|
|
|
|
|
def pad_list(xs, pad_value): |
|
""" |
|
Perform padding for the list of tensors. |
|
|
|
Args: |
|
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. |
|
pad_value (float): Value for padding. |
|
|
|
Returns: |
|
Tensor: Padded tensor (B, Tmax, `*`). |
|
|
|
""" |
|
n_batch = len(xs) |
|
max_len = max(x.size(0) for x in xs) |
|
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) |
|
|
|
for i in range(n_batch): |
|
pad[i, : xs[i].size(0)] = xs[i] |
|
|
|
return pad |
|
|
|
|
|
def subsequent_mask(size, device="cpu", dtype=torch.bool): |
|
""" |
|
Create mask for subsequent steps (size, size). |
|
|
|
:param int size: size of mask |
|
:param str device: "cpu" or "cuda" or torch.Tensor.device |
|
:param torch.dtype dtype: result dtype |
|
:rtype |
|
""" |
|
ret = torch.ones(size, size, device=device, dtype=dtype) |
|
return torch.tril(ret, out=ret) |
|
|
|
|
|
class ScorerInterface: |
|
""" |
|
Scorer interface for beam search. |
|
|
|
The scorer performs scoring of the all tokens in vocabulary. |
|
|
|
Examples: |
|
* Search heuristics |
|
* :class:`espnet.nets.scorers.length_bonus.LengthBonus` |
|
* Decoder networks of the sequence-to-sequence models |
|
* :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder` |
|
* :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder` |
|
* Neural language models |
|
* :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM` |
|
* :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM` |
|
* :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM` |
|
|
|
""" |
|
|
|
def init_state(self, x): |
|
""" |
|
Get an initial state for decoding (optional). |
|
|
|
Args: |
|
x (torch.Tensor): The encoded feature tensor |
|
|
|
Returns: initial state |
|
|
|
""" |
|
return None |
|
|
|
def select_state(self, state, i, new_id=None): |
|
""" |
|
Select state with relative ids in the main beam search. |
|
|
|
Args: |
|
state: Decoder state for prefix tokens |
|
i (int): Index to select a state in the main beam search |
|
new_id (int): New label index to select a state if necessary |
|
|
|
Returns: |
|
state: pruned state |
|
|
|
""" |
|
return None if state is None else state[i] |
|
|
|
def score(self, y, state, x): |
|
""" |
|
Score new token (required). |
|
|
|
Args: |
|
y (torch.Tensor): 1D torch.int64 prefix tokens. |
|
state: Scorer state for prefix tokens |
|
x (torch.Tensor): The encoder feature that generates ys. |
|
|
|
Returns: |
|
tuple[torch.Tensor, Any]: Tuple of |
|
scores for next token that has a shape of `(n_vocab)` |
|
and next state for ys |
|
|
|
""" |
|
raise NotImplementedError |
|
|
|
def final_score(self, state): |
|
""" |
|
Score eos (optional). |
|
|
|
Args: |
|
state: Scorer state for prefix tokens |
|
|
|
Returns: |
|
float: final score |
|
|
|
""" |
|
return 0.0 |
|
|
|
|
|
class BatchScorerInterface(ScorerInterface, ABC): |
|
|
|
def batch_init_state(self, x): |
|
""" |
|
Get an initial state for decoding (optional). |
|
|
|
Args: |
|
x (torch.Tensor): The encoded feature tensor |
|
|
|
Returns: initial state |
|
|
|
""" |
|
return self.init_state(x) |
|
|
|
def batch_score(self, ys, states, xs): |
|
""" |
|
Score new token batch (required). |
|
|
|
Args: |
|
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). |
|
states (List[Any]): Scorer states for prefix tokens. |
|
xs (torch.Tensor): |
|
The encoder feature that generates ys (n_batch, xlen, n_feat). |
|
|
|
Returns: |
|
tuple[torch.Tensor, List[Any]]: Tuple of |
|
batchfied scores for next token with shape of `(n_batch, n_vocab)` |
|
and next state list for ys. |
|
|
|
""" |
|
scores = list() |
|
outstates = list() |
|
for i, (y, state, x) in enumerate(zip(ys, states, xs)): |
|
score, outstate = self.score(y, state, x) |
|
outstates.append(outstate) |
|
scores.append(score) |
|
scores = torch.cat(scores, 0).view(ys.shape[0], -1) |
|
return scores, outstates |
|
|
|
|
|
def to_device(m, x): |
|
"""Send tensor into the device of the module. |
|
Args: |
|
m (torch.nn.Module): Torch module. |
|
x (Tensor): Torch tensor. |
|
Returns: |
|
Tensor: Torch tensor located in the same place as torch module. |
|
""" |
|
if isinstance(m, torch.nn.Module): |
|
device = next(m.parameters()).device |
|
elif isinstance(m, torch.Tensor): |
|
device = m.device |
|
else: |
|
raise TypeError( |
|
"Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}" |
|
) |
|
return x.to(device) |
|
|