File size: 10,813 Bytes
24d0437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
from typing import Tuple

import numpy as np
import torch
import torch.nn
from torch.nn.functional import softmax
from torch.nn.utils.rnn import pack_padded_sequence

import flair
from flair.data import Dictionary, Label, List, Sentence

START_TAG: str = "<START>"
STOP_TAG: str = "<STOP>"


class ViterbiLoss(torch.nn.Module):
    """
    Calculates the loss for each sequence up to its length t.
    """

    def __init__(self, tag_dictionary: Dictionary):
        """
        :param tag_dictionary: tag_dictionary of task
        """
        super(ViterbiLoss, self).__init__()
        self.tag_dictionary = tag_dictionary
        self.tagset_size = len(tag_dictionary)
        self.start_tag = tag_dictionary.get_idx_for_item(START_TAG)
        self.stop_tag = tag_dictionary.get_idx_for_item(STOP_TAG)

    def forward(self, features_tuple: tuple, targets: torch.Tensor) -> torch.Tensor:
        """
        Forward propagation of Viterbi Loss
        :param features_tuple: CRF scores from forward method in shape (batch size, seq len, tagset size, tagset size),
            lengths of sentences in batch, transitions from CRF
        :param targets: true tags for sentences which will be converted to matrix indices.
        :return: average Viterbi Loss over batch size
        """
        features, lengths, transitions = features_tuple

        batch_size = features.size(0)
        seq_len = features.size(1)

        targets, targets_matrix_indices = self._format_targets(targets, lengths)
        targets_matrix_indices = torch.tensor(targets_matrix_indices, dtype=torch.long).unsqueeze(2).to(flair.device)

        # scores_at_targets[range(features.shape[0]), lengths.values -1]
        # Squeeze crf scores matrices in 1-dim shape and gather scores at targets by matrix indices
        scores_at_targets = torch.gather(features.view(batch_size, seq_len, -1), 2, targets_matrix_indices)
        scores_at_targets = pack_padded_sequence(scores_at_targets, lengths, batch_first=True)[0]
        transitions_to_stop = transitions[
            np.repeat(self.stop_tag, features.shape[0]),
            [target[length - 1] for target, length in zip(targets, lengths)],
        ]
        gold_score = scores_at_targets.sum() + transitions_to_stop.sum()

        scores_upto_t = torch.zeros(batch_size, self.tagset_size, device=flair.device)

        for t in range(max(lengths)):
            batch_size_t = sum(
                [length > t for length in lengths]
            )  # since batch is ordered, we can save computation time by reducing our effective batch_size

            if t == 0:
                # Initially, get scores from <start> tag to all other tags
                scores_upto_t[:batch_size_t] = (
                    scores_upto_t[:batch_size_t] + features[:batch_size_t, t, :, self.start_tag]
                )
            else:
                # We add scores at current timestep to scores accumulated up to previous timestep, and log-sum-exp
                # Remember, the cur_tag of the previous timestep is the prev_tag of this timestep
                scores_upto_t[:batch_size_t] = self._log_sum_exp(
                    features[:batch_size_t, t, :, :] + scores_upto_t[:batch_size_t].unsqueeze(1), dim=2
                )

        all_paths_scores = self._log_sum_exp(scores_upto_t + transitions[self.stop_tag].unsqueeze(0), dim=1).sum()

        viterbi_loss = all_paths_scores - gold_score

        return viterbi_loss

    @staticmethod
    def _log_sum_exp(tensor, dim):
        """
        Calculates the log-sum-exponent of a tensor's dimension in a numerically stable way.
        :param tensor: tensor
        :param dim: dimension to calculate log-sum-exp of
        :return: log-sum-exp
        """
        m, _ = torch.max(tensor, dim)
        m_expanded = m.unsqueeze(dim).expand_as(tensor)
        return m + torch.log(torch.sum(torch.exp(tensor - m_expanded), dim))

    def _format_targets(self, targets: torch.Tensor, lengths: torch.IntTensor):
        """
        Formats targets into matrix indices.
        CRF scores contain per sentence, per token a (tagset_size x tagset_size) matrix, containing emission score for
            token j + transition prob from previous token i. Means, if we think of our rows as "to tag" and our columns
            as "from tag", the matrix in cell [10,5] would contain the emission score for tag 10 + transition score
            from previous tag 5 and could directly be addressed through the 1-dim indices (10 + tagset_size * 5) = 70,
            if our tagset consists of 12 tags.
        :param targets: targets as in tag dictionary
        :param lengths: lengths of sentences in batch
        """
        targets_per_sentence = []

        targets_list = targets.tolist()
        for cut in lengths:
            targets_per_sentence.append(targets_list[:cut])
            targets_list = targets_list[cut:]

        for t in targets_per_sentence:
            t += [self.tag_dictionary.get_idx_for_item(STOP_TAG)] * (int(lengths.max().item()) - len(t))

        matrix_indices = list(
            map(
                lambda s: [self.tag_dictionary.get_idx_for_item(START_TAG) + (s[0] * self.tagset_size)]
                + [s[i] + (s[i + 1] * self.tagset_size) for i in range(0, len(s) - 1)],
                targets_per_sentence,
            )
        )

        return targets_per_sentence, matrix_indices


