Shaltiel commited on
Commit
cb3a4de
1 Parent(s): 93a48aa

Upload 7 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ vocab.txt filter=lfs diff=lfs merge=lfs -text
BertForMorphTagging.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from operator import itemgetter
3
+ from transformers.utils import ModelOutput
4
+ import torch
5
+ from torch import nn
6
+ from typing import List, Tuple, Optional
7
+ from dataclasses import dataclass
8
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
9
+
10
+ ALL_POS = ['DET', 'NOUN', 'VERB', 'CCONJ', 'ADP', 'PRON', 'PUNCT', 'ADJ', 'ADV', 'SCONJ', 'NUM', 'PROPN', 'AUX', 'X', 'INTJ', 'SYM']
11
+ ALL_PREFIX_POS = ['SCONJ', 'DET', 'ADV', 'CCONJ', 'ADP', 'NUM']
12
+ ALL_SUFFIX_POS = ['none', 'ADP_PRON', 'PRON']
13
+ ALL_FEATURES = [
14
+ ('Gender', ['none', 'Masc', 'Fem', 'Fem,Masc']),
15
+ ('Number', ['none', 'Sing', 'Plur', 'Plur,Sing', 'Dual', 'Dual,Plur']),
16
+ ('Person', ['none', '1', '2', '3', '1,2,3']),
17
+ ('Tense', ['none', 'Past', 'Fut', 'Pres', 'Imp'])
18
+ ]
19
+
20
+ @dataclass
21
+ class MorphLogitsOutput(ModelOutput):
22
+ prefix_logits: torch.FloatTensor = None
23
+ pos_logits: torch.FloatTensor = None
24
+ features_logits: List[torch.FloatTensor] = None
25
+ suffix_logits: torch.FloatTensor = None
26
+ suffix_features_logits: List[torch.FloatTensor] = None
27
+
28
+ def detach(self):
29
+ return MorphLogitsOutput(self.prefix_logits.detach(), self.pos_logits.detach(), [logits.deatch() for logits in self.features_logits], self.suffix_logits.detach(), [logits.deatch() for logits in self.suffix_features_logits])
30
+
31
+
32
+ @dataclass
33
+ class MorphTaggingOutput(ModelOutput):
34
+ loss: Optional[torch.FloatTensor] = None
35
+ logits: Optional[MorphLogitsOutput] = None
36
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
37
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
38
+
39
+ @dataclass
40
+ class MorphLabels(ModelOutput):
41
+ prefix_labels: Optional[torch.FloatTensor] = None
42
+ pos_labels: Optional[torch.FloatTensor] = None
43
+ features_labels: Optional[List[torch.FloatTensor]] = None
44
+ suffix_labels: Optional[torch.FloatTensor] = None
45
+ suffix_features_labels: Optional[List[torch.FloatTensor]] = None
46
+
47
+ def detach(self):
48
+ return MorphLabels(self.prefix_labels.detach(), self.pos_labels.detach(), [labels.detach() for labels in self.features_labels], self.suffix_labels.detach(), [labels.detach() for labels in self.suffix_features_labels])
49
+
50
+ def to(self, device):
51
+ return MorphLabels(self.prefix_labels.to(device), self.pos_labels.to(device), [feat.to(device) for feat in self.features_labels], self.suffix_labels.to(device), [feat.to(device) for feat in self.suffix_features_labels])
52
+
53
+ class BertForMorphTagging(BertPreTrainedModel):
54
+
55
+ def __init__(self, config):
56
+ super().__init__(config)
57
+
58
+ self.bert = BertModel(config)
59
+
60
+ self.num_prefix_classes = len(ALL_PREFIX_POS)
61
+ self.num_pos_classes = len(ALL_POS)
62
+ self.num_suffix_classes = len(ALL_SUFFIX_POS)
63
+ self.num_features_classes = list(map(len, map(itemgetter(1), ALL_FEATURES)))
64
+ # we need a classifier for prefix cls and POS cls
65
+ # the prefix will use BCEWithLogits for multiple labels cls
66
+ self.prefix_cls = nn.Linear(config.hidden_size, self.num_prefix_classes)
67
+ # and pos + feats will use good old cross entropy for single label
68
+ self.pos_cls = nn.Linear(config.hidden_size, self.num_pos_classes)
69
+ self.features_cls = nn.ModuleList([nn.Linear(config.hidden_size, len(features)) for _, features in ALL_FEATURES])
70
+ # and suffix + feats will also be cross entropy
71
+ self.suffix_cls = nn.Linear(config.hidden_size, self.num_suffix_classes)
72
+ self.suffix_features_cls = nn.ModuleList([nn.Linear(config.hidden_size, len(features)) for _, features in ALL_FEATURES])
73
+
74
+ # Initialize weights and apply final processing
75
+ self.post_init()
76
+
77
+ def forward(
78
+ self,
79
+ input_ids: Optional[torch.Tensor] = None,
80
+ attention_mask: Optional[torch.Tensor] = None,
81
+ token_type_ids: Optional[torch.Tensor] = None,
82
+ position_ids: Optional[torch.Tensor] = None,
83
+ labels: Optional[MorphLabels] = None,
84
+ head_mask: Optional[torch.Tensor] = None,
85
+ inputs_embeds: Optional[torch.Tensor] = None,
86
+ output_attentions: Optional[bool] = None,
87
+ output_hidden_states: Optional[bool] = None,
88
+ return_dict: Optional[bool] = None,
89
+ ):
90
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
91
+
92
+ bert_outputs = self.bert(
93
+ input_ids,
94
+ attention_mask=attention_mask,
95
+ token_type_ids=token_type_ids,
96
+ position_ids=position_ids,
97
+ head_mask=head_mask,
98
+ inputs_embeds=inputs_embeds,
99
+ output_attentions=output_attentions,
100
+ output_hidden_states=output_hidden_states,
101
+ return_dict=return_dict,
102
+ )
103
+
104
+
105
+ # run each of the classifiers on the transformed output
106
+ prefix_logits = self.prefix_cls(bert_outputs[0])
107
+ pos_logits = self.pos_cls(bert_outputs[0])
108
+ suffix_logits = self.suffix_cls(bert_outputs[0])
109
+ features_logits = [cls(bert_outputs[0]) for cls in self.features_cls]
110
+ suffix_features_logits = [cls(bert_outputs[0]) for cls in self.suffix_features_cls]
111
+
112
+
113
+ loss = None
114
+ if labels is not None:
115
+ # step 1: prefix labels loss
116
+ loss_fct = nn.BCEWithLogitsLoss(weight=(labels.prefix_labels != -1).float())
117
+ loss = loss_fct(prefix_logits, labels.prefix_labels)
118
+ # step 2: pos labels loss
119
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
120
+ loss += loss_fct(pos_logits.view(-1, self.num_pos_classes), labels.pos_labels.view(-1))
121
+ # step 2b: features
122
+ for feat_logits,feat_labels,num_features in zip(features_logits, labels.features_labels, self.num_features_classes):
123
+ loss += loss_fct(feat_logits.view(-1, num_features), feat_labels.view(-1))
124
+ # step 3: suffix logits loss
125
+ loss += loss_fct(suffix_logits.view(-1, self.num_suffix_classes), labels.suffix_labels.view(-1))
126
+ # step 3b: suffix features
127
+ for feat_logits,feat_labels,num_features in zip(suffix_features_logits, labels.suffix_features_labels, self.num_features_classes):
128
+ loss += loss_fct(feat_logits.view(-1, num_features), feat_labels.view(-1))
129
+
130
+ if not return_dict:
131
+ return (loss,(prefix_logits, pos_logits, features_logits, suffix_logits, suffix_features_logits)) + bert_outputs[2:]
132
+
133
+ return MorphTaggingOutput(
134
+ loss=loss,
135
+ logits=MorphLogitsOutput(prefix_logits, pos_logits, features_logits, suffix_logits, suffix_features_logits),
136
+ hidden_states=bert_outputs.hidden_states,
137
+ attentions=bert_outputs.attentions,
138
+ )
139
+
140
+ def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
141
+ # tokenize the inputs and convert them to relevant device
142
+ inputs = tokenizer(sentences, padding=padding, return_tensors='pt')
143
+ inputs = {k:v.to(self.device) for k,v in inputs.items()}
144
+ # calculate the logits
145
+ logits = self.forward(**inputs, return_dict=True).logits
146
+
147
+ prefix_logits, pos_logits, feats_logits, suffix_logits, suffix_feats_logits = \
148
+ logits["prefix_logits"], logits["pos_logits"], logits['features_logits'], logits['suffix_logits'], logits['suffix_features_logits']
149
+
150
+ prefix_predictions = (prefix_logits > 0.5).int() # Threshold at 0.5 for multi-label classification
151
+ pos_predictions = pos_logits.argmax(axis=-1)
152
+ suffix_predictions = suffix_logits.argmax(axis=-1)
153
+ feats_predictions = [logits.argmax(axis=-1) for logits in feats_logits]
154
+ suffix_feats_predictions = [logits.argmax(axis=-1) for logits in suffix_feats_logits]
155
+
156
+ # create the return dictionary
157
+ # for each sentence, return a dict object with the following files { text, tokens }
158
+ # Where tokens is a list of dicts, where each dict is:
159
+ # { pos: str, feats: dict, prefixes: List[str], suffix: str | bool, suffix_feats: dict | None}
160
+ special_tokens = set(tokenizer.special_tokens_map.values())
161
+ ret = []
162
+ for sent_idx,sentence in enumerate(sentences):
163
+ input_id_strs = tokenizer.convert_ids_to_tokens(inputs['input_ids'][sent_idx])
164
+ # iterate through each token in the sentence, ignoring special tokens
165
+ tokens = []
166
+ for token_idx,token_str in enumerate(input_id_strs):
167
+ if not token_str in special_tokens:
168
+ if token_str.startswith('##'):
169
+ tokens[-1]['token'] += token_str[2:]
170
+ continue
171
+ tokens.append(dict(
172
+ token=token_str,
173
+ pos=ALL_POS[pos_predictions[sent_idx, token_idx]],
174
+ feats=get_features_dict_from_predictions(feats_predictions, (sent_idx, token_idx)),
175
+ prefixes=[ALL_PREFIX_POS[idx] for idx,i in enumerate(prefix_predictions[sent_idx, token_idx]) if i > 0],
176
+ suffix=get_suffix_or_false(ALL_SUFFIX_POS[suffix_predictions[sent_idx, token_idx]]),
177
+ ))
178
+ if tokens[-1]['suffix']:
179
+ tokens[-1]['suffix_feats'] = get_features_dict_from_predictions(suffix_feats_predictions, (sent_idx, token_idx))
180
+ ret.append(dict(text=sentence, tokens=tokens))
181
+ return ret
182
+
183
+ def get_suffix_or_false(suffix):
184
+ return False if suffix == 'none' else suffix
185
+
186
+ def get_features_dict_from_predictions(predictions, idx):
187
+ ret = {}
188
+ for (feat_idx, (feat_name, feat_values)) in enumerate(ALL_FEATURES):
189
+ val = feat_values[predictions[feat_idx][idx]]
190
+ if val != 'none':
191
+ ret[feat_name] = val
192
+ return ret
193
+
194
+
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoModel": "BertForMorphTagging.BertForMorphTagging"
7
+ },
8
+ "attention_probs_dropout_prob": 0.1,
9
+ "gradient_checkpointing": false,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 768,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-12,
16
+ "max_position_embeddings": 512,
17
+ "model_type": "bert",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 0,
21
+ "position_embedding_type": "absolute",
22
+ "transformers_version": "4.6.0.dev0",
23
+ "type_vocab_size": 2,
24
+ "use_cache": true,
25
+ "vocab_size": 128000,
26
+ "newmodern": true
27
+ }
description.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ FullParagraphBigModern, Phase2 / 512, Iter 36,000
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9516849580201b58a84fa1859624fa91fa425d9b0046e6c31cc436873c01dd5
3
+ size 737655809
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "clean_up_tokenization_spaces": true,
3
+ "cls_token": "[CLS]",
4
+ "do_lower_case": true,
5
+ "mask_token": "[MASK]",
6
+ "model_max_length": 512,
7
+ "pad_token": "[PAD]",
8
+ "sep_token": "[SEP]",
9
+ "strip_accents": null,
10
+ "tokenize_chinese_chars": true,
11
+ "tokenizer_class": "BertTokenizer",
12
+ "unk_token": "[UNK]"
13
+ }
vocab.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fb90bfa35244d26f0065d1fcd0b5becc3da3d44d616a7e2aacaf6320b9fa2d0
3
+ size 1500244