r""" Nucleus Sampling was introduced in the paper `The Curious Case of Neural Text Degeneration `_. If you take it from here, make sure to cite them: .. code-block:: text @inproceedings{, title={The Curious Case of Neural Text Degeneration}, author={Ari Holtzman and Jan Buys and Li Du and Maxwell Forbes and Yejin Choi}, journal={ICLR}, year={2020} } Some core parts of this code are adapted with minor modifications from Thomas Wolf's gist: https://gist.githubusercontent.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ from typing import Callable, List, Tuple import torch import torch.nn.functional as F class AutoRegressiveNucleusSampling(object): """ Implements the nucleus sampling for decoding captions. This class only works for auto-regressive models (Transformer-like), not recurrent models (LSTM-like). Parameters ---------- eos_index: int The index of the end token (``[EOS]``) in vocabulary. max_steps: int, optional (default = 50) The maximum number of decoding steps. nucleus_size: int, optional (default = 5) Size of top-K nucleus for sampling. """ def __init__( self, eos_index: int, max_steps: int = 50, nucleus_size: float = 0.9, ): super().__init__() self._eos_index = eos_index self.max_steps = max_steps self.nucleus_size = nucleus_size def search( self, start_predictions: torch.Tensor, step: Callable[..., torch.Tensor] ) -> Tuple[torch.Tensor, None]: batch_size = start_predictions.size()[0] # List of `(batch_size, )` tensors. One for each timestep. # This includes the start-of-sentence tokens, unlike the implementation # in `AutoregressiveBeamSearch`. We will remove them in the end. # Transpose `start_predictions` and make a list when prompt is provided. predictions = [ start_predictions[:, i] for i in range(start_predictions.size(1)) ] for timestep in range(self.max_steps): # Get the predictions from last timestep (most recent). # shape: (batch_size, ) last_predictions = predictions[-1] # If every predicted token from the last step is end-of-sentence token, # then we can stop early. if (last_predictions == self._eos_index).all(): break # Combine step predictions made so far into one tensor. This is our # "partial" caption input to the transformer. # shape: (batch_size, timestep + 1) predictions_so_far = torch.stack(predictions).permute(1, 0) # Take a step, get the distribution of logits from next timestep. # shape: (batch_size, num_classes) current_logits = step(predictions_so_far) # Sort logits in descending order to determine the nucleus. sorted_logits, sorted_idx = torch.sort(current_logits, descending=True) # Get cumulative softmax probabilites. For every instance in batch, a # variable amount of tokens (N) will consitute the nucleus. # shape: (batch_size, num_classes) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Determine indices of tokens at the tail of distribution. These will be # removed from the nucleus. sorted_idx_to_remove = cumulative_probs > self.nucleus_size # Shift the indices to the right to keep the first token outside nucleus. sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone() sorted_idx_to_remove[..., 0] = 0 # Set logits to large negative value to avoid sampling them. Iterate over # the batch of examples. for t in range(current_logits.size()[0]): idx_to_remove = sorted_idx[t][sorted_idx_to_remove[t]] current_logits[t][idx_to_remove] = -1e12 # Set logits for last predicted token to a large negative value to # avoid repetition. current_logits[t][last_predictions[t]] = -1e12 # Sample from the filtered distribution. # shape: (batch_size, num_classes) current_probs = F.softmax(current_logits, dim=-1) # shape: (batch_size, ) current_predictions = torch.multinomial(current_probs, 1) current_predictions = current_predictions.view(batch_size) # Set current predicted tokens to be end-of-sentence for instances where # last prediction was also end-of-sentence token. current_predictions[last_predictions == self._eos_index] = self._eos_index predictions.append(current_predictions) # Remove start-of-sentence token from predictions, and collect them together. # shape: (batch_size, max_steps) .. or could be less than max_steps. all_predictions = torch.stack(predictions[1:]).permute(1, 0) # We don't return any logprobs of generated sequence with nucleus sampling, # unlike `AutoregressiveBeamSearch`. return all_predictions, None