File size: 16,609 Bytes
6680682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
import os
from typing import *

import torch
from allennlp.common.from_params import Params, T, pop_and_construct_arg
from allennlp.data.vocabulary import Vocabulary, DEFAULT_PADDING_TOKEN, DEFAULT_OOV_TOKEN
from allennlp.models.model import Model
from allennlp.modules import TextFieldEmbedder
from allennlp.modules.seq2seq_encoders.pytorch_seq2seq_wrapper import Seq2SeqEncoder
from allennlp.modules.span_extractors import SpanExtractor
from allennlp.training.metrics import Metric

from ..metrics import ExactMatch
from ..modules import SpanFinder, SpanTyping
from ..utils import num2mask, VIRTUAL_ROOT, Span, tensor2span


@Model.register("span")
class SpanModel(Model):
    """
    Identify/Find spans; link them as a tree; label them.
    """
    default_predictor = 'span'

    def __init__(
            self,
            vocab: Vocabulary,

            # Modules
            word_embedding: TextFieldEmbedder,
            span_extractor: SpanExtractor,
            span_finder: SpanFinder,
            span_typing: SpanTyping,

            # Config
            typing_loss_factor: float = 1.,
            max_recursion_depth: int = -1,
            max_decoding_spans: int = -1,
            debug: bool = False,

            # Ontology Constraints
            ontology_path: Optional[str] = None,

            # Metrics
            metrics: Optional[List[Metric]] = None,
    ) -> None:
        """
        Note for jsonnet file: it doesn't strictly follow the init examples of every module for that we override
        the from_params method.
        You can either check the SpanModel.from_params or the example jsonnet file.
        :param vocab: No need to specify.
        ## Modules
        :param word_embedding: Refer to the module doc.
        :param span_extractor: Refer to the module doc.
        :param span_finder: Refer to the module doc.
        :param span_typing: Refer to the module doc.
        ## Configs
        :param typing_loss_factor: loss = span_finder_loss + span_typing_loss * typing_loss_factor
        :param max_recursion_depth: Maximum tree depth for inference. E.g., 1 for shallow event typing, 2 for SRL,
            -1 (unlimited) for dependency parsing.
        :param max_decoding_spans: Maximum spans for inference. -1 for unlimited.
        :param debug: Useless now.
        """
        self._pad_idx = vocab.get_token_index(DEFAULT_PADDING_TOKEN, 'token')
        self._null_idx = vocab.get_token_index(DEFAULT_OOV_TOKEN, 'span_label')
        super().__init__(vocab)

        self.word_embedding = word_embedding
        self._span_finder = span_finder
        self._span_extractor = span_extractor
        self._span_typing = span_typing

        self.metrics = [ExactMatch(True), ExactMatch(False)]
        if metrics is not None:
            self.metrics.extend(metrics)

        if ontology_path is not None and os.path.exists(ontology_path):
            self._span_typing.load_ontology(ontology_path, self.vocab)

        self._max_decoding_spans = max_decoding_spans
        self._typing_loss_factor = typing_loss_factor
        self._max_recursion_depth = max_recursion_depth
        self.debug = debug

    def forward(
            self,
            tokens: Dict[str, Dict[str, torch.Tensor]],

            span_boundary: Optional[torch.Tensor] = None,
            span_labels: Optional[torch.Tensor] = None,
            parent_indices: Optional[torch.Tensor] = None,
            parent_mask: Optional[torch.Tensor] = None,

            bio_seqs: Optional[torch.Tensor] = None,
            raw_inputs: Optional[dict] = None,
            meta: Optional[dict] = None,

            **extra
    ) -> Dict[str, torch.Tensor]:
        """
        For training, provide all blow.
        For inference, it's enough to only provide words.

        :param tokens: Indexed input sentence. Shape: [batch, token]

        :param span_boundary: Start and end indices for every span. Note this includes both parent and
            non-parent spans. Shape: [batch, span, 2]. For the last dim, [0] is start idx and [1] is end idx.
        :param span_labels: Indexed label for spans, including parent and non-parent ones. Shape: [batch, span]
        :param parent_indices: The parent span idx of every span. Shape: [batch, span]
        :param parent_mask: True if this span is a parent. Shape: [batch, span]

        :param bio_seqs: Shape [batch, parent, token, 3]
        :param raw_inputs

        :param meta: Meta information. Will be copied to the outputs.

        :return:
            - loss: training loss
            - prediction: Predicted spans
            - meta: Meta info copied from input
            - inputs: Input sentences and spans (if exist)
        """
        ret = {'inputs': raw_inputs, 'meta': meta or dict()}

        is_eval = span_labels is not None and not self.training  # evaluation on dev set
        is_test = span_labels is None  # test on test set
        # Shape [batch]
        num_spans = (span_labels != -1).sum(1) if span_labels is not None else None
        num_words = tokens['pieces']['mask'].sum(1)
        # Shape [batch, word, token_dim]
        token_vec = self.word_embedding(tokens)

        if span_labels is not None:
            # Revise the padding value from -1 to 0
            span_labels[span_labels == -1] = 0

        # Calculate Loss
        if self.training or is_eval:
            # Shape [batch, word, token_dim]
            span_vec = self._span_extractor(token_vec, span_boundary)
            finder_rst = self._span_finder(
                token_vec, num2mask(num_words), span_vec, num2mask(num_spans), span_labels, parent_indices,
                parent_mask, bio_seqs
            )
            typing_rst = self._span_typing(span_vec, parent_indices, span_labels)
            ret['loss'] = finder_rst['loss'] + typing_rst['loss'] * self._typing_loss_factor

        # Decoding
        if is_eval or is_test:
            pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence \
                = self.inference(num_words, token_vec, **extra)
            prediction = self.post_process_pred(
                pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence
            )
            for pred, raw_in in zip(prediction, raw_inputs):
                pred.re_index(raw_in['offsets'], True, True, True)
                pred.remove_overlapping()
            ret['prediction'] = prediction
            if 'spans' in raw_inputs[0]:
                for pred, raw_in in zip(prediction, raw_inputs):
                    gold = raw_in['spans']
                    for metric in self.metrics:
                        metric(pred, gold)

        return ret

    def inference(
            self,
            num_words: torch.Tensor,
            token_vec: torch.Tensor,
            **auxiliaries
    ):
        n_batch = num_words.shape[0]
        # The decoding results are preserved in the following tensors starting with `pred`
        # During inference, we completely ignore the arguments defaulted None in the forward method.
        # The span indexing space is shift to the decoding span space. (since we do not have gold span now)
        # boundary indices of every predicted span
        pred_span_boundary = num_words.new_zeros([n_batch, self._max_decoding_spans, 2])
        # labels (and corresponding confidence) for predicted spans
        pred_span_labels = num_words.new_full(
            [n_batch, self._max_decoding_spans], self.vocab.get_token_index(VIRTUAL_ROOT, 'span_label')
        )
        pred_label_confidence = num_words.new_zeros([n_batch, self._max_decoding_spans])
        # label masked as True will be treated as parent in the next round
        pred_parent_mask = num_words.new_zeros([n_batch, self._max_decoding_spans], dtype=torch.bool)
        pred_parent_mask[:, 0] = True
        # parent index (in the span indexing space) for every span
        pred_parent_indices = num_words.new_zeros([n_batch, self._max_decoding_spans])
        # what index have we reached for every batch?
        pred_cursor = num_words.new_ones([n_batch])

        # Pass environment variables to handler. Extra variables will be ignored.
        # So pass the union of variables that are needed by different modules.
        span_find_handler = self._span_finder.inference_forward_handler(
            token_vec, num2mask(num_words), self._span_extractor, **auxiliaries
        )

        # Every step here is one layer of the tree. It deals with all the parents for the last layer
        # so there might be 0 to multiple parents for a batch for a single step.
        for _ in range(self._max_recursion_depth):
            cursor_before_find = pred_cursor.clone()
            span_find_handler(
                pred_span_boundary, pred_span_labels, pred_parent_mask, pred_parent_indices, pred_cursor
            )
            # Labels of old spans are re-predicted. It doesn't matter since their results shouldn't change
            # in theory.
            span_typing_ret = self._span_typing(
                self._span_extractor(token_vec, pred_span_boundary), pred_parent_indices, pred_span_labels, True
            )
            pred_span_labels = span_typing_ret['prediction']
            pred_label_confidence = span_typing_ret['label_confidence']
            pred_span_labels[:, 0] = self.vocab.get_token_index(VIRTUAL_ROOT, 'span_label')
            pred_parent_mask = (
                    num2mask(cursor_before_find, self._max_decoding_spans) ^ num2mask(pred_cursor,
                                                                                      self._max_decoding_spans)
            )

            # Break the inference loop if 1) all batches reach max span limit OR 2) no parent is predicted
            # at last step OR 3) max recursion limit is reached (for loop condition)
            if (pred_cursor == self._max_decoding_spans).all() or pred_parent_mask.sum() == 0:
                break

        return pred_span_boundary, pred_span_labels, pred_parent_indices, pred_cursor, pred_label_confidence

    def one_step_prediction(
            self,
            tokens: Dict[str, Dict[str, torch.Tensor]],
            parent_boundary: torch.Tensor,
            parent_labels: torch.Tensor,
    ):
        """
        Single step prediction. Given parent span boundary indices, return the corresponding children spans
            and their labels.
        Restriction: Each sentence contain exactly 1 parent.
        For efficient multi-layer prediction, i.e. given a root, predict the whole tree,
            refer to the `forward' method.
        :param tokens: See forward.
        :param parent_boundary: Pairs of (start_idx, end_idx) for parents. Shape [batch, 2]
        :param parent_labels: Labels for parents. Shape [batch]
            Note: If `no_label' is on in span_finder module, this will be ignored.
        :return:
            children_boundary: (start_idx, end_idx) for every child span. Padded with (0, 0).
                Shape [batch, children, 2]
            children_labels: Label for every child span. Padded with null_idx. Shape [batch, children]
            num_children: The number of children predicted for parent/batch. Shape [batch]
                Tips: You can use num2mask method to convert this to bool tensor mask.
        """
        num_words = tokens['pieces']['mask'].sum(1)
        # Shape [batch, word, token_dim]
        token_vec = self.word_embedding(tokens)
        n_batch = token_vec.shape[0]

        # The following variables assumes the parent is the 0-th span, and we let the model
        # to extend the span list.
        pred_span_boundary = num_words.new_zeros([n_batch, self._max_decoding_spans, 2])
        pred_span_boundary[:, 0] = parent_boundary
        pred_span_labels = num_words.new_full([n_batch, self._max_decoding_spans], self._null_idx)
        pred_span_labels[:, 0] = parent_labels
        pred_parent_mask = num_words.new_zeros(pred_span_labels.shape, dtype=torch.bool)
        pred_parent_mask[:, 0] = True
        pred_parent_indices = num_words.new_zeros([n_batch, self._max_decoding_spans])
        # We start from idx 1 since 0 is the parents.
        pred_cursor = num_words.new_ones([n_batch])

        span_find_handler = self._span_finder.inference_forward_handler(
            token_vec, num2mask(num_words), self._span_extractor
        )
        span_find_handler(
            pred_span_boundary, pred_span_labels, pred_parent_mask, pred_parent_indices, pred_cursor
        )
        typing_out = self._span_typing(
            self._span_extractor(token_vec, pred_span_boundary), pred_parent_indices, pred_span_labels, True
        )
        pred_span_labels = typing_out['prediction']

        # Now remove the parent
        num_children = pred_cursor - 1
        max_children = int(num_children.max())
        children_boundary = pred_span_boundary[:, 1:max_children + 1]
        children_labels = pred_span_labels[:, 1:max_children + 1]
        children_distribution = typing_out['distribution'][:, 1:max_children + 1]
        return children_boundary, children_labels, num_children, children_distribution

    def post_process_pred(
            self, span_boundary, span_labels, parent_indices, num_spans, label_confidence
    ) -> List[Span]:
        pred_spans = tensor2span(
            span_boundary, span_labels, parent_indices, num_spans, label_confidence,
            self.vocab.get_index_to_token_vocabulary('span_label'),
            label_ignore=[self._null_idx],
        )
        return pred_spans

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        ret = dict()
        if reset:
            for metric in self.metrics:
                ret.update(metric.get_metric(reset))
        ret.update(self._span_finder.get_metrics(reset))
        ret.update(self._span_typing.get_metric(reset))
        return ret

    @classmethod
    def from_params(
            cls: Type[T],
            params: Params,
            constructor_to_call: Callable[..., T] = None,
            constructor_to_inspect: Callable[..., T] = None,
            **extras,
    ) -> T:
        """
        Specify the dependency between modules. E.g. the input dim of a module might depend on the output dim
        of another module.
        """
        vocab = extras['vocab']
        word_embedding = pop_and_construct_arg('SpanModel', 'word_embedding', TextFieldEmbedder, None, params, **extras)
        label_dim, token_emb_dim = params.pop('label_dim'), word_embedding.get_output_dim()
        span_extractor = pop_and_construct_arg(
            'SpanModel', 'span_extractor', SpanExtractor, None, params, input_dim=token_emb_dim, **extras
        )
        label_embedding = torch.nn.Embedding(vocab.get_vocab_size('span_label'), label_dim)
        extras['label_emb'] = label_embedding

        if params.get('span_finder').get('type') == 'bio':
            bio_encoder = Seq2SeqEncoder.from_params(
                params['span_finder'].pop('bio_encoder'),
                input_size=span_extractor.get_output_dim() + token_emb_dim + label_dim,
                input_dim=span_extractor.get_output_dim() + token_emb_dim + label_dim,
                **extras
            )
            extras['span_finder'] = SpanFinder.from_params(
                params.pop('span_finder'), bio_encoder=bio_encoder, **extras
            )
        else:
            extras['span_finder'] = pop_and_construct_arg(
                'SpanModel', 'span_finder', SpanFinder, None, params, **extras
            )
            extras['span_finder'].label_emb = label_embedding

        if params.get('span_typing').get('type') == 'mlp':
            extras['span_typing'] = SpanTyping.from_params(
                params.pop('span_typing'),
                input_dim=span_extractor.get_output_dim() * 2 + label_dim,
                n_category=vocab.get_vocab_size('span_label'),
                label_to_ignore=[
                    vocab.get_token_index(lti, 'span_label')
                    for lti in [DEFAULT_OOV_TOKEN, DEFAULT_PADDING_TOKEN]
                ],
                **extras
            )
        else:
            extras['span_typing'] = pop_and_construct_arg(
                'SpanModel', 'span_typing', SpanTyping, None, params, **extras
            )
            extras['span_typing'].label_emb = label_embedding

        return super().from_params(
            params,
            word_embedding=word_embedding,
            span_extractor=span_extractor,
            **extras
        )