class ViterbiDecoder:
    """
    Decodes a given sequence using the Viterbi algorithm.
    """

    def __init__(self, tag_dictionary: Dictionary):
        """
        :param tag_dictionary: Dictionary of tags for sequence labeling task
        """
        self.tag_dictionary = tag_dictionary
        self.tagset_size = len(tag_dictionary)
        self.start_tag = tag_dictionary.get_idx_for_item(START_TAG)
        self.stop_tag = tag_dictionary.get_idx_for_item(STOP_TAG)

    def decode(
        self, features_tuple: tuple, probabilities_for_all_classes: bool, sentences: List[Sentence]
    ) -> Tuple[List, List]:
        """
        Decoding function returning the most likely sequence of tags.
        :param features_tuple: CRF scores from forward method in shape (batch size, seq len, tagset size, tagset size),
            lengths of sentence in batch, transitions of CRF
        :param probabilities_for_all_classes: whether to return probabilities for all tags
        :return: decoded sequences
        """
        features, lengths, transitions = features_tuple
        all_tags = []

        batch_size = features.size(0)
        seq_len = features.size(1)

        # Create a tensor to hold accumulated sequence scores at each current tag
        scores_upto_t = torch.zeros(batch_size, seq_len + 1, self.tagset_size).to(flair.device)
        # Create a tensor to hold back-pointers
        # i.e., indices of the previous_tag that corresponds to maximum accumulated score at current tag
        # Let pads be the <end> tag index, since that was the last tag in the decoded sequence
        backpointers = (
            torch.ones((batch_size, seq_len + 1, self.tagset_size), dtype=torch.long, device=flair.device)
            * self.stop_tag
        )

        for t in range(seq_len):
            batch_size_t = sum([length > t for length in lengths])  # effective batch size (sans pads) at this timestep
            terminates = [i for i, length in enumerate(lengths) if length == t + 1]

            if t == 0:
                scores_upto_t[:batch_size_t, t] = features[:batch_size_t, t, :, self.start_tag]
                backpointers[:batch_size_t, t, :] = (
                    torch.ones((batch_size_t, self.tagset_size), dtype=torch.long) * self.start_tag
                )
            else:
                # We add scores at current timestep to scores accumulated up to previous timestep, and
                # choose the previous timestep that corresponds to the max. accumulated score for each current timestep
                scores_upto_t[:batch_size_t, t], backpointers[:batch_size_t, t, :] = torch.max(
                    features[:batch_size_t, t, :, :] + scores_upto_t[:batch_size_t, t - 1].unsqueeze(1), dim=2
                )

            # If sentence is over, add transition to STOP-tag
            if terminates:
                scores_upto_t[terminates, t + 1], backpointers[terminates, t + 1, :] = torch.max(
                    scores_upto_t[terminates, t].unsqueeze(1) + transitions[self.stop_tag].unsqueeze(0), dim=2
                )

        # Decode/trace best path backwards
        decoded = torch.zeros((batch_size, backpointers.size(1)), dtype=torch.long, device=flair.device)
        pointer = torch.ones((batch_size, 1), dtype=torch.long, device=flair.device) * self.stop_tag

        for t in list(reversed(range(backpointers.size(1)))):
            decoded[:, t] = torch.gather(backpointers[:, t, :], 1, pointer).squeeze(1)
            pointer = decoded[:, t].unsqueeze(1)

        # Sanity check
        assert torch.equal(
            decoded[:, 0], torch.ones((batch_size), dtype=torch.long, device=flair.device) * self.start_tag
        )

        # remove start-tag and backscore to stop-tag
        scores_upto_t = scores_upto_t[:, :-1, :]
        decoded = decoded[:, 1:]

        # Max + Softmax to get confidence score for predicted label and append label to each token
        scores = softmax(scores_upto_t, dim=2)
        confidences = torch.max(scores, dim=2)

        tags = []
        for tag_seq, tag_seq_conf, length_seq in zip(decoded, confidences.values, lengths):
            tags.append(
                [
                    (self.tag_dictionary.get_item_for_index(tag), conf.item())
                    for tag, conf in list(zip(tag_seq, tag_seq_conf))[:length_seq]
                ]
            )

        if probabilities_for_all_classes:
            all_tags = self._all_scores_for_token(scores.cpu(), lengths, sentences)

        return tags, all_tags

    def _all_scores_for_token(self, scores: torch.Tensor, lengths: torch.IntTensor, sentences: List[Sentence]):
        """
        Returns all scores for each tag in tag dictionary.
        :param scores: Scores for current sentence.
        """
        scores = scores.numpy()
        prob_tags_per_sentence = []
        for scores_sentence, length, sentence in zip(scores, lengths, sentences):
            scores_sentence = scores_sentence[:length]
            prob_tags_per_sentence.append(
                [
                    [
                        Label(token, self.tag_dictionary.get_item_for_index(score_id), score)
                        for score_id, score in enumerate(score_dist)
                    ]
                    for score_dist, token in zip(scores_sentence, sentence)
                ]
            )
        return prob_tags_per_sentence