File size: 14,393 Bytes
2618264
 
 
 
 
 
 
 
 
dc318e0
 
 
2618264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
from dataclasses import dataclass
import math
from operator import itemgetter
import torch
from torch import nn
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
from transformers.models.bert.modeling_bert import BertOnlyMLMHead
from transformers.utils import ModelOutput
from .BertForSyntaxParsing import BertSyntaxParsingHead, SyntaxLabels, SyntaxLogitsOutput, parse_logits as syntax_parse_logits
from .BertForPrefixMarking import BertPrefixMarkingHead, parse_logits as prefix_parse_logits, encode_sentences_for_bert_for_prefix_marking
from .BertForMorphTagging import BertMorphTaggingHead, MorphLogitsOutput, MorphLabels, parse_logits as morph_parse_logits 
    
import warnings

@dataclass
class JointParsingOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    # logits will contain the optional predictions for the given labels
    logits: Optional[Union[SyntaxLogitsOutput, None]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    # if no labels are given, we will always include the syntax logits separately
    syntax_logits: Optional[SyntaxLogitsOutput] = None
    ner_logits: Optional[torch.FloatTensor] = None
    prefix_logits: Optional[torch.FloatTensor] = None
    lex_logits: Optional[torch.FloatTensor] = None
    morph_logits: Optional[MorphLogitsOutput] = None

# wrapper class to wrap a torch.nn.Module so that you can store a module in multiple linked
# properties without registering the parameter multiple times
class ModuleRef:
    def __init__(self, module: torch.nn.Module):
        self.module = module

    def forward(self, *args, **kwargs):
        return self.module.forward(*args, **kwargs)

    def __call__(self, *args, **kwargs):
        return self.module(*args, **kwargs)

class BertForJointParsing(BertPreTrainedModel):
    _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]

    def __init__(self, config, do_syntax=None, do_ner=None, do_prefix=None, do_lex=None, do_morph=None, syntax_head_size=64):
        super().__init__(config)

        self.bert = BertModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # create all the heads as None, and then populate them as defined
        self.syntax, self.ner, self.prefix, self.lex, self.morph = (None,)*5

        if do_syntax is not None: 
            config.do_syntax = do_syntax
            config.syntax_head_size = syntax_head_size
        if do_ner is not None: config.do_ner = do_ner
        if do_prefix is not None: config.do_prefix = do_prefix
        if do_lex is not None: config.do_lex = do_lex
        if do_morph is not None: config.do_morph = do_morph
        
        # add all the individual heads
        if config.do_syntax:
            self.syntax = BertSyntaxParsingHead(config)
        if config.do_ner:
            self.num_labels = config.num_labels
            self.classifier = nn.Linear(config.hidden_size, config.num_labels) # name it same as in BertForTokenClassification 
            self.ner = ModuleRef(self.classifier)
        if config.do_prefix:
            self.prefix = BertPrefixMarkingHead(config)
        if config.do_lex:
            self.cls = BertOnlyMLMHead(config) # name it the same as in BertForMaskedLM
            self.lex = ModuleRef(self.cls)
        if config.do_morph:
            self.morph = BertMorphTaggingHead(config)
        
        # Initialize weights and apply final processing
        self.post_init()

    def get_output_embeddings(self):
        return self.cls.predictions.decoder if self.lex is not None else None

    def set_output_embeddings(self, new_embeddings):
        if self.lex is not None:
            self.cls.predictions.decoder = new_embeddings

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        prefix_class_id_options: Optional[torch.Tensor] = None,
        labels: Optional[Union[SyntaxLabels, MorphLabels, torch.Tensor]] = None,
        labels_type: Optional[Literal['syntax', 'ner', 'prefix', 'lex', 'morph']] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        compute_syntax_mst: Optional[bool] = None
    ):
        if return_dict is False:
            warnings.warn("Specified `return_dict=False` but the flag is ignored and treated as always True in this model.")
        
        if labels is not None and labels_type is None:
            raise ValueError("Cannot specify labels without labels_type")
        
        if labels_type == 'seg' and prefix_class_id_options is None:
            raise ValueError('Cannot calculate prefix logits without prefix_class_id_options')
        
        if compute_syntax_mst is not None and self.syntax is None:
            raise ValueError("Cannot compute syntax MST when the syntax head isn't loaded")


        bert_outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )
        
        # calculate the extended attention mask for any child that might need it
        extended_attention_mask = None
        if attention_mask is not None:
            extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size())

        # extract the hidden states, and apply the dropout
        hidden_states = self.dropout(bert_outputs[0])

        logits = None    
        syntax_logits = None
        ner_logits = None
        prefix_logits = None
        lex_logits = None
        morph_logits = None

        # Calculate the syntax
        if self.syntax is not None and (labels is None or labels_type == 'syntax'):
            # apply the syntax head
            loss, syntax_logits = self.syntax(hidden_states, extended_attention_mask, labels, compute_syntax_mst)
            logits = syntax_logits

        # Calculate the NER
        if self.ner is not None and (labels is None or labels_type == 'ner'):
            ner_logits = self.ner(hidden_states)
            logits = ner_logits
            if labels is not None:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        # Calculate the segmentation
        if self.prefix is not None and (labels is None or labels_type == 'prefix'):
            loss, prefix_logits = self.prefix(hidden_states, prefix_class_id_options, labels)
            logits = prefix_logits
        
        # Calculate the lexeme
        if self.lex is not None and (labels is None or labels_type == 'lex'):
            lex_logits = self.lex(hidden_states)
            logits = lex_logits
            if labels is not None:
                loss_fct = nn.CrossEntropyLoss()  # -100 index = padding token
                loss = loss_fct(lex_logits.view(-1, self.config.vocab_size), labels.view(-1))

        if self.morph is not None and (labels is None or labels_type == 'morph'):
            loss, morph_logits = self.morph(hidden_states, labels)
            logits = morph_logits

        # no labels => logits = None
        if labels is None: logits = None

        return JointParsingOutput(
            loss,
            logits,
            hidden_states=bert_outputs.hidden_states,
            attentions=bert_outputs.attentions,
            # all the predicted logits section
            syntax_logits=syntax_logits,
            ner_logits=ner_logits,
            prefix_logits=prefix_logits,
            lex_logits=lex_logits,
            morph_logits=morph_logits
        )

    def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False):
        is_single_sentence = isinstance(sentences, str)
        if is_single_sentence:
            sentences = [sentences]
            
        # predict the logits for the sentence
        if self.prefix is not None:
            inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
        else:
            inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_tensors='pt')
            
        # Copy the tensors to the right device, and parse!
        inputs = {k:v.to(self.device) for k,v in inputs.items()}
        output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
        
        final_output = [dict(text=sentence, tokens=[dict(token=t) for t in combine_token_wordpieces(ids, tokenizer)]) for sentence, ids in zip(sentences, inputs['input_ids'])]
        # Syntax logits: each sentence gets a dict(tree: List[dict(word,dep_head,dep_head_idx,dep_func)], root_idx: int)
        if output.syntax_logits is not None:
            for sent_idx,parsed in enumerate(syntax_parse_logits(inputs, sentences, tokenizer, output.syntax_logits)):
                merge_token_list(final_output[sent_idx]['tokens'], parsed['tree'], 'syntax')
                final_output[sent_idx]['root_idx'] = parsed['root_idx']
                
        # Prefix logits: each sentence gets a list([prefix_segment, word_without_prefix]) - **WITH CLS & SEP**
        if output.prefix_logits is not None:
            for sent_idx,parsed in enumerate(prefix_parse_logits(inputs, sentences, tokenizer, output.prefix_logits)):
                merge_token_list(final_output[sent_idx]['tokens'], map(tuple, parsed[1:-1]), 'seg')
            
        # Lex logits each sentence gets a list(tuple(word, lexeme))
        if output.lex_logits is not None:
            for sent_idx, parsed in enumerate(lex_parse_logits(inputs, sentences, tokenizer, output.lex_logits)):
                merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'lex')
                
        # morph logits each sentences get a dict(text=str, tokens=list(dict(token, pos, feats, prefixes, suffix, suffix_feats?)))
        if output.morph_logits is not None:
            for sent_idx,parsed in enumerate(morph_parse_logits(inputs, sentences, tokenizer, output.morph_logits)):
                merge_token_list(final_output[sent_idx]['tokens'], parsed['tokens'], 'morph')
            
        # NER logits each sentence gets a list(tuple(word, ner))
        if output.ner_logits is not None:
            for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label)):
                if per_token_ner:
                    merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
                final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed) 
        
        if is_single_sentence:
            final_output = final_output[0]
        return final_output

