File size: 5,296 Bytes
a5f8a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
r"""
Nucleus Sampling was introduced in the paper
`The Curious Case of Neural Text Degeneration <https://arxiv.org/abs/1904.09751>`_.
If you take it from here, make sure to cite them:

.. code-block:: text

    @inproceedings{,
        title={The Curious Case of Neural Text Degeneration},
        author={Ari Holtzman and Jan Buys and Li Du and Maxwell Forbes and Yejin Choi},
        journal={ICLR},
        year={2020}
    }

Some core parts of this code are adapted with minor modifications from Thomas Wolf's
gist: https://gist.githubusercontent.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""

from typing import Callable, List, Tuple

import torch
import torch.nn.functional as F


class AutoRegressiveNucleusSampling(object):
    """
    Implements the nucleus sampling for decoding captions. This class only works
    for auto-regressive models (Transformer-like), not recurrent models (LSTM-like).

    Parameters
    ----------
    eos_index: int
        The index of the end token (``[EOS]``) in vocabulary.
    max_steps: int, optional (default = 50)
        The maximum number of decoding steps.
    nucleus_size: int, optional (default = 5)
        Size of top-K nucleus for sampling.
    """

    def __init__(
        self,
        eos_index: int,
        max_steps: int = 50,
        nucleus_size: float = 0.9,
    ):
        super().__init__()
        self._eos_index = eos_index
        self.max_steps = max_steps
        self.nucleus_size = nucleus_size

    def search(
        self, start_predictions: torch.Tensor, step: Callable[..., torch.Tensor]
    ) -> Tuple[torch.Tensor, None]:

        batch_size = start_predictions.size()[0]

        # List of `(batch_size, )` tensors. One for each timestep.
        # This includes the start-of-sentence tokens, unlike the implementation
        # in `AutoregressiveBeamSearch`. We will remove them in the end.

        # Transpose `start_predictions` and make a list when prompt is provided.
        predictions = [
            start_predictions[:, i] for i in range(start_predictions.size(1))
        ]

        for timestep in range(self.max_steps):
            # Get the predictions from last timestep (most recent).
            # shape: (batch_size, )
            last_predictions = predictions[-1]

            # If every predicted token from the last step is end-of-sentence token,
            # then we can stop early.
            if (last_predictions == self._eos_index).all():
                break

            # Combine step predictions made so far into one tensor. This is our
            # "partial" caption input to the transformer.
            # shape: (batch_size, timestep + 1)
            predictions_so_far = torch.stack(predictions).permute(1, 0)

            # Take a step, get the distribution of logits from next timestep.
            # shape: (batch_size, num_classes)
            current_logits = step(predictions_so_far)

            # Sort logits in descending order to determine the nucleus.
            sorted_logits, sorted_idx = torch.sort(current_logits, descending=True)

            # Get cumulative softmax probabilites. For every instance in batch, a
            #  variable amount of tokens (N) will consitute the nucleus.
            # shape: (batch_size, num_classes)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # Determine indices of tokens at the tail of distribution. These will be
            # removed from the nucleus.
            sorted_idx_to_remove = cumulative_probs > self.nucleus_size

            # Shift the indices to the right to keep the first token outside nucleus.
            sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone()
            sorted_idx_to_remove[..., 0] = 0

            # Set logits to large negative value to avoid sampling them. Iterate over
            # the batch of examples.
            for t in range(current_logits.size()[0]):
                idx_to_remove = sorted_idx[t][sorted_idx_to_remove[t]]
                current_logits[t][idx_to_remove] = -1e12

                # Set logits for last predicted token to a large negative value to
                # avoid repetition.
                current_logits[t][last_predictions[t]] = -1e12

            # Sample from the filtered distribution.
            # shape: (batch_size, num_classes)
            current_probs = F.softmax(current_logits, dim=-1)

            # shape: (batch_size, )
            current_predictions = torch.multinomial(current_probs, 1)
            current_predictions = current_predictions.view(batch_size)

            # Set current predicted tokens to be end-of-sentence for instances where
            # last prediction was also end-of-sentence token.
            current_predictions[last_predictions == self._eos_index] = self._eos_index

            predictions.append(current_predictions)

        # Remove start-of-sentence token from predictions, and collect them together.
        # shape: (batch_size, max_steps) .. or could be less than max_steps.
        all_predictions = torch.stack(predictions[1:]).permute(1, 0)

        # We don't return any logprobs of generated sequence with nucleus sampling,
        # unlike `AutoregressiveBeamSearch`.
        return all_predictions, None