Spaces:
Runtime error
Runtime error
File size: 1,088 Bytes
914502f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class LstmSeq2SeqEncoder(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False):
super(LstmSeq2SeqEncoder, self).__init__()
self.lstm = nn.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
batch_first=True)
def forward(self, x, mask, hidden=None):
# Packing the input sequence
lengths = mask.sum(dim=1).cpu()
packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
# Passing packed sequence through LSTM
packed_output, hidden = self.lstm(packed_x, hidden)
# Unpacking the output sequence
output, _ = pad_packed_sequence(packed_output, batch_first=True)
return output
|