import collections from itertools import repeat import torch import torch.nn.utils.rnn as rnn_utils def _ntuple(n): def parse(x): if isinstance(x, collections.Iterable): return x return tuple(repeat(x, n)) return parse _single = _ntuple(1) _pair = _ntuple(2) _triple = _ntuple(3) _quadruple = _ntuple(4) def prepare_rnn_seq(rnn_input, lengths, hx=None, masks=None, batch_first=False): ''' Args: rnn_input: [seq_len, batch_size, input_size]: tensor containing the features of the input sequence. lengths: [batch_size]: tensor containing the lengthes of the input sequence hx: [num_layers * num_directions, batch_size, hidden_size]: tensor containing the initial hidden state for each element in the batch. masks: [seq_len, batch_size]: tensor containing the mask for each element in the batch. batch_first: If True, then the input and output tensors are provided as [batch_size, seq_len, feature]. Returns: ''' def check_decreasing(lengths): lens, order = torch.sort(lengths, dim=0, descending=True) if torch.ne(lens, lengths).sum() == 0: return None else: _, rev_order = torch.sort(order) return lens, order, rev_order check_res = check_decreasing(lengths) if check_res is None: lens = lengths rev_order = None else: lens, order, rev_order = check_res batch_dim = 0 if batch_first else 1 rnn_input = rnn_input.index_select(batch_dim, order) if hx is not None: # hack lstm if isinstance(hx, tuple): hx, cx = hx hx = hx.index_select(1, order) cx = cx.index_select(1, order) hx = (hx, cx) else: hx = hx.index_select(1, order) lens = lens.tolist() seq = rnn_utils.pack_padded_sequence(rnn_input, lens, batch_first=batch_first) if masks is not None: if batch_first: masks = masks[:, :lens[0]] else: masks = masks[:lens[0]] return seq, hx, rev_order, masks def recover_rnn_seq(seq, rev_order, hx=None, batch_first=False): output, _ = rnn_utils.pad_packed_sequence(seq, batch_first=batch_first) if rev_order is not None: batch_dim = 0 if batch_first else 1 output = output.index_select(batch_dim, rev_order) if hx is not None: # hack lstm if isinstance(hx, tuple): hx, cx = hx hx = hx.index_select(1, rev_order) cx = cx.index_select(1, rev_order) hx = (hx, cx) else: hx = hx.index_select(1, rev_order) return output, hx