Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
class LstmSeq2SeqEncoder(nn.Module): | |
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False): | |
super(LstmSeq2SeqEncoder, self).__init__() | |
self.lstm = nn.LSTM(input_size=input_size, | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
dropout=dropout, | |
bidirectional=bidirectional, | |
batch_first=True) | |
def forward(self, x, mask, hidden=None): | |
# Packing the input sequence | |
lengths = mask.sum(dim=1).cpu() | |
packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) | |
# Passing packed sequence through LSTM | |
packed_output, hidden = self.lstm(packed_x, hidden) | |
# Unpacking the output sequence | |
output, _ = pad_packed_sequence(packed_output, batch_first=True) | |
return output | |