davda54 commited on
Commit
55f9b9d
β€’
1 Parent(s): a8838b6
Files changed (8) hide show
  1. .gitattributes +1 -0
  2. app.py +36 -3
  3. config.json +27 -0
  4. dataset.py +74 -0
  5. lemma_rule.py +101 -0
  6. model.py +660 -0
  7. requirements.txt +5 -1
  8. tokenizer.py +231 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoint.bin filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -3,6 +3,19 @@ import tabulate
3
  import matplotlib.pyplot as plt
4
  import networkx as nx
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def render_dependency_tree(words, parents, labels):
8
  fig, ax = plt.subplots(figsize=(32, 16))
@@ -111,12 +124,17 @@ edge_labels = [
111
  if line and not line.startswith("#")
112
  ]
113
 
114
- def render_table(forms, lemmas, upos, xpos, feats, metadata, edges, edge_labels):
115
  feats = [[f"*{f.split('=')[0]}:* {f.split('=')[1]}" for f in (feat.split("|")) if '=' in f] for feat in feats]
116
  max_len = max(1, max([len(feat) for feat in feats]))
117
  feats = [feat + [""] * (max_len - len(feat)) for feat in feats]
118
  feats = list(zip(*feats))
119
 
 
 
 
 
 
120
  array = [
121
  [""] + forms,
122
  ["*LEMMAS:*"] + lemmas,
@@ -124,6 +142,7 @@ def render_table(forms, lemmas, upos, xpos, feats, metadata, edges, edge_labels)
124
  ["*XPOS:*"] + xpos,
125
  ["*UFEATS:*"] + list(feats[0]),
126
  *([""] + list(row) for row in feats[1:])
 
127
  ]
128
 
129
  #return tabulate.tabulate(array, headers="firstrow", tablefmt="unsafehtml")
@@ -141,13 +160,13 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
141
  gr.HTML(description)
142
 
143
  with gr.Row():
144
- with gr.Column(scale=1):
145
  source = gr.Textbox(
146
  label="Input sentence", placeholder="Write a sentende to parse", show_label=False, lines=1, max_lines=5, autofocus=True
147
  )
148
  submit = gr.Button("Submit", variant="primary")
149
 
