| | Tutorial: Simple LSTM |
| | ===================== |
| |
|
| | In this tutorial we will extend fairseq by adding a new |
| | :class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source |
| | sentence with an LSTM and then passes the final hidden state to a second LSTM |
| | that decodes the target sentence (without attention). |
| |
|
| | This tutorial covers: |
| |
|
| | 1. **Writing an Encoder and Decoder** to encode/decode the source/target |
| | sentence, respectively. |
| | 2. **Registering a new Model** so that it can be used with the existing |
| | :ref:`Command-line tools`. |
| | 3. **Training the Model** using the existing command-line tools. |
| | 4. **Making generation faster** by modifying the Decoder to use |
| | :ref:`Incremental decoding`. |
| |
|
| |
|
| | 1. Building an Encoder and Decoder |
| | ---------------------------------- |
| |
|
| | In this section we'll define a simple LSTM Encoder and Decoder. All Encoders |
| | should implement the :class:`~fairseq.models.FairseqEncoder` interface and |
| | Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface. |
| | These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders |
| | and FairseqDecoders can be written and used in the same ways as ordinary PyTorch |
| | Modules. |
| |
|
| |
|
| | Encoder |
| | ~~~~~~~ |
| |
|
| | Our Encoder will embed the tokens in the source sentence, feed them to a |
| | :class:`torch.nn.LSTM` and return the final hidden state. To create our encoder |
| | save the following in a new file named :file:`fairseq/models/simple_lstm.py`:: |
| | |
| | import torch.nn as nn |
| | from fairseq import utils |
| | from fairseq.models import FairseqEncoder |
| |
|
| | class SimpleLSTMEncoder(FairseqEncoder): |
| |
|
| | def __init__( |
| | self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1, |
| | ): |
| | super().__init__(dictionary) |
| | self.args = args |
| |
|
| | # Our encoder will embed the inputs before feeding them to the LSTM. |
| | self.embed_tokens = nn.Embedding( |
| | num_embeddings=len(dictionary), |
| | embedding_dim=embed_dim, |
| | padding_idx=dictionary.pad(), |
| | ) |
| | self.dropout = nn.Dropout(p=dropout) |
| |
|
| | # We'll use a single-layer, unidirectional LSTM for simplicity. |
| | self.lstm = nn.LSTM( |
| | input_size=embed_dim, |
| | hidden_size=hidden_dim, |
| | num_layers=1, |
| | bidirectional=False, |
| | batch_first=True, |
| | ) |
| |
|
| | def forward(self, src_tokens, src_lengths): |
| | # The inputs to the ``forward()`` function are determined by the |
| | # Task, and in particular the ``'net_input'`` key in each |
| | # mini-batch. We discuss Tasks in the next tutorial, but for now just |
| | # know that *src_tokens* has shape `(batch, src_len)` and *src_lengths* |
| | # has shape `(batch)`. |
| |
|
| | # Note that the source is typically padded on the left. This can be |
| | # configured by adding the `--left-pad-source "False"` command-line |
| | # argument, but here we'll make the Encoder handle either kind of |
| | # padding by converting everything to be right-padded. |
| | if self.args.left_pad_source: |
| | # Convert left-padding to right-padding. |
| | src_tokens = utils.convert_padding_direction( |
| | src_tokens, |
| | padding_idx=self.dictionary.pad(), |
| | left_to_right=True |
| | ) |
| |
|
| | # Embed the source. |
| | x = self.embed_tokens(src_tokens) |
| |
|
| | # Apply dropout. |
| | x = self.dropout(x) |
| |
|
| | # Pack the sequence into a PackedSequence object to feed to the LSTM. |
| | x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True) |
| |
|
| | # Get the output from the LSTM. |
| | _outputs, (final_hidden, _final_cell) = self.lstm(x) |
| |
|
| | # Return the Encoder's output. This can be any object and will be |
| | # passed directly to the Decoder. |
| | return { |
| | # this will have shape `(bsz, hidden_dim)` |
| | 'final_hidden': final_hidden.squeeze(0), |
| | } |
| |
|
| | # Encoders are required to implement this method so that we can rearrange |
| | # the order of the batch elements during inference (e.g., beam search). |
| | def reorder_encoder_out(self, encoder_out, new_order): |
| | """ |
| | Reorder encoder output according to `new_order`. |
| |
|
| | Args: |
| | encoder_out: output from the ``forward()`` method |
| | new_order (LongTensor): desired order |
| |
|
| | Returns: |
| | `encoder_out` rearranged according to `new_order` |
| | """ |
| | final_hidden = encoder_out['final_hidden'] |
| | return { |
| | 'final_hidden': final_hidden.index_select(0, new_order), |
| | } |
| |
|
| |
|
| | Decoder |
| | ~~~~~~~ |
| |
|
| | Our Decoder will predict the next word, conditioned on the Encoder's final |
| | hidden state and an embedded representation of the previous target word -- which |
| | is sometimes called *teacher forcing*. More specifically, we'll use a |
| | :class:`torch.nn.LSTM` to produce a sequence of hidden states that we'll project |
| | to the size of the output vocabulary to predict each target word. |
| |
|
| | :: |
| |
|
| | import torch |
| | from fairseq.models import FairseqDecoder |
| |
|
| | class SimpleLSTMDecoder(FairseqDecoder): |
| |
|
| | def __init__( |
| | self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128, |
| | dropout=0.1, |
| | ): |
| | super().__init__(dictionary) |
| |
|
| | # Our decoder will embed the inputs before feeding them to the LSTM. |
| | self.embed_tokens = nn.Embedding( |
| | num_embeddings=len(dictionary), |
| | embedding_dim=embed_dim, |
| | padding_idx=dictionary.pad(), |
| | ) |
| | self.dropout = nn.Dropout(p=dropout) |
| |
|
| | # We'll use a single-layer, unidirectional LSTM for simplicity. |
| | self.lstm = nn.LSTM( |
| | # For the first layer we'll concatenate the Encoder's final hidden |
| | # state with the embedded target tokens. |
| | input_size=encoder_hidden_dim + embed_dim, |
| | hidden_size=hidden_dim, |
| | num_layers=1, |
| | bidirectional=False, |
| | ) |
| |
|
| | # Define the output projection. |
| | self.output_projection = nn.Linear(hidden_dim, len(dictionary)) |
| |
|
| | # During training Decoders are expected to take the entire target sequence |
| | # (shifted right by one position) and produce logits over the vocabulary. |
| | # The *prev_output_tokens* tensor begins with the end-of-sentence symbol, |
| | # ``dictionary.eos()``, followed by the target sequence. |
| | def forward(self, prev_output_tokens, encoder_out): |
| | """ |
| | Args: |
| | prev_output_tokens (LongTensor): previous decoder outputs of shape |
| | `(batch, tgt_len)`, for teacher forcing |
| | encoder_out (Tensor, optional): output from the encoder, used for |
| | encoder-side attention |
| |
|
| | Returns: |
| | tuple: |
| | - the last decoder layer's output of shape |
| | `(batch, tgt_len, vocab)` |
| | - the last decoder layer's attention weights of shape |
| | `(batch, tgt_len, src_len)` |
| | """ |
| | bsz, tgt_len = prev_output_tokens.size() |
| |
|
| | # Extract the final hidden state from the Encoder. |
| | final_encoder_hidden = encoder_out['final_hidden'] |
| |
|
| | # Embed the target sequence, which has been shifted right by one |
| | # position and now starts with the end-of-sentence symbol. |
| | x = self.embed_tokens(prev_output_tokens) |
| |
|
| | # Apply dropout. |
| | x = self.dropout(x) |
| |
|
| | # Concatenate the Encoder's final hidden state to *every* embedded |
| | # target token. |
| | x = torch.cat( |
| | [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)], |
| | dim=2, |
| | ) |
| |
|
| | # Using PackedSequence objects in the Decoder is harder than in the |
| | # Encoder, since the targets are not sorted in descending length order, |
| | # which is a requirement of ``pack_padded_sequence()``. Instead we'll |
| | # feed nn.LSTM directly. |
| | initial_state = ( |
| | final_encoder_hidden.unsqueeze(0), # hidden |
| | torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell |
| | ) |
| | output, _ = self.lstm( |
| | x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)` |
| | initial_state, |
| | ) |
| | x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)` |
| |
|
| | # Project the outputs to the size of the vocabulary. |
| | x = self.output_projection(x) |
| |
|
| | # Return the logits and ``None`` for the attention weights |
| | return x, None |
| |
|
| |
|
| | 2. Registering the Model |
| | ------------------------ |
| |
|
| | Now that we've defined our Encoder and Decoder we must *register* our model with |
| | fairseq using the :func:`~fairseq.models.register_model` function decorator. |
| | Once the model is registered we'll be able to use it with the existing |
| | :ref:`Command-line Tools`. |
| |
|
| | All registered models must implement the |
| | :class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence |
| | models (i.e., any model with a single Encoder and Decoder), we can instead |
| | implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface. |
| |
|
| | Create a small wrapper class in the same file and register it in fairseq with |
| | the name ``'simple_lstm'``:: |
| | |
| | from fairseq.models import FairseqEncoderDecoderModel, register_model |
| |
|
| | # Note: the register_model "decorator" should immediately precede the |
| | # definition of the Model class. |
| |
|
| | @register_model('simple_lstm') |
| | class SimpleLSTMModel(FairseqEncoderDecoderModel): |
| |
|
| | @staticmethod |
| | def add_args(parser): |
| | # Models can override this method to add new command-line arguments. |
| | # Here we'll add some new command-line arguments to configure dropout |
| | # and the dimensionality of the embeddings and hidden states. |
| | parser.add_argument( |
| | '--encoder-embed-dim', type=int, metavar='N', |
| | help='dimensionality of the encoder embeddings', |
| | ) |
| | parser.add_argument( |
| | '--encoder-hidden-dim', type=int, metavar='N', |
| | help='dimensionality of the encoder hidden state', |
| | ) |
| | parser.add_argument( |
| | '--encoder-dropout', type=float, default=0.1, |
| | help='encoder dropout probability', |
| | ) |
| | parser.add_argument( |
| | '--decoder-embed-dim', type=int, metavar='N', |
| | help='dimensionality of the decoder embeddings', |
| | ) |
| | parser.add_argument( |
| | '--decoder-hidden-dim', type=int, metavar='N', |
| | help='dimensionality of the decoder hidden state', |
| | ) |
| | parser.add_argument( |
| | '--decoder-dropout', type=float, default=0.1, |
| | help='decoder dropout probability', |
| | ) |
| |
|
| | @classmethod |
| | def build_model(cls, args, task): |
| | # Fairseq initializes models by calling the ``build_model()`` |
| | # function. This provides more flexibility, since the returned model |
| | # instance can be of a different type than the one that was called. |
| | # In this case we'll just return a SimpleLSTMModel instance. |
| |
|
| | # Initialize our Encoder and Decoder. |
| | encoder = SimpleLSTMEncoder( |
| | args=args, |
| | dictionary=task.source_dictionary, |
| | embed_dim=args.encoder_embed_dim, |
| | hidden_dim=args.encoder_hidden_dim, |
| | dropout=args.encoder_dropout, |
| | ) |
| | decoder = SimpleLSTMDecoder( |
| | dictionary=task.target_dictionary, |
| | encoder_hidden_dim=args.encoder_hidden_dim, |
| | embed_dim=args.decoder_embed_dim, |
| | hidden_dim=args.decoder_hidden_dim, |
| | dropout=args.decoder_dropout, |
| | ) |
| | model = SimpleLSTMModel(encoder, decoder) |
| |
|
| | # Print the model architecture. |
| | print(model) |
| |
|
| | return model |
| |
|
| | # We could override the ``forward()`` if we wanted more control over how |
| | # the encoder and decoder interact, but it's not necessary for this |
| | # tutorial since we can inherit the default implementation provided by |
| | # the FairseqEncoderDecoderModel base class, which looks like: |
| | # |
| | # def forward(self, src_tokens, src_lengths, prev_output_tokens): |
| | # encoder_out = self.encoder(src_tokens, src_lengths) |
| | # decoder_out = self.decoder(prev_output_tokens, encoder_out) |
| | # return decoder_out |
| |
|
| | Finally let's define a *named architecture* with the configuration for our |
| | model. This is done with the :func:`~fairseq.models.register_model_architecture` |
| | function decorator. Thereafter this named architecture can be used with the |
| | ``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``:: |
| | |
| | from fairseq.models import register_model_architecture |
| |
|
| | # The first argument to ``register_model_architecture()`` should be the name |
| | # of the model we registered above (i.e., 'simple_lstm'). The function we |
| | # register here should take a single argument *args* and modify it in-place |
| | # to match the desired architecture. |
| |
|
| | @register_model_architecture('simple_lstm', 'tutorial_simple_lstm') |
| | def tutorial_simple_lstm(args): |
| | # We use ``getattr()`` to prioritize arguments that are explicitly given |
| | # on the command-line, so that the defaults defined below are only used |
| | # when no other value has been specified. |
| | args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) |
| | args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256) |
| | args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256) |
| | args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256) |
| |
|
| |
|
| | 3. Training the Model |
| | --------------------- |
| |
|
| | Now we're ready to train the model. We can use the existing :ref:`fairseq-train` |
| | command-line tool for this, making sure to specify our new Model architecture |
| | (``--arch tutorial_simple_lstm``). |
| |
|
| | .. note:: |
| |
|
| | Make sure you've already preprocessed the data from the IWSLT example in the |
| | :file:`examples/translation/` directory. |
| |
|
| | .. code-block:: console |
| |
|
| | > fairseq-train data-bin/iwslt14.tokenized.de-en \ |
| | --arch tutorial_simple_lstm \ |
| | --encoder-dropout 0.2 --decoder-dropout 0.2 \ |
| | --optimizer adam --lr 0.005 --lr-shrink 0.5 \ |
| | --max-tokens 12000 |
| | (...) |
| | | epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396 |
| | | epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954 |
| |
|
| | The model files should appear in the :file:`checkpoints/` directory. While this |
| | model architecture is not very good, we can use the :ref:`fairseq-generate` script to |
| | generate translations and compute our BLEU score over the test set: |
| |
|
| | .. code-block:: console |
| |
|
| | > fairseq-generate data-bin/iwslt14.tokenized.de-en \ |
| | --path checkpoints/checkpoint_best.pt \ |
| | --beam 5 \ |
| | --remove-bpe |
| | (...) |
| | | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s) |
| | | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146) |
| |
|
| |
|
| | 4. Making generation faster |
| | --------------------------- |
| |
|
| | While autoregressive generation from sequence-to-sequence models is inherently |
| | slow, our implementation above is especially slow because it recomputes the |
| | entire sequence of Decoder hidden states for every output token (i.e., it is |
| | ``O(n^2)``). We can make this significantly faster by instead caching the |
| | previous hidden states. |
| |
|
| | In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a |
| | special mode at inference time where the Model only receives a single timestep |
| | of input corresponding to the immediately previous output token (for teacher |
| | forcing) and must produce the next output incrementally. Thus the model must |
| | cache any long-term state that is needed about the sequence, e.g., hidden |
| | states, convolutional states, etc. |
| |
|
| | To implement incremental decoding we will modify our model to implement the |
| | :class:`~fairseq.models.FairseqIncrementalDecoder` interface. Compared to the |
| | standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental |
| | decoder interface allows ``forward()`` methods to take an extra keyword argument |
| | (*incremental_state*) that can be used to cache state across time-steps. |
| |
|
| | Let's replace our ``SimpleLSTMDecoder`` with an incremental one:: |
| | |
| | import torch |
| | from fairseq.models import FairseqIncrementalDecoder |
| |
|
| | class SimpleLSTMDecoder(FairseqIncrementalDecoder): |
| |
|
| | def __init__( |
| | self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128, |
| | dropout=0.1, |
| | ): |
| | # This remains the same as before. |
| | super().__init__(dictionary) |
| | self.embed_tokens = nn.Embedding( |
| | num_embeddings=len(dictionary), |
| | embedding_dim=embed_dim, |
| | padding_idx=dictionary.pad(), |
| | ) |
| | self.dropout = nn.Dropout(p=dropout) |
| | self.lstm = nn.LSTM( |
| | input_size=encoder_hidden_dim + embed_dim, |
| | hidden_size=hidden_dim, |
| | num_layers=1, |
| | bidirectional=False, |
| | ) |
| | self.output_projection = nn.Linear(hidden_dim, len(dictionary)) |
| |
|
| | # We now take an additional kwarg (*incremental_state*) for caching the |
| | # previous hidden and cell states. |
| | def forward(self, prev_output_tokens, encoder_out, incremental_state=None): |
| | if incremental_state is not None: |
| | # If the *incremental_state* argument is not ``None`` then we are |
| | # in incremental inference mode. While *prev_output_tokens* will |
| | # still contain the entire decoded prefix, we will only use the |
| | # last step and assume that the rest of the state is cached. |
| | prev_output_tokens = prev_output_tokens[:, -1:] |
| |
|
| | # This remains the same as before. |
| | bsz, tgt_len = prev_output_tokens.size() |
| | final_encoder_hidden = encoder_out['final_hidden'] |
| | x = self.embed_tokens(prev_output_tokens) |
| | x = self.dropout(x) |
| | x = torch.cat( |
| | [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)], |
| | dim=2, |
| | ) |
| |
|
| | # We will now check the cache and load the cached previous hidden and |
| | # cell states, if they exist, otherwise we will initialize them to |
| | # zeros (as before). We will use the ``utils.get_incremental_state()`` |
| | # and ``utils.set_incremental_state()`` helpers. |
| | initial_state = utils.get_incremental_state( |
| | self, incremental_state, 'prev_state', |
| | ) |
| | if initial_state is None: |
| | # first time initialization, same as the original version |
| | initial_state = ( |
| | final_encoder_hidden.unsqueeze(0), # hidden |
| | torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell |
| | ) |
| |
|
| | # Run one step of our LSTM. |
| | output, latest_state = self.lstm(x.transpose(0, 1), initial_state) |
| |
|
| | # Update the cache with the latest hidden and cell states. |
| | utils.set_incremental_state( |
| | self, incremental_state, 'prev_state', latest_state, |
| | ) |
| |
|
| | # This remains the same as before |
| | x = output.transpose(0, 1) |
| | x = self.output_projection(x) |
| | return x, None |
| |
|
| | # The ``FairseqIncrementalDecoder`` interface also requires implementing a |
| | # ``reorder_incremental_state()`` method, which is used during beam search |
| | # to select and reorder the incremental state. |
| | def reorder_incremental_state(self, incremental_state, new_order): |
| | # Load the cached state. |
| | prev_state = utils.get_incremental_state( |
| | self, incremental_state, 'prev_state', |
| | ) |
| |
|
| | # Reorder batches according to *new_order*. |
| | reordered_state = ( |
| | prev_state[0].index_select(1, new_order), # hidden |
| | prev_state[1].index_select(1, new_order), # cell |
| | ) |
| |
|
| | # Update the cached state. |
| | utils.set_incremental_state( |
| | self, incremental_state, 'prev_state', reordered_state, |
| | ) |
| |
|
| | Finally, we can rerun generation and observe the speedup: |
| |
|
| | .. code-block:: console |
| |
|
| | # Before |
| |
|
| | > fairseq-generate data-bin/iwslt14.tokenized.de-en \ |
| | --path checkpoints/checkpoint_best.pt \ |
| | --beam 5 \ |
| | --remove-bpe |
| | (...) |
| | | Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s) |
| | | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146) |
| |
|
| | # After |
| |
|
| | > fairseq-generate data-bin/iwslt14.tokenized.de-en \ |
| | --path checkpoints/checkpoint_best.pt \ |
| | --beam 5 \ |
| | --remove-bpe |
| | (...) |
| | | Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s) |
| | | Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146) |
| |
|