File size: 11,071 Bytes
b003ea2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6626d0
b003ea2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706324f
b003ea2
 
 
 
 
 
 
 
 
 
 
 
 
a6162ae
b003ea2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3075e8
b003ea2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
from transformers.utils import ModelOutput
import torch
from torch import nn
from typing import List, Tuple, Optional
from dataclasses import dataclass
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast

# define the classes, and the possible prefixes for each class
POSSIBLE_PREFIX_CLASSES =  [ ['ืœื›ืฉ', 'ื›ืฉ', 'ืžืฉ', 'ื‘ืฉ', 'ืœืฉ'], ['ืž'], ['ืฉ'], ['ื”'], ['ื•'], ['ื›'], ['ืœ'], ['ื‘'] ]
# map each individual prefix to it's class number
PREFIXES_TO_CLASS = {w:i for i,l in enumerate(POSSIBLE_PREFIX_CLASSES) for w in l}
# keep a list of all the prefixes, sorted by length, so that we can decompose
# a given prefixes and figure out the classes
ALL_PREFIX_ITEMS = list(sorted(PREFIXES_TO_CLASS.keys(), key=len, reverse=True))
TOTAL_POSSIBLE_PREFIX_CLASSES = len(POSSIBLE_PREFIX_CLASSES)    

def get_prefixes_from_str(s, greedy=False):
    # keep trimming prefixes from the string
    while len(s) > 0 and s[0] in PREFIXES_TO_CLASS:
        # find the longest string to trim
        next_pre = next((pre for pre in ALL_PREFIX_ITEMS if s.startswith(pre)), None)
        if next_pre is None:
            return
        yield next_pre
        # if the chosen prefix is more than one letter, there is always an option that the 
        # prefix is actually just the first letter of the prefix - so offer that up as a valid prefix
        # as well. We will still jump to the length of the longer one, since if the next two/three
        # letters are a prefix, they have to be the longest one
        if not greedy and len(next_pre) > 1:
            yield next_pre[0]
        s = s[len(next_pre):]

def get_prefix_classes_from_str(s, greedy=False):
    for pre in get_prefixes_from_str(s, greedy):
        yield PREFIXES_TO_CLASS[pre]

@dataclass
class PrefixesClassifiersOutput(ModelOutput):
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

class BertForPrefixMarking(BertPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(0.1)

        # an embedding table containing an embedding for each prefix class + 1 for NONE
        # we will concatenate either the embedding/NONE for each class - and we want the concatenate
        # size to be the hidden_size
        prefix_class_embed = config.hidden_size // TOTAL_POSSIBLE_PREFIX_CLASSES
        self.prefix_class_embeddings = nn.Embedding(TOTAL_POSSIBLE_PREFIX_CLASSES + 1, prefix_class_embed)
        
        # one layer for transformation, apply an activation, then another N classifiers for each prefix class
        self.transform = nn.Linear(config.hidden_size + prefix_class_embed * TOTAL_POSSIBLE_PREFIX_CLASSES, config.hidden_size)
        self.activation = nn.Tanh()
        self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, 2) for _ in range(TOTAL_POSSIBLE_PREFIX_CLASSES)])

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        prefix_class_id_options: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        bert_outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = bert_outputs[0]
        sequence_output = self.dropout(sequence_output)

        # encode the prefix_class_id_options
        # If input_ids is batch x seq_len
        # Then sequence_output is batch x seq_len x hidden_dim
        # So prefix_class_id_options is batch x seq_len x TOTAL_POSSIBLE_PREFIX_CLASSES
        # Looking up the embeddings should give us batch x seq_len x TOTAL_POSSIBLE_PREFIX_CLASSES x hidden_dim / N
        possible_class_embed = self.prefix_class_embeddings(prefix_class_id_options)
        # then flatten the final dimension - now we have batch x seq_len x hidden_dim_2
        possible_class_embed = possible_class_embed.reshape(possible_class_embed.shape[:-2] + (-1,))

        # concatenate the new class embed into the sequence output before the transform
        pre_transform_output = torch.cat((sequence_output, possible_class_embed), dim=-1) # batch x seq_len x (hidden_dim + hidden_dim_2)
        pre_logits_output = self.activation(self.transform(pre_transform_output))# batch x seq_len x hidden_dim
        # run each of the classifiers on the transformed output
        logits = torch.cat([cls(pre_logits_output).unsqueeze(-2) for cls in self.classifiers], dim=-2)
        
        if not return_dict:
            return (logits,) + bert_outputs[2:]

        return PrefixesClassifiersOutput(
            logits=logits,
            hidden_states=bert_outputs.hidden_states,
            attentions=bert_outputs.attentions,
        )
    
    def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
        # step 1: encode the sentences through using the tokenizer, and get the input tensors + prefix id tensors
        inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
        inputs = {k:v.to(self.device) for k,v in inputs.items()}

        # run through bert
        logits = self.forward(**inputs, return_dict=True).logits
        
        # extract the predictions by argmaxing the final dimension (batch x sequence x prefixes x prediction)
        logit_preds = torch.argmax(logits, axis=3)

        ret = []

        for sent_idx,sent_ids in enumerate(inputs['input_ids']):
            tokens = tokenizer.convert_ids_to_tokens(sent_ids)
            ret.append([])
            for tok_idx,token in enumerate(tokens):
                # If we've reached the pad token, then we are at the end
                if token == tokenizer.pad_token: continue
                if token.startswith('##'): continue

                # combine the next tokens in? only if it's a breakup
                next_tok_idx = tok_idx + 1
                while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'):
                    token += tokens[next_tok_idx][2:]
                    next_tok_idx += 1

                prefix_len = get_predicted_prefix_len_from_logits(token, logit_preds[sent_idx, tok_idx])
            
                if not prefix_len:
                    ret[-1].append([token])
                else:
                    ret[-1].append([token[:prefix_len], token[prefix_len:]])

        return ret