150
- with gr.Column(scale=1):
151
  dataset = gr.Dataset(components=[gr.Textbox(visible=False)],
152
  label="Input examples",
153
  samples=[
@@ -161,4 +180,18 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
161
  table = gr.DataFrame(**render_table(forms, lemmas, upos, xpos, feats, metadata, edges, edge_labels), interactive=False, datatype="markdown")
162
  dependency_plot = gr.Plot(render_dependency_tree(forms, edges, edge_labels), container=False)
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  demo.launch()
 
3
  import matplotlib.pyplot as plt
4
  import networkx as nx
5
 
6
+ from model import Parser
7
+
8
+
9
+ parser = Parser()
10
+
11
+ def parse(text):
12
+ output = parser.parse(text)
13
+
14
+ dependency_tree = render_dependency_tree(output["forms"], output["heads"], output["deprels"])
15
+ table = render_table(output["forms"], output["lemmas"], output["upos"], output["xpos"], output["ne"])
16
+
17
+ return dependency_tree, table
18
+
19
 
20
  def render_dependency_tree(words, parents, labels):
21
  fig, ax = plt.subplots(figsize=(32, 16))
 
124
  if line and not line.startswith("#")
125
  ]
126
 
127
+ def render_table(forms, lemmas, upos, xpos, feats, named_entities):
128
  feats = [[f"*{f.split('=')[0]}:* {f.split('=')[1]}" for f in (feat.split("|")) if '=' in f] for feat in feats]
129
  max_len = max(1, max([len(feat) for feat in feats]))
130
  feats = [feat + [""] * (max_len - len(feat)) for feat in feats]
131
  feats = list(zip(*feats))
132
 
133
+ named_entities = [
134
+ "" if ne == "O" else f"<< {ne.split('-')[1]} >>" if ne.startswith("B") else ne.split('-')[1] if ne.startswith("I") and i - 1 < len(named_entities) and named_entities[i + 1].startswith("I") else f"{ne.split('-')[1]} >>"
135
+ for i, ne in enumerate(named_entities)
136
+ ]
137
+
138
  array = [
139
  [""] + forms,
140
  ["*LEMMAS:*"] + lemmas,
 
142
  ["*XPOS:*"] + xpos,
143
  ["*UFEATS:*"] + list(feats[0]),
144
  *([""] + list(row) for row in feats[1:])
145
+ ["*NE:*"] + named_entities,
146
  ]
147
 
148
  #return tabulate.tabulate(array, headers="firstrow", tablefmt="unsafehtml")
 
160
  gr.HTML(description)
161
 
162
  with gr.Row():
163
+ with gr.Column(scale=1, variant="panel"):
164
  source = gr.Textbox(
165
  label="Input sentence", placeholder="Write a sentende to parse", show_label=False, lines=1, max_lines=5, autofocus=True
166
  )
167
  submit = gr.Button("Submit", variant="primary")
168
 
169
+ with gr.Column(scale=1, variant="panel"):
170
  dataset = gr.Dataset(components=[gr.Textbox(visible=False)],
171
  label="Input examples",
172
  samples=[
 
180
  table = gr.DataFrame(**render_table(forms, lemmas, upos, xpos, feats, metadata, edges, edge_labels), interactive=False, datatype="markdown")
181
  dependency_plot = gr.Plot(render_dependency_tree(forms, edges, edge_labels), container=False)
182
 
183
+ source.submit(
184
+ fn=parse, inputs=["source"], outputs=["dependency_plot", "table"], queue=True
185
+ )
186
+ submit.click(
187
+ fn=parse, inputs=["source"], outputs=["dependency_plot", "table"], queue=True
188
+ )
189
+ dataset.click(
190
+ fn=lambda text: text, inputs=["dataset"], outputs=["source"]
191
+ ).then(
192
+ fn=parse, inputs=["source"], outputs=["dependency_plot", "table"], queue=True
193
+ )
194
+
195
+
196
+ demo.queue(max_size=32, concurrency_count=2)
197
  demo.launch()
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "NorbertForMaskedLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_norbert.NorbertConfig",
7
+ "AutoModel": "modeling_norbert.NorbertModel",
8
+ "AutoModelForMaskedLM": "modeling_norbert.NorbertForMaskedLM",
9
+ "AutoModelForSequenceClassification": "modeling_norbert.NorbertForSequenceClassification",
10
+ "AutoModelForTokenClassification": "modeling_norbert.NorbertForTokenClassification",
11
+ "AutoModelForQuestionAnswering": "modeling_norbert.NorbertForQuestionAnswering",
12
+ "AutoModelForMultipleChoice": "modeling_norbert.NorbertForMultipleChoice"
13
+ },
14
+ "attention_probs_dropout_prob": 0.1,
15
+ "hidden_dropout_prob": 0.1,
16
+ "hidden_size": 1024,
17
+ "intermediate_size": 2730,
18
+ "layer_norm_eps": 1e-07,
19
+ "max_position_embeddings": 512,
20
+ "num_attention_heads": 16,
21
+ "num_hidden_layers": 24,
22
+ "position_bucket_size": 32,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.23.1",
25
+ "vocab_size": 50000
26
+ }
27
+
dataset.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import AutoTokenizer
4
+
5
+ from tokenizer import NLTKWordTokenizer
6
+ from lemma_rule import apply_lemma_rule
7
+
8
+
9
+ class Dataset:
10
+ def __init__(self):
11
+ self.word_tokenizer = NLTKWordTokenizer()
12
+ self.subword_tokenizer = AutoTokenizer.from_pretrained("ltg/norbert3-large")
13
+
14
+ def prepare_input(self, sentence: str):
15
+ word_spans = list(self.word_tokenizer.span_tokenize(sentence))
16
+ forms = [sentence[start:end] for start, end in word_spans]
17
+
18
+ subwords, alignment = [self.subword_tokenizer.convert_tokens_to_ids("[CLS]")], [0]
19
+ for i, word in enumerate(forms):
20
+ space_before = (i == 0) or sentence[word_spans[i - 1][1]] == " "
21
+
22
+ # very very ugly hack ;(
23
+ encoding = self.subword_tokenizer(f"| {word}" if space_before else f"|{word}", add_special_tokens=False)
24
+ subwords += encoding.input_ids[1:]
25
+ alignment += (len(encoding.input_ids) - 1) * [i + 1]
26
+
27
+ subwords.append(self.subword_tokenizer.convert_tokens_to_ids("[SEP]"))
28
+ alignment.append(alignment[-1] + 1)
29
+
30
+ subwords = torch.tensor([subwords])
31
+ alignment = torch.tensor([alignment])
32
+ alignment = F.one_hot(alignment, num_classes=len(forms) + 2).float()
33
+
34
+ return forms, subwords, alignment
35
+
36
+ def decode_output(self, forms, lemma_p, upos_p, xpos_p, feats_p, dep_p, ne_p, head_p):
37
+ lemmas = [apply_lemma_rule(form, self.lemma_vocab[lemma_p[0, i, :].argmax().item()]) for i, form in enumerate(forms)]
38
+ upos = [self.upos_vocab[upos_p[0, i, :].argmax().item()] for i in range(len(forms))]
39
+ xpos = [self.xpos_vocab[xpos_p[0, i, :].argmax().item()] for i in range(len(forms))]
40
+ feats = [self.feats_vocab[feats_p[0, i, :].argmax().item()] for i in range(len(forms))]
41
+ heads = [head_p[0, i].item() for i in range(len(forms))]
42
+ deprel = [self.arc_dep_vocab[dep_p[0, i, :].argmax().item()] for i in range(len(forms))]
43
+ ne = [self.ne_vocab[ne_p[0, i, :].argmax().item()] for i in range(len(forms))]
44
+
45
+ return lemmas, upos, xpos, feats, heads, deprel, ne
46
+
47
+ # save state dict
48
+ def state_dict(self):
49
+ return {
50
+ "forms_vocab": self.forms_vocab,
51
+ "lemma_vocab": self.lemma_vocab,
52
+ "upos_vocab": self.upos_vocab,
53
+ "xpos_vocab": self.xpos_vocab,
54
+ "feats_vocab": self.feats_vocab,
55
+ "arc_dep_vocab": self.arc_dep_vocab,
56
+ "ne_vocab": self.ne_vocab
57
+ }
58
+
59
+ # load state dict
60
+ def load_state_dict(self, state_dict):
61
+ self.forms_vocab = state_dict["forms_vocab"]
62
+ self.lemma_vocab = state_dict["lemma_vocab"]
63
+ self.upos_vocab = state_dict["upos_vocab"]
64
+ self.xpos_vocab = state_dict["xpos_vocab"]
65
+ self.feats_vocab = state_dict["feats_vocab"]
66
+ self.arc_dep_vocab = state_dict["arc_dep_vocab"]
67
+ self.ne_vocab = state_dict["ne_vocab"]
68
+
69
+ self.lemma_indexer = {i: n for n, i in enumerate(self.lemma_vocab)}
70
+ self.upos_indexer = {i: n for n, i in enumerate(self.upos_vocab)}
71
+ self.xpos_indexer = {i: n for n, i in enumerate(self.xpos_vocab)}
72
+ self.feats_indexer = {i: n for n, i in enumerate(self.feats_vocab)}
73
+ self.ne_indexer = {i: n for n, i in enumerate(self.ne_vocab)}
74
+ self.arc_dep_indexer = {i: n for n, i in enumerate(self.arc_dep_vocab)}
lemma_rule.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def min_edit_script(source, target, allow_copy):
2
+ a = [[(len(source) + len(target) + 1, None)] * (len(target) + 1) for _ in range(len(source) + 1)]
3
+ for i in range(0, len(source) + 1):
4
+ for j in range(0, len(target) + 1):
5
+ if i == 0 and j == 0:
6
+ a[i][j] = (0, "")
7
+ else:
8
+ if allow_copy and i and j and source[i - 1] == target[j - 1] and a[i-1][j-1][0] < a[i][j][0]:
9
+ a[i][j] = (a[i-1][j-1][0], a[i-1][j-1][1] + "β†’")
10
+ if i and a[i-1][j][0] < a[i][j][0]:
11
+ a[i][j] = (a[i-1][j][0] + 1, a[i-1][j][1] + "-")
12
+ if j and a[i][j-1][0] < a[i][j][0]:
13
+ a[i][j] = (a[i][j-1][0] + 1, a[i][j-1][1] + "+" + target[j - 1])
14
+ return a[-1][-1][1]
15
+
16
+
17
+ def gen_lemma_rule(form, lemma, allow_copy):
18
+ form = form.lower()
19
+
20
+ previous_case = -1
21
+ lemma_casing = ""
22
+ for i, c in enumerate(lemma):
23
+ case = "↑" if c.lower() != c else "↓"
24
+ if case != previous_case:
25
+ lemma_casing += "{}{}{}".format("Β¦" if lemma_casing else "", case, i if i <= len(lemma) // 2 else i - len(lemma))
26
+ previous_case = case
27
+ lemma = lemma.lower()
28
+
29
+ best, best_form, best_lemma = 0, 0, 0
30
+ for l in range(len(lemma)):
31
+ for f in range(len(form)):
32
+ cpl = 0
33
+ while f + cpl < len(form) and l + cpl < len(lemma) and form[f + cpl] == lemma[l + cpl]: cpl += 1
34
+ if cpl > best:
35
+ best = cpl
36
+ best_form = f
37
+ best_lemma = l
38
+
39
+ rule = lemma_casing + ";"
40
+ if not best:
41
+ rule += "a" + lemma
42
+ else:
43
+ rule += "d{}Β¦{}".format(
44
+ min_edit_script(form[:best_form], lemma[:best_lemma], allow_copy),
45
+ min_edit_script(form[best_form + best:], lemma[best_lemma + best:], allow_copy),
46
+ )
47
+ return rule
48
+
49
+
50
+ def apply_lemma_rule(form, lemma_rule):
51
+ if lemma_rule == "<unk>":
52
+ return form
53
+
54
+ if ';' not in lemma_rule:
55
+ raise ValueError('lemma_rule %r for form %r missing semicolon' %(lemma_rule, form))
56
+
57
+ casing, rule = lemma_rule.split(";", 1)
58
+ if rule.startswith("a"):
59
+ lemma = rule[1:]
60
+ else:
61
+ form = form.lower()
62
+ rules, rule_sources = rule[1:].split("Β¦"), []
63
+ assert len(rules) == 2
64
+ for rule in rules:
65
+ source, i = 0, 0
66
+ while i < len(rule):
67
+ if rule[i] == "β†’" or rule[i] == "-":
68
+ source += 1
69
+ else:
70
+ assert rule[i] == "+"
71
+ i += 1
72
+ i += 1
73
+ rule_sources.append(source)
74
+
75
+ try:
76
+ lemma, form_offset = "", 0
77
+ for i in range(2):
78
+ j, offset = 0, (0 if i == 0 else len(form) - rule_sources[1])
79
+ while j < len(rules[i]):
80
+ if rules[i][j] == "β†’":
81
+ lemma += form[offset]
82
+ offset += 1
83
+ elif rules[i][j] == "-":
84
+ offset += 1
85
+ else:
86
+ assert(rules[i][j] == "+")
87
+ lemma += rules[i][j + 1]
88
+ j += 1
89
+ j += 1
90
+ if i == 0:
91
+ lemma += form[rule_sources[0] : len(form) - rule_sources[1]]
92
+ except:
93
+ lemma = form
94
+
95
+ for rule in casing.split("Β¦"):
96
+ if rule == "↓0": continue # The lemma is lowercased initially
97
+ if not rule: continue # Empty lemma might generate empty casing rule
98
+ case, offset = rule[0], int(rule[1:])
99
+ lemma = lemma[:offset] + (lemma[offset:].upper() if case == "↑" else lemma[offset:].lower())
100
+
101
+ return lemma
model.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+ import dependency_decoding
4
+ import ftfy
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.utils import checkpoint
10
+
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.activations import gelu_new
13
+ from transformers.modeling_outputs import (
14
+ MaskedLMOutput,
15
+ MultipleChoiceModelOutput,
16
+ QuestionAnsweringModelOutput,
17
+ SequenceClassifierOutput,
18
+ TokenClassifierOutput,
19
+ BaseModelOutput
20
+ )
21
+ from transformers.pytorch_utils import softmax_backward_data
22
+ from transformers.configuration_utils import PretrainedConfig
23
+
24
+ from dataset import Dataset
25
+
26
+
27
+ class NorbertConfig(PretrainedConfig):
28
+ """Configuration class to store the configuration of a `NorbertModel`.
29
+ """
30
+ def __init__(
31
+ self,
32
+ vocab_size=50000,
33
+ attention_probs_dropout_prob=0.1,
34
+ hidden_dropout_prob=0.1,
35
+ hidden_size=768,
36
+ intermediate_size=2048,
37
+ max_position_embeddings=512,
38
+ position_bucket_size=32,
39
+ num_attention_heads=12,
40
+ num_hidden_layers=12,
41
+ layer_norm_eps=1.0e-7,
42
+ output_all_encoded_layers=True,
43
+ **kwargs,
44
+ ):
45
+ super().__init__(**kwargs)
46
+
47
+ self.vocab_size = vocab_size
48
+ self.hidden_size = hidden_size
49
+ self.num_hidden_layers = num_hidden_layers
50
+ self.num_attention_heads = num_attention_heads
51
+ self.intermediate_size = intermediate_size
52
+ self.hidden_dropout_prob = hidden_dropout_prob
53
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
54
+ self.max_position_embeddings = max_position_embeddings
55
+ self.output_all_encoded_layers = output_all_encoded_layers
56
+ self.position_bucket_size = position_bucket_size
57
+ self.layer_norm_eps = layer_norm_eps
58
+
59
+
60
+ class Encoder(nn.Module):
61
+ def __init__(self, config, activation_checkpointing=False):
62
+ super().__init__()
63
+ self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_hidden_layers)])
64
+
65
+ for i, layer in enumerate(self.layers):
66
+ layer.mlp.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
67
+ layer.mlp.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
68
+
69
+ self.activation_checkpointing = activation_checkpointing
70
+
71
+ def forward(self, hidden_states, attention_mask, relative_embedding):
72
+ hidden_states, attention_probs = [hidden_states], []
73
+
74
+ for layer in self.layers:
75
+ if self.activation_checkpointing:
76
+ hidden_state, attention_p = checkpoint.checkpoint(layer, hidden_states[-1], attention_mask, relative_embedding)
77
+ else:
78
+ hidden_state, attention_p = layer(hidden_states[-1], attention_mask, relative_embedding)
79
+
80
+ hidden_states.append(hidden_state)
81
+ attention_probs.append(attention_p)
82
+
83
+ return hidden_states, attention_probs
84
+
85
+
86
+ class MaskClassifier(nn.Module):
87
+ def __init__(self, config, subword_embedding):
88
+ super().__init__()
89
+ self.nonlinearity = nn.Sequential(
90
+ nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False),
91
+ nn.Linear(config.hidden_size, config.hidden_size),
92
+ nn.GELU(),
93
+ nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False),
94
+ nn.Dropout(config.hidden_dropout_prob),
95
+ nn.Linear(subword_embedding.size(1), subword_embedding.size(0))
96
+ )
97
+ self.initialize(config.hidden_size, subword_embedding)
98
+
99
+ def initialize(self, hidden_size, embedding):
100
+ std = math.sqrt(2.0 / (5.0 * hidden_size))
101
+ nn.init.trunc_normal_(self.nonlinearity[1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
102
+ self.nonlinearity[-1].weight = embedding
103
+ self.nonlinearity[1].bias.data.zero_()
104
+ self.nonlinearity[-1].bias.data.zero_()
105
+
106
+ def forward(self, x, masked_lm_labels=None):
107
+ if masked_lm_labels is not None:
108
+ x = torch.index_select(x.flatten(0, 1), 0, torch.nonzero(masked_lm_labels.flatten() != -100).squeeze())
109
+ x = self.nonlinearity(x)
110
+ return x
111
+
112
+
113
+ class EncoderLayer(nn.Module):
114
+ def __init__(self, config):
115
+ super().__init__()
116
+ self.attention = Attention(config)
117
+ self.mlp = FeedForward(config)
118
+
119
+ def forward(self, x, padding_mask, relative_embedding):
120
+ attention_output, attention_probs = self.attention(x, padding_mask, relative_embedding)
121
+ x = x + attention_output
122
+ x = x + self.mlp(x)
123
+ return x, attention_probs
124
+
125
+
126
+ class GeGLU(nn.Module):
127
+ def forward(self, x):
128
+ x, gate = x.chunk(2, dim=-1)
129
+ x = x * gelu_new(gate)
130
+ return x
131
+
132
+
133
+ class FeedForward(nn.Module):
134
+ def __init__(self, config):
135
+ super().__init__()
136
+ self.mlp = nn.Sequential(
137
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False),
138
+ nn.Linear(config.hidden_size, 2*config.intermediate_size, bias=False),
139
+ GeGLU(),
140
+ nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False),
141
+ nn.Linear(config.intermediate_size, config.hidden_size, bias=False),
142
+ nn.Dropout(config.hidden_dropout_prob)
143
+ )
144
+ self.initialize(config.hidden_size)
145
+
146
+ def initialize(self, hidden_size):
147
+ std = math.sqrt(2.0 / (5.0 * hidden_size))
148
+ nn.init.trunc_normal_(self.mlp[1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
149
+ nn.init.trunc_normal_(self.mlp[-2].weight, mean=0.0, std=std, a=-2*std, b=2*std)
150
+
151
+ def forward(self, x):
152
+ return self.mlp(x)
153
+
154
+
155
+ class MaskedSoftmax(torch.autograd.Function):
156
+ @staticmethod
157
+ def forward(self, x, mask, dim):
158
+ self.dim = dim
159
+ x.masked_fill_(mask, float('-inf'))
160
+ x = torch.softmax(x, self.dim)
161
+ x.masked_fill_(mask, 0.0)
162
+ self.save_for_backward(x)
163
+ return x
164
+
165
+ @staticmethod
166
+ def backward(self, grad_output):
167
+ output, = self.saved_tensors
168
+ input_grad = softmax_backward_data(self, grad_output, output, self.dim, output)
169
+ return input_grad, None, None
170
+
171
+
172
+ class Attention(nn.Module):
173
+ def __init__(self, config):
174
+ super().__init__()
175
+
176
+ self.config = config
177
+
178
+ if config.hidden_size % config.num_attention_heads != 0:
179
+ raise ValueError(f"The hidden size {config.hidden_size} is not a multiple of the number of attention heads {config.num_attention_heads}")
180
+
181
+ self.hidden_size = config.hidden_size
182
+ self.num_heads = config.num_attention_heads
183
+ self.head_size = config.hidden_size // config.num_attention_heads
184
+
185
+ self.in_proj_qk = nn.Linear(config.hidden_size, 2*config.hidden_size, bias=True)
186
+ self.in_proj_v = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
187
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
188
+
189
+ self.pre_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False)
190
+ self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
191
+
192
+ position_indices = torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(1) \
193
+ - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
194
+ position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
195
+ position_indices = config.position_bucket_size - 1 + position_indices
196
+ self.register_buffer("position_indices", position_indices, persistent=True)
197
+
198
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
199
+ self.scale = 1.0 / math.sqrt(3 * self.head_size)
200
+ self.initialize()
201
+
202
+ def make_log_bucket_position(self, relative_pos, bucket_size, max_position):
203
+ sign = torch.sign(relative_pos)
204
+ mid = bucket_size // 2
205
+ abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos).clamp(max=max_position - 1))
206
+ log_pos = torch.ceil(torch.log(abs_pos / mid) / math.log((max_position-1) / mid) * (mid - 1)).int() + mid
207
+ bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
208
+ return bucket_pos
209
+
210
+ def initialize(self):
211
+ std = math.sqrt(2.0 / (5.0 * self.hidden_size))
212
+ nn.init.trunc_normal_(self.in_proj_qk.weight, mean=0.0, std=std, a=-2*std, b=2*std)
213
+ nn.init.trunc_normal_(self.in_proj_v.weight, mean=0.0, std=std, a=-2*std, b=2*std)
214
+ nn.init.trunc_normal_(self.out_proj.weight, mean=0.0, std=std, a=-2*std, b=2*std)
215
+ self.in_proj_qk.bias.data.zero_()
216
+ self.in_proj_v.bias.data.zero_()
217
+ self.out_proj.bias.data.zero_()
218
+
219
+ def compute_attention_scores(self, hidden_states, relative_embedding):
220
+ key_len, batch_size, _ = hidden_states.size()
221
+ query_len = key_len
222
+
223
+ if self.position_indices.size(0) < query_len:
224
+ position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
225
+ - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
226
+ position_indices = self.make_log_bucket_position(position_indices, self.position_bucket_size, 512)
227
+ position_indices = self.position_bucket_size - 1 + position_indices
228
+ self.position_indices = position_indices.to(hidden_states.device)
229
+
230
+ hidden_states = self.pre_layer_norm(hidden_states)
231
+
232
+ query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
233
+ value = self.in_proj_v(hidden_states) # shape: [T, B, D]
234
+
235
+ query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
236
+ key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
237
+ value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
238
+
239
+ attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
240
+
241
+ pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2T-1, 2D]
242
+ query_pos, key_pos = pos.view(-1, self.num_heads, 2*self.head_size).chunk(2, dim=2)
243
+ query = query.view(batch_size, self.num_heads, query_len, self.head_size)
244
+ key = key.view(batch_size, self.num_heads, query_len, self.head_size)
245
+
246
+ attention_c_p = torch.einsum("bhqd,khd->bhqk", query, key_pos.squeeze(1) * self.scale)
247
+ attention_p_c = torch.einsum("bhkd,qhd->bhqk", key * self.scale, query_pos.squeeze(1))
248
+
249
+ position_indices = self.position_indices[:query_len, :key_len].expand(batch_size, self.num_heads, -1, -1)
250
+ attention_c_p = attention_c_p.gather(3, position_indices)
251
+ attention_p_c = attention_p_c.gather(2, position_indices)
252
+
253
+ attention_scores = attention_scores.view(batch_size, self.num_heads, query_len, key_len)
254
+ attention_scores.add_(attention_c_p)
255
+ attention_scores.add_(attention_p_c)
256
+
257
+ return attention_scores, value
258
+
259
+ def compute_output(self, attention_probs, value):
260
+ attention_probs = self.dropout(attention_probs)
261
+ context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
262
+ context = context.transpose(0, 1).reshape(context.size(1), -1, self.hidden_size) # shape: [Q, B, H*D]
263
+ context = self.out_proj(context)
264
+ context = self.post_layer_norm(context)
265
+ context = self.dropout(context)
266
+ return context
267
+
268
+ def forward(self, hidden_states, attention_mask, relative_embedding):
269
+ attention_scores, value = self.compute_attention_scores(hidden_states, relative_embedding)
270
+ attention_probs = MaskedSoftmax.apply(attention_scores, attention_mask, -1)
271
+ return self.compute_output(attention_probs, value), attention_probs.detach()
272
+
273
+
274
+ class Embedding(nn.Module):
275
+ def __init__(self, config):
276
+ super().__init__()
277
+ self.hidden_size = config.hidden_size
278
+
279
+ self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
280
+ self.word_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
281
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
282
+
283
+ self.relative_embedding = nn.Parameter(torch.empty(2 * config.position_bucket_size - 1, config.hidden_size))
284
+ self.relative_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
285
+
286
+ self.initialize()
287
+
288
+ def initialize(self):
289
+ std = math.sqrt(2.0 / (5.0 * self.hidden_size))
290
+ nn.init.trunc_normal_(self.relative_embedding, mean=0.0, std=std, a=-2*std, b=2*std)
291
+ nn.init.trunc_normal_(self.word_embedding.weight, mean=0.0, std=std, a=-2*std, b=2*std)
292
+
293
+ def forward(self, input_ids):
294
+ word_embedding = self.dropout(self.word_layer_norm(self.word_embedding(input_ids)))
295
+ relative_embeddings = self.relative_layer_norm(self.relative_embedding)
296
+ return word_embedding, relative_embeddings
297
+
298
+
299
+ #
300
+ # HuggingFace wrappers
301
+ #
302
+
303
+ class NorbertPreTrainedModel(PreTrainedModel):
304
+ config_class = NorbertConfig
305
+ base_model_prefix = "norbert3"
306
+ supports_gradient_checkpointing = True
307
+
308
+ def _set_gradient_checkpointing(self, module, value=False):
309
+ if isinstance(module, Encoder):
310
+ module.activation_checkpointing = value
311
+
312
+ def _init_weights(self, module):
313
+ pass # everything is already initialized
314
+
315
+
316
+ class NorbertModel(NorbertPreTrainedModel):
317
+ def __init__(self, config, add_mlm_layer=False, gradient_checkpointing=False, **kwargs):
318
+ super().__init__(config, **kwargs)
319
+ self.config = config
320
+
321
+ self.embedding = Embedding(config)
322
+ self.transformer = Encoder(config, activation_checkpointing=gradient_checkpointing)
323
+ self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
324
+
325
+ def get_input_embeddings(self):
326
+ return self.embedding.word_embedding
327
+
328
+ def set_input_embeddings(self, value):
329
+ self.embedding.word_embedding = value
330
+
331
+ def get_contextualized_embeddings(
332
+ self,
333
+ input_ids: Optional[torch.Tensor] = None,
334
+ attention_mask: Optional[torch.Tensor] = None
335
+ ) -> List[torch.Tensor]:
336
+ if input_ids is not None:
337
+ input_shape = input_ids.size()
338
+ else:
339
+ raise ValueError("You have to specify input_ids")
340
+
341
+ batch_size, seq_length = input_shape
342
+ device = input_ids.device
343
+
344
+ if attention_mask is None:
345
+ attention_mask = torch.zeros(batch_size, seq_length, dtype=torch.bool, device=device)
346
+ else:
347
+ attention_mask = ~attention_mask.bool()
348
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
349
+
350
+ static_embeddings, relative_embedding = self.embedding(input_ids.t())
351
+ contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
352
+ contextualized_embeddings = [e.transpose(0, 1) for e in contextualized_embeddings]
353
+ last_layer = contextualized_embeddings[-1]
354
+ contextualized_embeddings = [contextualized_embeddings[0]] + [
355
+ contextualized_embeddings[i] - contextualized_embeddings[i - 1]
356
+ for i in range(1, len(contextualized_embeddings))
357
+ ]
358
+ return last_layer, contextualized_embeddings, attention_probs
359
+
360
+ def forward(
361
+ self,
362
+ input_ids: Optional[torch.Tensor] = None,
363
+ attention_mask: Optional[torch.Tensor] = None,
364
+ token_type_ids: Optional[torch.Tensor] = None,
365
+ position_ids: Optional[torch.Tensor] = None,
366
+ output_hidden_states: Optional[bool] = None,
367
+ output_attentions: Optional[bool] = None,
368
+ return_dict: Optional[bool] = None,
369
+ **kwargs
370
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
371
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
372
+
373
+ sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
374
+
375
+ if not return_dict:
376
+ return (
377
+ sequence_output,
378
+ *([contextualized_embeddings] if output_hidden_states else []),
379
+ *([attention_probs] if output_attentions else [])
380
+ )
381
+
382
+ return BaseModelOutput(
383
+ last_hidden_state=sequence_output,
384
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
385
+ attentions=attention_probs if output_attentions else None
386
+ )
387
+
388
+
389
+ class Classifier(nn.Module):
390
+ def __init__(self, hidden_size, vocab_size, dropout):
391
+ super().__init__()
392
+
393
+ self.transform = nn.Sequential(
394
+ nn.Linear(hidden_size, hidden_size),
395
+ nn.GELU(),
396
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
397
+ nn.Dropout(dropout),
398
+ nn.Linear(hidden_size, vocab_size)
399
+ )
400
+ self.initialize(hidden_size)
401
+
402
+ def initialize(self, hidden_size):
403
+ std = math.sqrt(2.0 / (5.0 * hidden_size))
404
+ nn.init.trunc_normal_(self.transform[0].weight, mean=0.0, std=std, a=-2*std, b=2*std)
405
+ nn.init.trunc_normal_(self.transform[-1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
406
+ self.transform[0].bias.data.zero_()
407
+ self.transform[-1].bias.data.zero_()
408
+
409
+ def forward(self, x):
410
+ return self.transform(x)
411
+
412
+
413
+ class ZeroClassifier(nn.Module):
414
+ def forward(self, x):
415
+ output = torch.zeros(x.size(0), x.size(1), 2, device=x.device, dtype=x.dtype)
416
+ output[:, :, 0] = 1.0
417
+ output[:, :, 1] = -1.0
418
+ return output
419
+
420
+
421
+ class EdgeClassifier(nn.Module):
422
+ def __init__(self, hidden_size, dep_hidden_size, vocab_size, dropout):
423
+ super().__init__()
424
+
425
+ self.head_dep_transform = nn.Sequential(
426
+ nn.Linear(hidden_size, hidden_size),
427
+ nn.GELU(),
428
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
429
+ nn.Dropout(dropout)
430
+ )
431
+ self.head_root_transform = nn.Sequential(
432
+ nn.Linear(hidden_size, hidden_size),
433
+ nn.GELU(),
434
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
435
+ nn.Dropout(dropout)
436
+ )
437
+ self.head_bilinear = nn.Parameter(torch.zeros(hidden_size, hidden_size))
438
+ self.head_linear_dep = nn.Linear(hidden_size, 1, bias=False)
439
+ self.head_linear_root = nn.Linear(hidden_size, 1, bias=False)
440
+ self.head_bias = nn.Parameter(torch.zeros(1))
441
+
442
+ self.dep_dep_transform = nn.Sequential(
443
+ nn.Linear(hidden_size, dep_hidden_size),
444
+ nn.GELU(),
445
+ nn.LayerNorm(dep_hidden_size, elementwise_affine=False),
446
+ nn.Dropout(dropout)
447
+ )
448
+ self.dep_root_transform = nn.Sequential(
449
+ nn.Linear(hidden_size, dep_hidden_size),
450
+ nn.GELU(),
451
+ nn.LayerNorm(dep_hidden_size, elementwise_affine=False),
452
+ nn.Dropout(dropout)
453
+ )
454
+ self.dep_bilinear = nn.Parameter(torch.zeros(dep_hidden_size, dep_hidden_size, vocab_size))
455
+ self.dep_linear_dep = nn.Linear(dep_hidden_size, vocab_size, bias=False)
456
+ self.dep_linear_root = nn.Linear(dep_hidden_size, vocab_size, bias=False)
457
+ self.dep_bias = nn.Parameter(torch.zeros(vocab_size))
458
+
459
+ self.hidden_size = hidden_size
460
+ self.dep_hidden_size = dep_hidden_size
461
+
462
+ self.mask_value = float("-inf")
463
+ self.initialize(hidden_size)
464
+
465
+ def initialize(self, hidden_size):
466
+ std = math.sqrt(2.0 / (5.0 * hidden_size))
467
+ nn.init.trunc_normal_(self.head_dep_transform[0].weight, mean=0.0, std=std, a=-2*std, b=2*std)
468
+ nn.init.trunc_normal_(self.head_root_transform[0].weight, mean=0.0, std=std, a=-2*std, b=2*std)
469
+ nn.init.trunc_normal_(self.dep_dep_transform[0].weight, mean=0.0, std=std, a=-2*std, b=2*std)
470
+ nn.init.trunc_normal_(self.dep_root_transform[0].weight, mean=0.0, std=std, a=-2*std, b=2*std)
471
+
472
+ nn.init.trunc_normal_(self.head_linear_dep.weight, mean=0.0, std=std, a=-2*std, b=2*std)
473
+ nn.init.trunc_normal_(self.head_linear_root.weight, mean=0.0, std=std, a=-2*std, b=2*std)
474
+ nn.init.trunc_normal_(self.dep_linear_dep.weight, mean=0.0, std=std, a=-2*std, b=2*std)
475
+ nn.init.trunc_normal_(self.dep_linear_root.weight, mean=0.0, std=std, a=-2*std, b=2*std)
476
+
477
+ self.head_dep_transform[0].bias.data.zero_()
478
+ self.head_root_transform[0].bias.data.zero_()
479
+ self.dep_dep_transform[0].bias.data.zero_()
480
+ self.dep_root_transform[0].bias.data.zero_()
481
+
482
+ def forward(self, head_x, dep_x, lengths, head_gold=None):
483
+ head_dep = self.head_dep_transform(head_x[:, 1:, :])
484
+ head_root = self.head_root_transform(head_x)
485
+ head_prediction = torch.einsum("bkn,nm,blm->bkl", head_dep, self.head_bilinear, head_root / math.sqrt(self.hidden_size)) \
486
+ + self.head_linear_dep(head_dep) + self.head_linear_root(head_root).transpose(1, 2) + self.head_bias
487
+
488
+ mask = (torch.arange(head_x.size(1)).unsqueeze(0) >= lengths.unsqueeze(1)).unsqueeze(1).to(head_x.device)
489
+ mask = mask | (torch.ones(head_x.size(1) - 1, head_x.size(1), dtype=torch.bool, device=head_x.device).tril(1) & torch.ones(head_x.size(1) - 1, head_x.size(1), dtype=torch.bool, device=head_x.device).triu(1))
490
+ head_prediction = head_prediction.masked_fill(mask, self.mask_value)
491
+
492
+ if head_gold is None:
493
+ head_logp = torch.log_softmax(head_prediction, dim=-1)
494
+ head_logp = F.pad(head_logp, (0, 0, 1, 0), value=torch.nan).cpu()
495
+ head_gold = []
496
+ for i, length in enumerate(lengths.tolist()):
497
+ head = self.max_spanning_tree(head_logp[i, :length, :length])
498
+ head = head + ((head_x.size(1) - 1) - len(head)) * [0]
499
+ head_gold.append(torch.tensor(head))
500
+ head_gold = torch.stack(head_gold).to(head_x.device)
501
+
502
+ dep_dep = self.dep_dep_transform(dep_x[:, 1:])
503
+ dep_root = dep_x.gather(1, head_gold.unsqueeze(-1).expand(-1, -1, dep_x.size(-1)).clamp(min=0))
504
+ dep_root = self.dep_root_transform(dep_root)
505
+ dep_prediction = torch.einsum("btm,mnl,btn->btl", dep_dep, self.dep_bilinear, dep_root / math.sqrt(self.dep_hidden_size)) \
506
+ + self.dep_linear_dep(dep_dep) + self.dep_linear_root(dep_root) + self.dep_bias
507
+
508
+ return head_prediction, dep_prediction, head_gold
509
+
510
+ def max_spanning_tree(self, weight_matrix):
511
+ weight_matrix = weight_matrix.clone()
512
+ # weight_matrix[:, 0] = torch.nan
513
+
514
+ # we need to make sure that the root is the parent of a single node
515
+ # first, we try to use the default weights, it should work in most cases
516
+ parents, _ = dependency_decoding.chu_liu_edmonds(weight_matrix.numpy().astype(float))
517
+
518
+ assert parents[0] == -1, f"{parents}\n{weight_matrix}"
519
+ parents = parents[1:]
520
+
521
+ # check if the root is the parent of a single node
522
+ if parents.count(0) == 1:
523
+ return parents
524
+
525
+ # if not, we need to modify the weights and try all possibilities
526
+ # we try to find the node that is the parent of the root
527
+ best_score = float("-inf")
528
+ best_parents = None
529
+
530
+ for i in range(len(parents)):
531
+ weight_matrix_mod = weight_matrix.clone()
532
+ weight_matrix_mod[:i+1, 0] = torch.nan
533
+ weight_matrix_mod[i+2:, 0] = torch.nan
534
+ parents, score = dependency_decoding.chu_liu_edmonds(weight_matrix_mod.numpy().astype(float))
535
+ parents = parents[1:]
536
+
537
+ if score > best_score:
538
+ best_score = score
539
+ best_parents = parents
540
+
541
+ def print_whole_matrix(matrix):
542
+ for i in range(matrix.shape[0]):
543
+ print(" ".join([str(x) for x in matrix[i]]))
544
+
545
+ assert best_parents is not None, f"{best_parents}\n{print_whole_matrix(weight_matrix)}"
546
+ return best_parents
547
+
548
+
549
+ class Model(nn.Module):
550
+ def __init__(self, dataset):
551
+ super().__init__()
552
+
553
+ # config = BertConfig("../../configs/base.json")
554
+ # self.bert = Bert(config)
555
+ # checkpoint = torch.load("../../checkpoints/test_wd=0.01/model.bin", map_location="cpu")
556
+ # self.bert.load_state_dict(checkpoint["model"], strict=False)
557
+
558
+ config = NorbertConfig.from_json_file("config.json")
559
+ self.bert = NorbertModel(config)
560
+
561
+ self.n_layers = config.num_hidden_layers
562
+
563
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
564
+ self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
565
+ self.upos_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
566
+ self.xpos_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
567
+ self.feats_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
568
+ self.lemma_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
569
+ self.head_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
570
+ self.dep_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
571
+ self.ner_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
572
+
573
+ self.lemma_classifier = Classifier(config.hidden_size, len(dataset.lemma_vocab), config.hidden_dropout_prob)
574
+ self.upos_classifier = Classifier(config.hidden_size, len(dataset.upos_vocab), config.hidden_dropout_prob) if len(dataset.upos_vocab) > 2 else ZeroClassifier()
575
+ self.xpos_classifier = Classifier(config.hidden_size, len(dataset.xpos_vocab), config.hidden_dropout_prob) if len(dataset.xpos_vocab) > 2 else ZeroClassifier()
576
+ self.feats_classifier = Classifier(config.hidden_size, len(dataset.feats_vocab), config.hidden_dropout_prob) if len(dataset.feats_vocab) > 2 else ZeroClassifier()
577
+ self.edge_classifier = EdgeClassifier(config.hidden_size, 128, len(dataset.arc_dep_vocab), config.hidden_dropout_prob)
578
+ self.ner_classifier = Classifier(config.hidden_size, len(dataset.ne_vocab), config.hidden_dropout_prob) if len(dataset.ne_vocab) > 2 else ZeroClassifier()
579
+
580
+ def forward(self, x, alignment_mask, subword_lengths, word_lengths, head_gold=None):
581
+ padding_mask = (torch.arange(x.size(1)).unsqueeze(0) < subword_lengths.unsqueeze(1)).to(x.device)
582
+ x = self.bert(x, padding_mask, output_hidden_states=True).hidden_states
583
+ x = torch.stack(x, dim=0)
584
+
585
+ upos_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.upos_layer_score, dim=0))
586
+ xpos_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.xpos_layer_score, dim=0))
587
+ feats_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.feats_layer_score, dim=0))
588
+ lemma_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.lemma_layer_score, dim=0))
589
+ head_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.head_layer_score, dim=0))
590
+ dep_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.dep_layer_score, dim=0))
591
+ ne_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.ner_layer_score, dim=0))
592
+
593
+ upos_x = torch.einsum("bsd,bst->btd", upos_x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).clamp(min=1.0)
594
+ xpos_x = torch.einsum("bsd,bst->btd", xpos_x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).clamp(min=1.0)
595
+ feats_x = torch.einsum("bsd,bst->btd", feats_x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).clamp(min=1.0)
596
+ lemma_x = torch.einsum("bsd,bst->btd", lemma_x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).clamp(min=1.0)
597
+ head_x = torch.einsum("bsd,bst->btd", head_x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).clamp(min=1.0)
598
+ dep_x = torch.einsum("bsd,bst->btd", dep_x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).clamp(min=1.0)
599
+ ne_x = torch.einsum("bsd, bst -> btd", ne_x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).clamp(min=1.0)
600
+
601
+ upos_x = self.dropout(self.layer_norm(upos_x[:, 1:-1, :]))
602
+ xpos_x = self.dropout(self.layer_norm(xpos_x[:, 1:-1, :]))
603
+ feats_x = self.dropout(self.layer_norm(feats_x[:, 1:-1, :]))
604
+ lemma_x = self.dropout(self.layer_norm(lemma_x[:, 1:-1, :]))
605
+ head_x = self.dropout(self.layer_norm(head_x[:, 0:-1, :]))
606
+ dep_x = self.dropout(self.layer_norm(dep_x[:, 0:-1, :]))
607
+ ne_x = self.dropout(self.layer_norm(ne_x[:, 1:-1, :]))
608
+
609
+ lemma_preds = self.lemma_classifier(lemma_x)
610
+ upos_preds = self.upos_classifier(upos_x)
611
+ xpos_preds = self.xpos_classifier(xpos_x)
612
+ feats_preds = self.feats_classifier(feats_x)
613
+ ne_preds = self.ner_classifier(feats_x)
614
+ head_prediction, dep_prediction, head_liu = self.edge_classifier(head_x, dep_x, word_lengths, head_gold)
615
+
616
+ return lemma_preds, upos_preds, xpos_preds, feats_preds, head_prediction, dep_prediction, ne_preds, head_liu
617
+
618
+
619
+ class Parser:
620
+ def __init__(self):
621
+ checkpoint = torch.load("checkpoint.bin", map_location="cpu")
622
+
623
+ self.dataset = Dataset()
624
+ self.dataset.load_state_dict(checkpoint["dataset"])
625
+
626
+ self.model = Model(self.dataset)
627
+ self.model.load_state_dict(checkpoint["model"])
628
+ self.model.eval()
629
+ del checkpoint
630
+
631
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
632
+ self.model.to(self.device)
633
+
634
+ def parse(self, sentence):
635
+ sentence = ftfy.fix_text(sentence.strip())
636
+ forms, subwords, alignment = self.dataset.prepare_input(sentence)
637
+
638
+ with torch.no_grad():
639
+ output = self.model(
640
+ subwords.to(self.device),
641
+ alignment.to(self.device),
642
+ torch.tensor([len(forms) + 1], device=self.device),
643
+ torch.tensor([subwords.size(1)], device=self.device)
644
+ )
645
+
646
+ lemma_p, upos_p, xpos_p, feats_p, _, dep_p, ne_p, head_p = output
647
+ lemmas, upos, xpos, feats, heads, deprel, ne = self.dataset.decode_output(
648
+ forms, lemma_p, upos_p, xpos_p, feats_p, dep_p, ne_p, head_p
649
+ )
650
+
651
+ return {
652
+ "forms": forms,
653
+ "lemmas": lemmas,
654
+ "upos": upos,
655
+ "xpos": xpos,
656
+ "feats": feats,
657
+ "heads": heads,
658
+ "deprel": deprel,
659
+ "ne": ne
660
+ }
requirements.txt CHANGED
@@ -1,4 +1,8 @@
1
  tabulate
