import torch import torch.nn.functional as F from torch import nn from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence class LstmSeq2SeqEncoder(nn.Module): def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False): super(LstmSeq2SeqEncoder, self).__init__() self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional, batch_first=True) def forward(self, x, mask, hidden=None): # Packing the input sequence lengths = mask.sum(dim=1).cpu() packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) # Passing packed sequence through LSTM packed_output, hidden = self.lstm(packed_x, hidden) # Unpacking the output sequence output, _ = pad_packed_sequence(packed_output, batch_first=True) return output