def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, sentences: List[str], padding='longest'):
    inputs = tokenizer(sentences, padding=padding, truncation=True, return_tensors='pt')

    # create our prefix_id_options array which will be like the input ids shape but with an addtional
    # dimension containing for each prefix whether it can be for that word
    prefix_id_options = torch.full(inputs['input_ids'].shape + (TOTAL_POSSIBLE_PREFIX_CLASSES,), TOTAL_POSSIBLE_PREFIX_CLASSES, dtype=torch.long)

    # go through each token, and fill in the vector accordingly
    for sent_idx, sent_ids in enumerate(inputs['input_ids']):
        tokens = tokenizer.convert_ids_to_tokens(sent_ids)
        for tok_idx, token in enumerate(tokens):
            # if the first letter isn't a valid prefix letter, nothing to talk about
            if len(token) < 2 or not token[0] in PREFIXES_TO_CLASS: continue

            # combine the next tokens in? only if it's a breakup
            next_tok_idx = tok_idx + 1
            while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'):
                token += tokens[next_tok_idx][2:]
                next_tok_idx += 1

            # find all the possible prefixes - and mark them as 0 (and in the possible mark it as it's value for embed lookup)
            for pre_class in get_prefix_classes_from_str(token):
                prefix_id_options[sent_idx, tok_idx, pre_class] = pre_class
        
    inputs['prefix_class_id_options'] = prefix_id_options
    return inputs

def get_predicted_prefix_len_from_logits(token, token_logits):
    # Go through each possible prefix, and check if the prefix is yes - and if
    # so increase the counter of the matched length, otherwise break out. That will solve cases
    # of predicting prefix combinations that don't exist on the word.
    # For example, if we have the word ื•ืฉื›ืฉื”ืœื›ืชื™ and the model predict ื• & ื›ืฉ, then we will only
    # take the vuv because in order to get the ื›ืฉ we need the ืฉ as well.
    # Two extra items:
    # 1] Don't allow the same prefix multiple times
    # 2] Always check that the word starts with that prefix - otherwise it's bad 
    #    (except for the case of multi-letter prefix, where we force the next to be last)
    cur_len, skip_next, last_check, seen_prefixes = 0, False, False, set()
    for prefix in get_prefixes_from_str(token):
        # Are we skipping this prefix? This will be the case where we matched ื›ืฉ, don't allow ืฉ
        if skip_next:
            skip_next = False
            continue
        # check for duplicate prefixes, we don't allow two of the same prefix
        # if it predicted two of the same, then we will break out
        if prefix in seen_prefixes: break
        seen_prefixes.add(prefix)

        # check if we predicted this prefix
        if token_logits[PREFIXES_TO_CLASS[prefix]].item():
            cur_len += len(prefix)
            if last_check: break
            skip_next = len(prefix) > 1
        # Otherwise, we predicted no. If we didn't, then this is the end of the prefix
        # and time to break out. *Except* if it's a multi letter prefix, then we allow
        # just the next letter - e.g., if ื›ืฉ doesn't match, then we allow ื›, but then we know
        # the word continues with a ืฉ, and if it's not ื›ืฉ, then it's not ื›-ืฉ- (invalid)
        elif len(prefix) > 1:
            last_check = True
        else:
            break

    return cur_len