File size: 9,849 Bytes
e70cca1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import unittest
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch import nn
from torch.autograd import Variable


@dataclass
class CRFOutput:
    loss: Optional[torch.tensor]
    real_path_score: Optional[torch.tensor]
    total_score: torch.tensor
    best_path_score: torch.tensor
    best_path: torch.tensor


class MaskedCRFLoss(nn.Module):
    __constants__ = ["num_tags", "mask_id"]

    num_tags: int
    mask_id: int

    def __init__(self, num_tags: int, mask_id: int = 0):
        super().__init__()
        self.num_tags = num_tags
        self.mask_id = mask_id
        self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
        self.start_transitions = nn.Parameter(torch.randn(num_tags))
        self.stop_transitions = nn.Parameter(torch.randn(num_tags))

    def extra_repr(self) -> str:
        s = "num_tags={num_tags}, mask_id={mask_id}"
        return s.format(**self.__dict__)

    def forward(self, emissions, tags, mask, return_best_path=False):
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        # mask: (seq_length, batch_size)

        seq_length, batch_size = tags.shape
        mask = mask.float()
        # set return_best_path as True always during eval.
        # During training it slows things down as best path is not needed
        if not self.training:
            return_best_path = True
        # Compute the total likelihood
        total_score, best_path_score, best_path = self.compute_log_partition_function(
            emissions, mask, return_best_path=return_best_path
        )

        if tags is None:
            return CRFOutput(None, None, total_score, best_path_score, best_path)

        # Compute the likelihood of the real path
        real_path_score = torch.zeros(batch_size).to(tags.device)
        real_path_score += self.start_transitions[tags[0]]  # batch_size
        for i in range(1, seq_length):
            current_tag = tags[i]
            real_path_score += self.transitions[tags[i - 1], current_tag] * mask[i]
            real_path_score += emissions[i, range(batch_size), current_tag] * mask[i]
        # Transition to STOP_TAG
        real_path_score += self.stop_transitions[tags[-1]]

        # Return the negative log likelihood
        loss = torch.mean(total_score - real_path_score)
        return CRFOutput(loss, real_path_score, total_score, best_path_score, best_path)

    def compute_log_partition_function(self, emissions, mask, return_best_path=False):
        init_alphas = self.start_transitions + emissions[0]  # (batch_size, num_tags)
        forward_var = init_alphas
        forward_viterbi_var = init_alphas
        # backpointers holds the best tag id at each time step, we accumulate these in reverse order
        backpointers = []

        for i, emission in enumerate(emissions[1:, :, :], 1):
            broadcast_emission = emission.unsqueeze(2)  # (batch_size, num_tags, 1)
            broadcast_transmissions = self.transitions.unsqueeze(
                0
            )  # (1, num_tags, num_tags)

            # Compute next
            next_tag_var = (
                forward_var.unsqueeze(1) + broadcast_emission + broadcast_transmissions
            )
            next_tag_viterbi_var = (
                forward_viterbi_var.unsqueeze(1)
                + broadcast_emission
                + broadcast_transmissions
            )

            next_unmasked_forward_var = torch.logsumexp(next_tag_var, dim=2)
            viterbi_scores, best_next_tags = torch.max(next_tag_viterbi_var, dim=2)
            # If mask == 1 use the next_unmasked_forward_var else copy the forward_var
            # Update forward_var
            forward_var = (
                mask[i].unsqueeze(-1) * next_unmasked_forward_var
                + (1 - mask[i]).unsqueeze(-1) * forward_var
            )
            # Update viterbi with mask
            forward_viterbi_var = (
                mask[i].unsqueeze(-1) * viterbi_scores
                + (1 - mask[i]).unsqueeze(-1) * forward_viterbi_var
            )
            backpointers.append(best_next_tags)

        # Transition to STOP_TAG
        terminal_var = forward_var + self.stop_transitions
        terminal_viterbi_var = forward_viterbi_var + self.stop_transitions

        alpha = torch.logsumexp(terminal_var, dim=1)
        best_path_score, best_final_tags = torch.max(terminal_viterbi_var, dim=1)

        best_path = None
        if return_best_path:
            # backtrace
            best_path = [best_final_tags]
            for bptrs, mask_data in zip(reversed(backpointers), torch.flip(mask, [0])):
                best_tag_id = torch.gather(
                    bptrs, 1, best_final_tags.unsqueeze(1)
                ).squeeze(1)
                best_final_tags.masked_scatter_(
                    mask_data.to(dtype=torch.bool),
                    best_tag_id.masked_select(mask_data.to(dtype=torch.bool)),
                )
                best_path.append(best_final_tags)
            # Reverse the order because we were appending in reverse
            best_path = torch.stack(best_path[::-1])
            best_path = best_path.where(mask == 1, -100)

        return alpha, best_path_score, best_path

    def viterbi_decode(self, emissions, mask):
        seq_len, batch_size, num_tags = emissions.shape

        # backpointers holds the best tag id at each time step, we accumulate these in reverse order
        backpointers = []

        # Initialize the viterbi variables in log space
        init_vvars = self.start_transitions + emissions[0]  # (batch_size, num_tags)
        forward_var = init_vvars

        for i, emission in enumerate(emissions[1:, :, :], 1):
            broadcast_emission = emission.unsqueeze(2)
            broadcast_transmissions = self.transitions.unsqueeze(0)
            next_tag_var = (
                forward_var.unsqueeze(1) + broadcast_emission + broadcast_transmissions
            )

            viterbi_scores, best_next_tags = torch.max(next_tag_var, 2)
            # If mask == 1 use the next_unmasked_forward_var else copy the forward_var
            forward_var = (
                mask[i].unsqueeze(-1) * viterbi_scores
                + (1 - mask[i]).unsqueeze(-1) * forward_var
            )
            backpointers.append(best_next_tags)

        # Transition to STOP_TAG
        terminal_var = forward_var + self.stop_transitions
        best_path_score, best_final_tags = torch.max(terminal_var, dim=1)

        # backtrace
        best_path = [best_final_tags]
        for bptrs, mask_data in zip(reversed(backpointers), torch.flip(mask, [0])):
            best_tag_id = torch.gather(bptrs, 1, best_final_tags.unsqueeze(1)).squeeze(
                1
            )
            best_final_tags.masked_scatter_(
                mask_data.to(dtype=torch.bool),
                best_tag_id.masked_select(mask_data.to(dtype=torch.bool)),
            )
            best_path.append(best_final_tags)

        # Reverse the order because we were appending in reverse
        best_path = torch.stack(best_path[::-1])
        best_path = best_path.where(mask == 1, -100)

        return best_path, best_path_score


