|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class DocumentBiLSTM(nn.Module): |
|
""" |
|
BiLSTM implementation with stability improvements inspired by DocBERT |
|
""" |
|
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, |
|
n_layers=2, dropout=0.5, pad_idx=0): |
|
super().__init__() |
|
|
|
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx) |
|
|
|
self.lstm = nn.LSTM(embedding_dim, |
|
hidden_dim, |
|
num_layers=n_layers, |
|
bidirectional=True, |
|
dropout=dropout if n_layers > 1 else 0, |
|
batch_first=True) |
|
|
|
|
|
self.layer_norm = nn.LayerNorm(hidden_dim * 2) |
|
|
|
self.fc = nn.Linear(hidden_dim * 2, output_dim) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, input_ids, attention_mask=None, **kwargs): |
|
|
|
|
|
|
|
embedded = self.embedding(input_ids) |
|
|
|
|
|
embedded = self.dropout(embedded) |
|
|
|
|
|
hidden = None |
|
cell = None |
|
|
|
if attention_mask is not None: |
|
|
|
seq_lengths = attention_mask.sum(dim=1).to(torch.int64).cpu() |
|
|
|
|
|
seq_lengths, indices = torch.sort(seq_lengths, descending=True) |
|
sorted_embedded = embedded[indices] |
|
|
|
|
|
packed_embedded = nn.utils.rnn.pack_padded_sequence( |
|
sorted_embedded, seq_lengths, batch_first=True, enforce_sorted=True |
|
) |
|
|
|
|
|
packed_output, (hidden, cell) = self.lstm(packed_embedded) |
|
|
|
|
|
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True) |
|
|
|
|
|
_, restore_indices = torch.sort(indices) |
|
hidden = hidden[:, restore_indices] |
|
else: |
|
|
|
_, (hidden, cell) = self.lstm(embedded) |
|
|
|
|
|
hidden_cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1) |
|
|
|
|
|
normalized = self.layer_norm(hidden_cat) |
|
|
|
|
|
dropped = self.dropout(normalized) |
|
|
|
|
|
prediction = self.fc(dropped) |
|
|
|
return prediction |