Spaces:
Runtime error
Runtime error
File size: 5,296 Bytes
a5f8a35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
r"""
Nucleus Sampling was introduced in the paper
`The Curious Case of Neural Text Degeneration <https://arxiv.org/abs/1904.09751>`_.
If you take it from here, make sure to cite them:
.. code-block:: text
@inproceedings{,
title={The Curious Case of Neural Text Degeneration},
author={Ari Holtzman and Jan Buys and Li Du and Maxwell Forbes and Yejin Choi},
journal={ICLR},
year={2020}
}
Some core parts of this code are adapted with minor modifications from Thomas Wolf's
gist: https://gist.githubusercontent.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
from typing import Callable, List, Tuple
import torch
import torch.nn.functional as F
class AutoRegressiveNucleusSampling(object):
"""
Implements the nucleus sampling for decoding captions. This class only works
for auto-regressive models (Transformer-like), not recurrent models (LSTM-like).
Parameters
----------
eos_index: int
The index of the end token (``[EOS]``) in vocabulary.
max_steps: int, optional (default = 50)
The maximum number of decoding steps.
nucleus_size: int, optional (default = 5)
Size of top-K nucleus for sampling.
"""
def __init__(
self,
eos_index: int,
max_steps: int = 50,
nucleus_size: float = 0.9,
):
super().__init__()
self._eos_index = eos_index
self.max_steps = max_steps
self.nucleus_size = nucleus_size
def search(
self, start_predictions: torch.Tensor, step: Callable[..., torch.Tensor]
) -> Tuple[torch.Tensor, None]:
batch_size = start_predictions.size()[0]
# List of `(batch_size, )` tensors. One for each timestep.
# This includes the start-of-sentence tokens, unlike the implementation
# in `AutoregressiveBeamSearch`. We will remove them in the end.
# Transpose `start_predictions` and make a list when prompt is provided.
predictions = [
start_predictions[:, i] for i in range(start_predictions.size(1))
]
for timestep in range(self.max_steps):
# Get the predictions from last timestep (most recent).
# shape: (batch_size, )
last_predictions = predictions[-1]
# If every predicted token from the last step is end-of-sentence token,
# then we can stop early.
if (last_predictions == self._eos_index).all():
break
# Combine step predictions made so far into one tensor. This is our
# "partial" caption input to the transformer.
# shape: (batch_size, timestep + 1)
predictions_so_far = torch.stack(predictions).permute(1, 0)
# Take a step, get the distribution of logits from next timestep.
# shape: (batch_size, num_classes)
current_logits = step(predictions_so_far)
# Sort logits in descending order to determine the nucleus.
sorted_logits, sorted_idx = torch.sort(current_logits, descending=True)
# Get cumulative softmax probabilites. For every instance in batch, a
# variable amount of tokens (N) will consitute the nucleus.
# shape: (batch_size, num_classes)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Determine indices of tokens at the tail of distribution. These will be
# removed from the nucleus.
sorted_idx_to_remove = cumulative_probs > self.nucleus_size
# Shift the indices to the right to keep the first token outside nucleus.
sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone()
sorted_idx_to_remove[..., 0] = 0
# Set logits to large negative value to avoid sampling them. Iterate over
# the batch of examples.
for t in range(current_logits.size()[0]):
idx_to_remove = sorted_idx[t][sorted_idx_to_remove[t]]
current_logits[t][idx_to_remove] = -1e12
# Set logits for last predicted token to a large negative value to
# avoid repetition.
current_logits[t][last_predictions[t]] = -1e12
# Sample from the filtered distribution.
# shape: (batch_size, num_classes)
current_probs = F.softmax(current_logits, dim=-1)
# shape: (batch_size, )
current_predictions = torch.multinomial(current_probs, 1)
current_predictions = current_predictions.view(batch_size)
# Set current predicted tokens to be end-of-sentence for instances where
# last prediction was also end-of-sentence token.
current_predictions[last_predictions == self._eos_index] = self._eos_index
predictions.append(current_predictions)
# Remove start-of-sentence token from predictions, and collect them together.
# shape: (batch_size, max_steps) .. or could be less than max_steps.
all_predictions = torch.stack(predictions[1:]).permute(1, 0)
# We don't return any logprobs of generated sequence with nucleus sampling,
# unlike `AutoregressiveBeamSearch`.
return all_predictions, None
|