class MaskedCRFLossTest(unittest.TestCase):
    def setUp(self):
        self.num_tags = 5
        self.mask_id = 0

        self.crf_model = MaskedCRFLoss(self.num_tags, self.mask_id)

        self.seq_length, self.batch_size = 11, 5
        # Making up some inputs
        # emissions = Variable(torch.randn(seq_length, batch_size, num_tags))
        # tags = Variable(torch.randint(num_tags, (seq_length, batch_size)))
        # mask = Variable(torch.ones(seq_length, batch_size))
        self.emissions = torch.randn(self.seq_length, self.batch_size, self.num_tags)
        self.tags = torch.randint(self.num_tags, (self.seq_length, self.batch_size))
        # mask = torch.ones(seq_length, batch_size)
        self.mask = torch.randint(2, (self.seq_length, self.batch_size))

    def test_forward(self):
        # Checking if forward runs successfully
        try:
            output = self.crf_model(self.emissions, self.tags, self.mask)
            print("Forward function runs successfully!")
        except Exception as e:
            print("Forward function couldn't run successfully:", e)

    def test_viterbi_decode(self):
        # Checking if viterbi_decode runs successfully
        try:
            path, best_path_score = self.crf_model.viterbi_decode(
                self.emissions, self.mask
            )
            print(path.T)
            print("Viterbi decoding function runs successfully!")
        except Exception as e:
            print("Viterbi decoding function couldn't run successfully:", e)

    def test_forward_output(self):
        # Simple check if losses are non-negative
        output = self.crf_model(self.emissions, self.tags, self.mask)
        loss = output.loss
        self.assertTrue((loss > 0).all())

    def test_compute_log_partition_function_output(self):
        # Simply checking if the output is non-negative
        (
            partition,
            best_path_score,
            best_path,
        ) = self.crf_model.compute_log_partition_function(self.emissions, self.mask)
        self.assertTrue((partition > 0).all())

    def test_viterbi_decode_output(self):
        print(self.mask.T)
        # Check whether the output shape is correct and lies within valid tag range
        path, best_path_score = self.crf_model.viterbi_decode(self.emissions, self.mask)
        print(path.T)
        self.assertEqual(
            path.shape, (self.seq_length, self.batch_size)
        )  # checking dimensions
        self.assertTrue(
            ((0 <= path) | (path == -100)).all() and (path < self.num_tags).all()
        )  # checking tag validity