def aggregate_ner_tokens(predictions):
    entities = []
    prev = None
    for word,pred in predictions:
        # O does nothing
        if pred == 'O': prev = None
        # B- || I-entity != prev (different entity or none)
        elif pred.startswith('B-') or pred[2:] != prev:
            prev = pred[2:]
            entities.append(([word], prev))
        else: entities[-1][0].append(word)
    
    return [dict(phrase=' '.join(words), label=label) for words,label in entities]
        

def merge_token_list(src, update, key):
    for token_src, token_update in zip(src, update):
        token_src[key] = token_update
        
def combine_token_wordpieces(input_ids: torch.Tensor, tokenizer: BertTokenizerFast):
    ret = []
    for token in tokenizer.convert_ids_to_tokens(input_ids):
        if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: continue
        if token.startswith('##'):
            ret[-1] += token[2:]
        else: ret.append(token)
    return ret

def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]):
    input_ids = inputs['input_ids']
    
    predictions = torch.argmax(logits, dim=-1)
    batch_ret = []
    for batch_idx in range(len(sentences)):
        ret = []
        batch_ret.append(ret)
        for tok_idx in range(input_ids.shape[1]):
            token_id = input_ids[batch_idx, tok_idx]
            # ignore cls, sep, pad
            if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue

            token = tokenizer._convert_id_to_token(token_id)
            # wordpieces should just be appended to the previous word
            if token.startswith('##'):
                ret[-1] = (ret[-1][0] + token[2:], ret[-1][1])
                continue
            ret.append((token, id2label[predictions[batch_idx, tok_idx].item()]))
    return batch_ret

def lex_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
    input_ids = inputs['input_ids']
    
    predictions = torch.argmax(logits, dim=-1)
    batch_ret = []
    for batch_idx in range(len(sentences)):
        ret = []
        batch_ret.append(ret)
        for tok_idx in range(input_ids.shape[1]):
            token_id = input_ids[batch_idx, tok_idx]
            # ignore cls, sep, pad
            if token_id in [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]: continue

            token = tokenizer._convert_id_to_token(token_id)
            # wordpieces should just be appended to the previous word
            if token.startswith('##'):
                ret[-1] = (ret[-1][0] + token[2:], ret[-1][1])
                continue
            ret.append((token, tokenizer._convert_id_to_token(predictions[batch_idx, tok_idx])))
    return batch_ret