2
  matplotlib
3
  networkx
4
- pygraphviz
 
 
 
 
 
1
  tabulate
2
  matplotlib
3
  networkx
4
+ pygraphviz
5
+ ftfy
6
+ torch
7
+ transformers
8
+ dependency_decoding
tokenizer.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Natural Language Toolkit: NLTK's very own tokenizer, slightly modified.
2
+ #
3
+ # Copyright (C) 2001-2023 NLTK Project
4
+ # Author: Liling Tan
5
+ # Tom Aarsen <> (modifications)
6
+ # URL: <https://www.nltk.org>
7
+
8
+
9
+ import re
10
+ import warnings
11
+ from typing import Iterator, List, Tuple
12
+
13
+
14
+ def align_tokens(tokens, sentence):
15
+ """
16
+ This module attempt to find the offsets of the tokens in *s*, as a sequence
17
+ of ``(start, end)`` tuples, given the tokens and also the source string.
18
+
19
+ >>> from nltk.tokenize import TreebankWordTokenizer
20
+ >>> from nltk.tokenize.util import align_tokens
21
+ >>> s = str("The plane, bound for St Petersburg, crashed in Egypt's "
22
+ ... "Sinai desert just 23 minutes after take-off from Sharm el-Sheikh "
23
+ ... "on Saturday.")
24
+ >>> tokens = TreebankWordTokenizer().tokenize(s)
25
+ >>> expected = [(0, 3), (4, 9), (9, 10), (11, 16), (17, 20), (21, 23),
26
+ ... (24, 34), (34, 35), (36, 43), (44, 46), (47, 52), (52, 54),
27
+ ... (55, 60), (61, 67), (68, 72), (73, 75), (76, 83), (84, 89),
28
+ ... (90, 98), (99, 103), (104, 109), (110, 119), (120, 122),
29
+ ... (123, 131), (131, 132)]
30
+ >>> output = list(align_tokens(tokens, s))
31
+ >>> len(tokens) == len(expected) == len(output) # Check that length of tokens and tuples are the same.
32
+ True
33
+ >>> expected == list(align_tokens(tokens, s)) # Check that the output is as expected.
34
+ True
35
+ >>> tokens == [s[start:end] for start, end in output] # Check that the slices of the string corresponds to the tokens.
36
+ True
37
+
38
+ :param tokens: The list of strings that are the result of tokenization
39
+ :type tokens: list(str)
40
+ :param sentence: The original string
41
+ :type sentence: str
42
+ :rtype: list(tuple(int,int))
43
+ """
44
+ point = 0
45
+ offsets = []
46
+ for token in tokens:
47
+ try:
48
+ start = sentence.index(token, point)
49
+ except ValueError as e:
50
+ raise ValueError(f'substring "{token}" not found in "{sentence}"') from e
51
+ point = start + len(token)
52
+ offsets.append((start, point))
53
+ return offsets
54
+
55
+
56
+ class NLTKWordTokenizer:
57
+ """
58
+ The NLTK tokenizer that has improved upon the TreebankWordTokenizer.
59
+
60
+ This is the method that is invoked by ``word_tokenize()``. It assumes that the
61
+ text has already been segmented into sentences, e.g. using ``sent_tokenize()``.
62
+
63
+ The tokenizer is "destructive" such that the regexes applied will munge the
64
+ input string to a state beyond re-construction. It is possible to apply
65
+ `TreebankWordDetokenizer.detokenize` to the tokenized outputs of
66
+ `NLTKDestructiveWordTokenizer.tokenize` but there's no guarantees to
67
+ revert to the original string.
68
+ """
69
+
70
+ # Starting quotes.
71
+ STARTING_QUOTES = [
72
+ (re.compile("([Β«β€œβ€˜β€ž]|[`]+)", re.U), r" \1 "),
73
+ (re.compile(r"^\""), r' " '),
74
+ (re.compile(r"(``)"), r" \1 "),
75
+ (re.compile(r"([ \(\[{<])(\"|\'{2})"), r'\1 " '),
76
+ # (re.compile(r"(?i)(\')(?!re|ve|ll|m|t|s|d|n)(\w)\b", re.U), r"\1 \2"),
77
+ ]
78
+
79
+ # Ending quotes.
80
+ ENDING_QUOTES = [
81
+ (re.compile("([»”’])", re.U), r" \1 "),
82
+ (re.compile(r"''"), " '' "),
83
+ (re.compile(r'"'), ' " '),
84
+ (re.compile(r"([^' ])('[sS]|'[mM]|'[dD]|') "), r"\1 \2 "),
85
+ # (re.compile(r"([^' ])('ll|'LL|'re|'RE|'ve|'VE|n't|N'T) "), r"\1 \2 "),
86
+ ]
87
+
88
+ # For improvements for starting/closing quotes from TreebankWordTokenizer,
89
+ # see discussion on https://github.com/nltk/nltk/pull/1437
90
+ # Adding to TreebankWordTokenizer, nltk.word_tokenize now splits on
91
+ # - chervon quotes u'\xab' and u'\xbb' .
92
+ # - unicode quotes u'\u2018', u'\u2019', u'\u201c' and u'\u201d'
93
+ # See https://github.com/nltk/nltk/issues/1995#issuecomment-376741608
94
+ # Also, behavior of splitting on clitics now follows Stanford CoreNLP
95
+ # - clitics covered (?!re|ve|ll|m|t|s|d)(\w)\b
96
+
97
+ # Punctuation.
98
+ PUNCTUATION = [
99
+ (re.compile(r'([^\.])(\.)([\]\)}>"\'' "»”’ " r"]*)\s*$", re.U), r"\1 \2 \3 "),
100
+ (re.compile(r"([:,])([^\d])"), r" \1 \2"),
101
+ (re.compile(r"([:,])$"), r" \1 "),
102
+ (
103
+ re.compile(r"\.{2,}", re.U),
104
+ r" \g<0> ",
105
+ ), # See https://github.com/nltk/nltk/pull/2322
106
+ (re.compile(r"[;@#$%&]"), r" \g<0> "),
107
+ (
108
+ re.compile(r'([^\.])(\.)([\]\)}>"\']*)\s*$'),
109
+ r"\1 \2\3 ",
110
+ ), # Handles the final period.
111
+ (re.compile(r"[?!]"), r" \g<0> "),
112
+ (re.compile(r"([^'])' "), r"\1 ' "),
113
+ (
114
+ re.compile(r"[*]", re.U),
115
+ r" \g<0> ",
116
+ ), # See https://github.com/nltk/nltk/pull/2322
117
+ ]
118
+
119
+ # Pads parentheses
120
+ PARENS_BRACKETS = (re.compile(r"[\]\[\(\)\{\}\<\>]"), r" \g<0> ")
121
+
122
+ # Optionally: Convert parentheses, brackets and converts them to PTB symbols.
123
+ # CONVERT_PARENTHESES = [
124
+ # (re.compile(r"\("), "-LRB-"),
125
+ # (re.compile(r"\)"), "-RRB-"),
126
+ # (re.compile(r"\["), "-LSB-"),
127
+ # (re.compile(r"\]"), "-RSB-"),
128
+ # (re.compile(r"\{"), "-LCB-"),
129
+ # (re.compile(r"\}"), "-RCB-"),
130
+ # ]
131
+
132
+ DOUBLE_DASHES = (re.compile(r"--"), r" -- ")
133
+
134
+ # List of contractions adapted from Robert MacIntyre's tokenizer.
135
+ # _contractions = MacIntyreContractions()
136
+ # CONTRACTIONS2 = list(map(re.compile, _contractions.CONTRACTIONS2))
137
+ # CONTRACTIONS3 = list(map(re.compile, _contractions.CONTRACTIONS3))
138
+
139
+ def tokenize(
140
+ self, text: str
141
+ ) -> List[str]:
142
+ r"""Return a tokenized copy of `text`.
143
+
144
+ >>> from nltk.tokenize import NLTKWordTokenizer
145
+ >>> s = '''Good muffins cost $3.88 (roughly 3,36 euros)\nin New York. Please buy me\ntwo of them.\nThanks.'''
146
+ >>> NLTKWordTokenizer().tokenize(s) # doctest: +NORMALIZE_WHITESPACE
147
+ ['Good', 'muffins', 'cost', '$', '3.88', '(', 'roughly', '3,36',
148
+ 'euros', ')', 'in', 'New', 'York.', 'Please', 'buy', 'me', 'two',
149
+ 'of', 'them.', 'Thanks', '.']
150
+ >>> NLTKWordTokenizer().tokenize(s, convert_parentheses=True) # doctest: +NORMALIZE_WHITESPACE
151
+ ['Good', 'muffins', 'cost', '$', '3.88', '-LRB-', 'roughly', '3,36',
152
+ 'euros', '-RRB-', 'in', 'New', 'York.', 'Please', 'buy', 'me', 'two',
153
+ 'of', 'them.', 'Thanks', '.']
154
+
155
+
156
+ :param text: A string with a sentence or sentences.
157
+ :type text: str
158
+ :param convert_parentheses: if True, replace parentheses to PTB symbols,
159
+ e.g. `(` to `-LRB-`. Defaults to False.
160
+ :type convert_parentheses: bool, optional
161
+ :param return_str: If True, return tokens as space-separated string,
162
+ defaults to False.
163
+ :type return_str: bool, optional
164
+ :return: List of tokens from `text`.
165
+ :rtype: List[str]
166
+ """
167
+
168
+ for regexp, substitution in self.STARTING_QUOTES:
169
+ text = regexp.sub(substitution, text)
170
+
171
+ for regexp, substitution in self.PUNCTUATION:
172
+ text = regexp.sub(substitution, text)
173
+
174
+ # Handles parentheses.
175
+ regexp, substitution = self.PARENS_BRACKETS
176
+ text = regexp.sub(substitution, text)
177
+
178
+ # Handles double dash.
179
+ regexp, substitution = self.DOUBLE_DASHES
180
+ text = regexp.sub(substitution, text)
181
+
182
+ # add extra space to make things easier
183
+ text = " " + text + " "
184
+
185
+ for regexp, substitution in self.ENDING_QUOTES:
186
+ text = regexp.sub(substitution, text)
187
+
188
+ return text.split()
189
+
190
+ def span_tokenize(self, text: str) -> Iterator[Tuple[int, int]]:
191
+ r"""
192
+ Returns the spans of the tokens in ``text``.
193
+ Uses the post-hoc nltk.tokens.align_tokens to return the offset spans.
194
+
195
+ >>> from nltk.tokenize import NLTKWordTokenizer
196
+ >>> s = '''Good muffins cost $3.88\nin New (York). Please (buy) me\ntwo of them.\n(Thanks).'''
197
+ >>> expected = [(0, 4), (5, 12), (13, 17), (18, 19), (19, 23),
198
+ ... (24, 26), (27, 30), (31, 32), (32, 36), (36, 37), (37, 38),
199
+ ... (40, 46), (47, 48), (48, 51), (51, 52), (53, 55), (56, 59),
200
+ ... (60, 62), (63, 68), (69, 70), (70, 76), (76, 77), (77, 78)]
201
+ >>> list(NLTKWordTokenizer().span_tokenize(s)) == expected
202
+ True
203
+ >>> expected = ['Good', 'muffins', 'cost', '$', '3.88', 'in',
204
+ ... 'New', '(', 'York', ')', '.', 'Please', '(', 'buy', ')',
205
+ ... 'me', 'two', 'of', 'them.', '(', 'Thanks', ')', '.']
206
+ >>> [s[start:end] for start, end in NLTKWordTokenizer().span_tokenize(s)] == expected
207
+ True
208
+
209
+ :param text: A string with a sentence or sentences.
210
+ :type text: str
211
+ :yield: Tuple[int, int]
212
+ """
213
+ raw_tokens = self.tokenize(text)
214
+
215
+ # Convert converted quotes back to original double quotes
216
+ # Do this only if original text contains double quote(s) or double
217
+ # single-quotes (because '' might be transformed to `` if it is
218
+ # treated as starting quotes).
219
+ # if ('"' in text) or ("''" in text):
220
+ # # Find double quotes and converted quotes
221
+ # matched = [m.group() for m in re.finditer(r"``|'{2}|\"", text)]
222
+
223
+ # # Replace converted quotes back to double quotes
224
+ # tokens = [
225
+ # matched.pop(0) if tok in ['"', "``", "''"] else tok
226
+ # for tok in raw_tokens
227
+ # ]
228
+ # else:
229
+ tokens = raw_tokens
230
+
231
+ yield from align_tokens(tokens, text)