Spaces:
Sleeping
Sleeping
from copy import deepcopy | |
import numpy as np | |
import torch | |
from torch import nn | |
class RNNEnoder(nn.Module): | |
def __init__(self, cfg): | |
super(RNNEnoder, self).__init__() | |
self.cfg = cfg | |
self.rnn_type = cfg.MODEL.LANGUAGE_BACKBONE.RNN_TYPE | |
self.variable_length = cfg.MODEL.LANGUAGE_BACKBONE.VARIABLE_LENGTH | |
self.word_embedding_size = cfg.MODEL.LANGUAGE_BACKBONE.WORD_EMBEDDING_SIZE | |
self.word_vec_size = cfg.MODEL.LANGUAGE_BACKBONE.WORD_VEC_SIZE | |
self.hidden_size = cfg.MODEL.LANGUAGE_BACKBONE.HIDDEN_SIZE | |
self.bidirectional = cfg.MODEL.LANGUAGE_BACKBONE.BIDIRECTIONAL | |
self.input_dropout_p = cfg.MODEL.LANGUAGE_BACKBONE.INPUT_DROPOUT_P | |
self.dropout_p = cfg.MODEL.LANGUAGE_BACKBONE.DROPOUT_P | |
self.n_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS | |
self.corpus_path = cfg.MODEL.LANGUAGE_BACKBONE.CORPUS_PATH | |
self.vocab_size = cfg.MODEL.LANGUAGE_BACKBONE.VOCAB_SIZE | |
# language encoder | |
self.embedding = nn.Embedding(self.vocab_size, self.word_embedding_size) | |
self.input_dropout = nn.Dropout(self.input_dropout_p) | |
self.mlp = nn.Sequential(nn.Linear(self.word_embedding_size, self.word_vec_size), nn.ReLU()) | |
self.rnn = getattr(nn, self.rnn_type.upper())( | |
self.word_vec_size, | |
self.hidden_size, | |
self.n_layers, | |
batch_first=True, | |
bidirectional=self.bidirectional, | |
dropout=self.dropout_p, | |
) | |
self.num_dirs = 2 if self.bidirectional else 1 | |
def forward(self, input, mask=None): | |
word_id = input | |
max_len = (word_id != 0).sum(1).max().item() | |
word_id = word_id[:, :max_len] # mask zero | |
# embedding | |
output, hidden, embedded, final_output = self.RNNEncode(word_id) | |
return { | |
"hidden": hidden, | |
"output": output, | |
"embedded": embedded, | |
"final_output": final_output, | |
} | |
def encode(self, input_labels): | |
""" | |
Inputs: | |
- input_labels: Variable long (batch, seq_len) | |
Outputs: | |
- output : Variable float (batch, max_len, hidden_size * num_dirs) | |
- hidden : Variable float (batch, num_layers * num_dirs * hidden_size) | |
- embedded: Variable float (batch, max_len, word_vec_size) | |
""" | |
device = input_labels.device | |
if self.variable_length: | |
input_lengths_list, sorted_lengths_list, sort_idxs, recover_idxs = self.sort_inputs(input_labels) | |
input_labels = input_labels[sort_idxs] | |
embedded = self.embedding(input_labels) # (n, seq_len, word_embedding_size) | |
embedded = self.input_dropout(embedded) # (n, seq_len, word_embedding_size) | |
embedded = self.mlp(embedded) # (n, seq_len, word_vec_size) | |
if self.variable_length: | |
if self.variable_length: | |
embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_lengths_list, batch_first=True) | |
# forward rnn | |
self.rnn.flatten_parameters() | |
output, hidden = self.rnn(embedded) | |
# recover | |
if self.variable_length: | |
# recover embedded | |
embedded, _ = nn.utils.rnn.pad_packed_sequence( | |
embedded, batch_first=True | |
) # (batch, max_len, word_vec_size) | |
embedded = embedded[recover_idxs] | |
# recover output | |
output, _ = nn.utils.rnn.pad_packed_sequence( | |
output, batch_first=True | |
) # (batch, max_len, hidden_size * num_dir) | |
output = output[recover_idxs] | |
# recover hidden | |
if self.rnn_type == "lstm": | |
hidden = hidden[0] # hidden state | |
hidden = hidden[:, recover_idxs, :] # (num_layers * num_dirs, batch, hidden_size) | |
hidden = hidden.transpose(0, 1).contiguous() # (batch, num_layers * num_dirs, hidden_size) | |
hidden = hidden.view(hidden.size(0), -1) # (batch, num_layers * num_dirs * hidden_size) | |
# final output | |
finnal_output = [] | |
for ii in range(output.shape[0]): | |
finnal_output.append(output[ii, int(input_lengths_list[ii] - 1), :]) | |
finnal_output = torch.stack(finnal_output, dim=0) # (batch, number_dirs * hidden_size) | |
return output, hidden, embedded, finnal_output | |
def sort_inputs(self, input_labels): # sort input labels by descending | |
device = input_labels.device | |
input_lengths = (input_labels != 0).sum(1) | |
input_lengths_list = input_lengths.data.cpu().numpy().tolist() | |
sorted_input_lengths_list = np.sort(input_lengths_list)[::-1].tolist() # list of sorted input_lengths | |
sort_idxs = np.argsort(input_lengths_list)[::-1].tolist() | |
s2r = {s: r for r, s in enumerate(sort_idxs)} | |
recover_idxs = [s2r[s] for s in range(len(input_lengths_list))] | |
assert max(input_lengths_list) == input_labels.size(1) | |
# move to long tensor | |
sort_idxs = input_labels.data.new(sort_idxs).long().to(device) # Variable long | |
recover_idxs = input_labels.data.new(recover_idxs).long().to(device) # Variable long | |
return input_lengths_list, sorted_input_lengths_list, sort_idxs, recover_idxs | |