"""Define RNN-based encoders.""" import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pack_padded_sequence as pack from torch.nn.utils.rnn import pad_packed_sequence as unpack from onmt.encoders.encoder import EncoderBase from onmt.utils.rnn_factory import rnn_factory class RNNEncoder(EncoderBase): """A generic recurrent neural network encoder. Args: rnn_type (str): style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU] bidirectional (bool) : use a bidirectional RNN num_layers (int) : number of stacked layers hidden_size (int) : hidden size of each layer dropout (float) : dropout value for :class:`torch.nn.Dropout` embeddings (onmt.modules.Embeddings): embedding module to use """ def __init__( self, rnn_type, bidirectional, num_layers, hidden_size, dropout=0.0, embeddings=None, use_bridge=False, ): super(RNNEncoder, self).__init__() assert embeddings is not None num_directions = 2 if bidirectional else 1 assert hidden_size % num_directions == 0 hidden_size = hidden_size // num_directions self.embeddings = embeddings self.rnn, self.no_pack_padded_seq = rnn_factory( rnn_type, input_size=embeddings.embedding_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional, ) # Initialize the bridge layer self.use_bridge = use_bridge if self.use_bridge: self._initialize_bridge(rnn_type, hidden_size, num_layers) @classmethod def from_opt(cls, opt, embeddings): """Alternate constructor.""" return cls( opt.rnn_type, opt.brnn, opt.enc_layers, opt.enc_hid_size, opt.dropout[0] if type(opt.dropout) is list else opt.dropout, embeddings, opt.bridge, ) def forward(self, src, src_len=None): """See :func:`EncoderBase.forward()`""" emb = self.embeddings(src) packed_emb = emb if src_len is not None and not self.no_pack_padded_seq: # src lengths data is wrapped inside a Tensor. src_len_list = src_len.view(-1).tolist() packed_emb = pack(emb, src_len_list, batch_first=True, enforce_sorted=False) enc_out, enc_final_hs = self.rnn(packed_emb) if src_len is not None and not self.no_pack_padded_seq: enc_out = unpack(enc_out, batch_first=True)[0] if self.use_bridge: enc_final_hs = self._bridge(enc_final_hs) return enc_out, enc_final_hs, src_len def _initialize_bridge(self, rnn_type, hidden_size, num_layers): # LSTM has hidden and cell state, other only one number_of_states = 2 if rnn_type == "LSTM" else 1 # Total number of states self.total_hidden_dim = hidden_size * num_layers # Build a linear layer for each self.bridge = nn.ModuleList( [ nn.Linear(self.total_hidden_dim, self.total_hidden_dim, bias=True) for _ in range(number_of_states) ] ) def _bridge(self, hidden): """Forward hidden state through bridge. final hidden state ``(num_layers x dir, batch, hidden_size)`` """ def bottle_hidden(linear, states): """ Transform from 3D to 2D, apply linear and return initial size """ states = states.permute(1, 0, 2).contiguous() size = states.size() result = linear(states.view(-1, self.total_hidden_dim)) result = F.relu(result).view(size) return result.permute(1, 0, 2).contiguous() if isinstance(hidden, tuple): # LSTM outs = tuple( [ bottle_hidden(layer, hidden[ix]) for ix, layer in enumerate(self.bridge) ] ) else: outs = bottle_hidden(self.bridge[0], hidden) return outs def update_dropout(self, dropout, attention_dropout=None): self.rnn.dropout = dropout