davda54
commited on
Commit
β’
55f9b9d
1
Parent(s):
a8838b6
parser
Browse files- .gitattributes +1 -0
- app.py +36 -3
- config.json +27 -0
- dataset.py +74 -0
- lemma_rule.py +101 -0
- model.py +660 -0
- requirements.txt +5 -1
- tokenizer.py +231 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
checkpoint.bin filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -3,6 +3,19 @@ import tabulate
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
import networkx as nx
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def render_dependency_tree(words, parents, labels):
|
8 |
fig, ax = plt.subplots(figsize=(32, 16))
|
@@ -111,12 +124,17 @@ edge_labels = [
|
|
111 |
if line and not line.startswith("#")
|
112 |
]
|
113 |
|
114 |
-
def render_table(forms, lemmas, upos, xpos, feats,
|
115 |
feats = [[f"*{f.split('=')[0]}:* {f.split('=')[1]}" for f in (feat.split("|")) if '=' in f] for feat in feats]
|
116 |
max_len = max(1, max([len(feat) for feat in feats]))
|
117 |
feats = [feat + [""] * (max_len - len(feat)) for feat in feats]
|
118 |
feats = list(zip(*feats))
|
119 |
|
|
|
|
|
|
|
|
|
|
|
120 |
array = [
|
121 |
[""] + forms,
|
122 |
["*LEMMAS:*"] + lemmas,
|
@@ -124,6 +142,7 @@ def render_table(forms, lemmas, upos, xpos, feats, metadata, edges, edge_labels)
|
|
124 |
["*XPOS:*"] + xpos,
|
125 |
["*UFEATS:*"] + list(feats[0]),
|
126 |
*([""] + list(row) for row in feats[1:])
|
|
|
127 |
]
|
128 |
|
129 |
#return tabulate.tabulate(array, headers="firstrow", tablefmt="unsafehtml")
|
@@ -141,13 +160,13 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
|
|
141 |
gr.HTML(description)
|
142 |
|
143 |
with gr.Row():
|
144 |
-
with gr.Column(scale=1):
|
145 |
source = gr.Textbox(
|
146 |
label="Input sentence", placeholder="Write a sentende to parse", show_label=False, lines=1, max_lines=5, autofocus=True
|
147 |
)
|
148 |
submit = gr.Button("Submit", variant="primary")
|
149 |
|
150 |
-
with gr.Column(scale=1):
|
151 |
dataset = gr.Dataset(components=[gr.Textbox(visible=False)],
|
152 |
label="Input examples",
|
153 |
samples=[
|
@@ -161,4 +180,18 @@ with gr.Blocks(theme='sudeepshouche/minimalist', css=custom_css) as demo:
|
|
161 |
table = gr.DataFrame(**render_table(forms, lemmas, upos, xpos, feats, metadata, edges, edge_labels), interactive=False, datatype="markdown")
|
162 |
dependency_plot = gr.Plot(render_dependency_tree(forms, edges, edge_labels), container=False)
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
demo.launch()
|
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
import networkx as nx
|
5 |
|
6 |
+
from model import Parser
|
7 |
+
|
8 |
+
|
9 |
+
parser = Parser()
|
10 |
+
|
11 |
+
def parse(text):
|
12 |
+
output = parser.parse(text)
|
13 |
+
|
14 |
+
dependency_tree = render_dependency_tree(output["forms"], output["heads"], output["deprels"])
|
15 |
+
table = render_table(output["forms"], output["lemmas"], output["upos"], output["xpos"], output["ne"])
|
16 |
+
|
17 |
+
return dependency_tree, table
|
18 |
+
|
19 |
|
20 |
def render_dependency_tree(words, parents, labels):
|
21 |
fig, ax = plt.subplots(figsize=(32, 16))
|
|
|
124 |
if line and not line.startswith("#")
|
125 |
]
|
126 |
|
127 |
+
def render_table(forms, lemmas, upos, xpos, feats, named_entities):
|
128 |
feats = [[f"*{f.split('=')[0]}:* {f.split('=')[1]}" for f in (feat.split("|")) if '=' in f] for feat in feats]
|
129 |
max_len = max(1, max([len(feat) for feat in feats]))
|
130 |
feats = [feat + [""] * (max_len - len(feat)) for feat in feats]
|
131 |
feats = list(zip(*feats))
|
132 |
|
133 |
+
named_entities = [
|
134 |
+
"" if ne == "O" else f"<< {ne.split('-')[1]} >>" if ne.startswith("B") else ne.split('-')[1] if ne.startswith("I") and i - 1 < len(named_entities) and named_entities[i + 1].startswith("I") else f"{ne.split('-')[1]} >>"
|
135 |
+
for i, ne in enumerate(named_entities)
|
136 |
+
]
|
137 |
+
|
138 |
array = [
|
139 |
[""] + forms,
|
140 |
["*LEMMAS:*"] + lemmas,
|
|
|
142 |
["*XPOS:*"] + xpos,
|
143 |
["*UFEATS:*"] + list(feats[0]),
|
144 |
*([""] + list(row) for row in feats[1:])
|
145 |
+
["*NE:*"] + named_entities,
|
146 |
]
|
147 |
|
148 |
#return tabulate.tabulate(array, headers="firstrow", tablefmt="unsafehtml")
|
|
|
160 |
gr.HTML(description)
|
161 |
|
162 |
with gr.Row():
|
163 |
+
with gr.Column(scale=1, variant="panel"):
|
164 |
source = gr.Textbox(
|
165 |
label="Input sentence", placeholder="Write a sentende to parse", show_label=False, lines=1, max_lines=5, autofocus=True
|
166 |
)
|
167 |
submit = gr.Button("Submit", variant="primary")
|
168 |
|
169 |
+
with gr.Column(scale=1, variant="panel"):
|
170 |
dataset = gr.Dataset(components=[gr.Textbox(visible=False)],
|
171 |
label="Input examples",
|
172 |
samples=[
|
|
|
180 |
table = gr.DataFrame(**render_table(forms, lemmas, upos, xpos, feats, metadata, edges, edge_labels), interactive=False, datatype="markdown")
|
181 |
dependency_plot = gr.Plot(render_dependency_tree(forms, edges, edge_labels), container=False)
|
182 |
|
183 |
+
source.submit(
|
184 |
+
fn=parse, inputs=["source"], outputs=["dependency_plot", "table"], queue=True
|
185 |
+
)
|
186 |
+
submit.click(
|
187 |
+
fn=parse, inputs=["source"], outputs=["dependency_plot", "table"], queue=True
|
188 |
+
)
|
189 |
+
dataset.click(
|
190 |
+
fn=lambda text: text, inputs=["dataset"], outputs=["source"]
|
191 |
+
).then(
|
192 |
+
fn=parse, inputs=["source"], outputs=["dependency_plot", "table"], queue=True
|
193 |
+
)
|
194 |
+
|
195 |
+
|
196 |
+
demo.queue(max_size=32, concurrency_count=2)
|
197 |
demo.launch()
|
config.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"NorbertForMaskedLM"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "modeling_norbert.NorbertConfig",
|
7 |
+
"AutoModel": "modeling_norbert.NorbertModel",
|
8 |
+
"AutoModelForMaskedLM": "modeling_norbert.NorbertForMaskedLM",
|
9 |
+
"AutoModelForSequenceClassification": "modeling_norbert.NorbertForSequenceClassification",
|
10 |
+
"AutoModelForTokenClassification": "modeling_norbert.NorbertForTokenClassification",
|
11 |
+
"AutoModelForQuestionAnswering": "modeling_norbert.NorbertForQuestionAnswering",
|
12 |
+
"AutoModelForMultipleChoice": "modeling_norbert.NorbertForMultipleChoice"
|
13 |
+
},
|
14 |
+
"attention_probs_dropout_prob": 0.1,
|
15 |
+
"hidden_dropout_prob": 0.1,
|
16 |
+
"hidden_size": 1024,
|
17 |
+
"intermediate_size": 2730,
|
18 |
+
"layer_norm_eps": 1e-07,
|
19 |
+
"max_position_embeddings": 512,
|
20 |
+
"num_attention_heads": 16,
|
21 |
+
"num_hidden_layers": 24,
|
22 |
+
"position_bucket_size": 32,
|
23 |
+
"torch_dtype": "float32",
|
24 |
+
"transformers_version": "4.23.1",
|
25 |
+
"vocab_size": 50000
|
26 |
+
}
|
27 |
+
|
dataset.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
|
5 |
+
from tokenizer import NLTKWordTokenizer
|
6 |
+
from lemma_rule import apply_lemma_rule
|
7 |
+
|
8 |
+
|
9 |
+
class Dataset:
|
10 |
+
def __init__(self):
|
11 |
+
self.word_tokenizer = NLTKWordTokenizer()
|
12 |
+
self.subword_tokenizer = AutoTokenizer.from_pretrained("ltg/norbert3-large")
|
13 |
+
|
14 |
+
def prepare_input(self, sentence: str):
|
15 |
+
word_spans = list(self.word_tokenizer.span_tokenize(sentence))
|
16 |
+
forms = [sentence[start:end] for start, end in word_spans]
|
17 |
+
|
18 |
+
subwords, alignment = [self.subword_tokenizer.convert_tokens_to_ids("[CLS]")], [0]
|
19 |
+
for i, word in enumerate(forms):
|
20 |
+
space_before = (i == 0) or sentence[word_spans[i - 1][1]] == " "
|
21 |
+
|
22 |
+
# very very ugly hack ;(
|
23 |
+
encoding = self.subword_tokenizer(f"| {word}" if space_before else f"|{word}", add_special_tokens=False)
|
24 |
+
subwords += encoding.input_ids[1:]
|
25 |
+
alignment += (len(encoding.input_ids) - 1) * [i + 1]
|
26 |
+
|
27 |
+
subwords.append(self.subword_tokenizer.convert_tokens_to_ids("[SEP]"))
|
28 |
+
alignment.append(alignment[-1] + 1)
|
29 |
+
|
30 |
+
subwords = torch.tensor([subwords])
|
31 |
+
alignment = torch.tensor([alignment])
|
32 |
+
alignment = F.one_hot(alignment, num_classes=len(forms) + 2).float()
|
33 |
+
|
34 |
+
return forms, subwords, alignment
|
35 |
+
|
36 |
+
def decode_output(self, forms, lemma_p, upos_p, xpos_p, feats_p, dep_p, ne_p, head_p):
|
37 |
+
lemmas = [apply_lemma_rule(form, self.lemma_vocab[lemma_p[0, i, :].argmax().item()]) for i, form in enumerate(forms)]
|
38 |
+
upos = [self.upos_vocab[upos_p[0, i, :].argmax().item()] for i in range(len(forms))]
|
39 |
+
xpos = [self.xpos_vocab[xpos_p[0, i, :].argmax().item()] for i in range(len(forms))]
|
40 |
+
feats = [self.feats_vocab[feats_p[0, i, :].argmax().item()] for i in range(len(forms))]
|
41 |
+
heads = [head_p[0, i].item() for i in range(len(forms))]
|
42 |
+
deprel = [self.arc_dep_vocab[dep_p[0, i, :].argmax().item()] for i in range(len(forms))]
|
43 |
+
ne = [self.ne_vocab[ne_p[0, i, :].argmax().item()] for i in range(len(forms))]
|
44 |
+
|
45 |
+
return lemmas, upos, xpos, feats, heads, deprel, ne
|
46 |
+
|
47 |
+
# save state dict
|
48 |
+
def state_dict(self):
|
49 |
+
return {
|
50 |
+
"forms_vocab": self.forms_vocab,
|
51 |
+
"lemma_vocab": self.lemma_vocab,
|
52 |
+
"upos_vocab": self.upos_vocab,
|
53 |
+
"xpos_vocab": self.xpos_vocab,
|
54 |
+
"feats_vocab": self.feats_vocab,
|
55 |
+
"arc_dep_vocab": self.arc_dep_vocab,
|
56 |
+
"ne_vocab": self.ne_vocab
|
57 |
+
}
|
58 |
+
|
59 |
+
# load state dict
|
60 |
+
def load_state_dict(self, state_dict):
|
61 |
+
self.forms_vocab = state_dict["forms_vocab"]
|
62 |
+
self.lemma_vocab = state_dict["lemma_vocab"]
|
63 |
+
self.upos_vocab = state_dict["upos_vocab"]
|
64 |
+
self.xpos_vocab = state_dict["xpos_vocab"]
|
65 |
+
self.feats_vocab = state_dict["feats_vocab"]
|
66 |
+
self.arc_dep_vocab = state_dict["arc_dep_vocab"]
|
67 |
+
self.ne_vocab = state_dict["ne_vocab"]
|
68 |
+
|
69 |
+
self.lemma_indexer = {i: n for n, i in enumerate(self.lemma_vocab)}
|
70 |
+
self.upos_indexer = {i: n for n, i in enumerate(self.upos_vocab)}
|
71 |
+
self.xpos_indexer = {i: n for n, i in enumerate(self.xpos_vocab)}
|
72 |
+
self.feats_indexer = {i: n for n, i in enumerate(self.feats_vocab)}
|
73 |
+
self.ne_indexer = {i: n for n, i in enumerate(self.ne_vocab)}
|
74 |
+
self.arc_dep_indexer = {i: n for n, i in enumerate(self.arc_dep_vocab)}
|
lemma_rule.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def min_edit_script(source, target, allow_copy):
|
2 |
+
a = [[(len(source) + len(target) + 1, None)] * (len(target) + 1) for _ in range(len(source) + 1)]
|
3 |
+
for i in range(0, len(source) + 1):
|
4 |
+
for j in range(0, len(target) + 1):
|
5 |
+
if i == 0 and j == 0:
|
6 |
+
a[i][j] = (0, "")
|
7 |
+
else:
|
8 |
+
if allow_copy and i and j and source[i - 1] == target[j - 1] and a[i-1][j-1][0] < a[i][j][0]:
|
9 |
+
a[i][j] = (a[i-1][j-1][0], a[i-1][j-1][1] + "β")
|
10 |
+
if i and a[i-1][j][0] < a[i][j][0]:
|
11 |
+
a[i][j] = (a[i-1][j][0] + 1, a[i-1][j][1] + "-")
|
12 |
+
if j and a[i][j-1][0] < a[i][j][0]:
|
13 |
+
a[i][j] = (a[i][j-1][0] + 1, a[i][j-1][1] + "+" + target[j - 1])
|
14 |
+
return a[-1][-1][1]
|
15 |
+
|
16 |
+
|
17 |
+
def gen_lemma_rule(form, lemma, allow_copy):
|
18 |
+
form = form.lower()
|
19 |
+
|
20 |
+
previous_case = -1
|
21 |
+
lemma_casing = ""
|
22 |
+
for i, c in enumerate(lemma):
|
23 |
+
case = "β" if c.lower() != c else "β"
|
24 |
+
if case != previous_case:
|
25 |
+
lemma_casing += "{}{}{}".format("Β¦" if lemma_casing else "", case, i if i <= len(lemma) // 2 else i - len(lemma))
|
26 |
+
previous_case = case
|
27 |
+
lemma = lemma.lower()
|
28 |
+
|
29 |
+
best, best_form, best_lemma = 0, 0, 0
|
30 |
+
for l in range(len(lemma)):
|
31 |
+
for f in range(len(form)):
|
32 |
+
cpl = 0
|
33 |
+
while f + cpl < len(form) and l + cpl < len(lemma) and form[f + cpl] == lemma[l + cpl]: cpl += 1
|
34 |
+
if cpl > best:
|
35 |
+
best = cpl
|
36 |
+
best_form = f
|
37 |
+
best_lemma = l
|
38 |
+
|
39 |
+
rule = lemma_casing + ";"
|
40 |
+
if not best:
|
41 |
+
rule += "a" + lemma
|
42 |
+
else:
|
43 |
+
rule += "d{}Β¦{}".format(
|
44 |
+
min_edit_script(form[:best_form], lemma[:best_lemma], allow_copy),
|
45 |
+
min_edit_script(form[best_form + best:], lemma[best_lemma + best:], allow_copy),
|
46 |
+
)
|
47 |
+
return rule
|
48 |
+
|
49 |
+
|
50 |
+
def apply_lemma_rule(form, lemma_rule):
|
51 |
+
if lemma_rule == "<unk>":
|
52 |
+
return form
|
53 |
+
|
54 |
+
if ';' not in lemma_rule:
|
55 |
+
raise ValueError('lemma_rule %r for form %r missing semicolon' %(lemma_rule, form))
|
56 |
+
|
57 |
+
casing, rule = lemma_rule.split(";", 1)
|
58 |
+
if rule.startswith("a"):
|
59 |
+
lemma = rule[1:]
|
60 |
+
else:
|
61 |
+
form = form.lower()
|
62 |
+
rules, rule_sources = rule[1:].split("Β¦"), []
|
63 |
+
assert len(rules) == 2
|
64 |
+
for rule in rules:
|
65 |
+
source, i = 0, 0
|
66 |
+
while i < len(rule):
|
67 |
+
if rule[i] == "β" or rule[i] == "-":
|
68 |
+
source += 1
|
69 |
+
else:
|
70 |
+
assert rule[i] == "+"
|
71 |
+
i += 1
|
72 |
+
i += 1
|
73 |
+
rule_sources.append(source)
|
74 |
+
|
75 |
+
try:
|
76 |
+
lemma, form_offset = "", 0
|
77 |
+
for i in range(2):
|
78 |
+
j, offset = 0, (0 if i == 0 else len(form) - rule_sources[1])
|
79 |
+
while j < len(rules[i]):
|
80 |
+
if rules[i][j] == "β":
|
81 |
+
lemma += form[offset]
|
82 |
+
offset += 1
|
83 |
+
elif rules[i][j] == "-":
|
84 |
+
offset += 1
|
85 |
+
else:
|
86 |
+
assert(rules[i][j] == "+")
|
87 |
+
lemma += rules[i][j + 1]
|
88 |
+
j += 1
|
89 |
+
j += 1
|
90 |
+
if i == 0:
|
91 |
+
lemma += form[rule_sources[0] : len(form) - rule_sources[1]]
|
92 |
+
except:
|
93 |
+
lemma = form
|
94 |
+
|
95 |
+
for rule in casing.split("Β¦"):
|
96 |
+
if rule == "β0": continue # The lemma is lowercased initially
|
97 |
+
if not rule: continue # Empty lemma might generate empty casing rule
|
98 |
+
case, offset = rule[0], int(rule[1:])
|
99 |
+
lemma = lemma[:offset] + (lemma[offset:].upper() if case == "β" else lemma[offset:].lower())
|
100 |
+
|
101 |
+
return lemma
|
model.py
ADDED
@@ -0,0 +1,660 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
import dependency_decoding
|
4 |
+
import ftfy
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.utils import checkpoint
|
10 |
+
|
11 |
+
from transformers.modeling_utils import PreTrainedModel
|
12 |
+
from transformers.activations import gelu_new
|
13 |
+
from transformers.modeling_outputs import (
|
14 |
+
MaskedLMOutput,
|
15 |
+
MultipleChoiceModelOutput,
|
16 |
+
QuestionAnsweringModelOutput,
|
17 |
+
SequenceClassifierOutput,
|
18 |
+
TokenClassifierOutput,
|
19 |
+
BaseModelOutput
|
20 |
+
)
|
21 |
+
from transformers.pytorch_utils import softmax_backward_data
|
22 |
+
from transformers.configuration_utils import PretrainedConfig
|
23 |
+
|
24 |
+
from dataset import Dataset
|
25 |
+
|
26 |
+
|
27 |
+
class NorbertConfig(PretrainedConfig):
|
28 |
+
"""Configuration class to store the configuration of a `NorbertModel`.
|
29 |
+
"""
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
vocab_size=50000,
|
33 |
+
attention_probs_dropout_prob=0.1,
|
34 |
+
hidden_dropout_prob=0.1,
|
35 |
+
hidden_size=768,
|
36 |
+
intermediate_size=2048,
|
37 |
+
max_position_embeddings=512,
|
38 |
+
position_bucket_size=32,
|
39 |
+
num_attention_heads=12,
|
40 |
+
num_hidden_layers=12,
|
41 |
+
layer_norm_eps=1.0e-7,
|
42 |
+
output_all_encoded_layers=True,
|
43 |
+
**kwargs,
|
44 |
+
):
|
45 |
+
super().__init__(**kwargs)
|
46 |
+
|
47 |
+
self.vocab_size = vocab_size
|
48 |
+
self.hidden_size = hidden_size
|
49 |
+
self.num_hidden_layers = num_hidden_layers
|
50 |
+
self.num_attention_heads = num_attention_heads
|
51 |
+
self.intermediate_size = intermediate_size
|
52 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
53 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
54 |
+
self.max_position_embeddings = max_position_embeddings
|
55 |
+
self.output_all_encoded_layers = output_all_encoded_layers
|
56 |
+
self.position_bucket_size = position_bucket_size
|
57 |
+
self.layer_norm_eps = layer_norm_eps
|
58 |
+
|
59 |
+
|
60 |
+
class Encoder(nn.Module):
|
61 |
+
def __init__(self, config, activation_checkpointing=False):
|
62 |
+
super().__init__()
|
63 |
+
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
64 |
+
|
65 |
+
for i, layer in enumerate(self.layers):
|
66 |
+
layer.mlp.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
|
67 |
+
layer.mlp.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
|
68 |
+
|
69 |
+
self.activation_checkpointing = activation_checkpointing
|
70 |
+
|
71 |
+
def forward(self, hidden_states, attention_mask, relative_embedding):
|
72 |
+
hidden_states, attention_probs = [hidden_states], []
|
73 |
+
|
74 |
+
for layer in self.layers:
|
75 |
+
if self.activation_checkpointing:
|
76 |
+
hidden_state, attention_p = checkpoint.checkpoint(layer, hidden_states[-1], attention_mask, relative_embedding)
|
77 |
+
else:
|
78 |
+
hidden_state, attention_p = layer(hidden_states[-1], attention_mask, relative_embedding)
|
79 |
+
|
80 |
+
hidden_states.append(hidden_state)
|
81 |
+
attention_probs.append(attention_p)
|
82 |
+
|
83 |
+
return hidden_states, attention_probs
|
84 |
+
|
85 |
+
|
86 |
+
class MaskClassifier(nn.Module):
|
87 |
+
def __init__(self, config, subword_embedding):
|
88 |
+
super().__init__()
|
89 |
+
self.nonlinearity = nn.Sequential(
|
90 |
+
nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False),
|
91 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
92 |
+
nn.GELU(),
|
93 |
+
nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False),
|
94 |
+
nn.Dropout(config.hidden_dropout_prob),
|
95 |
+
nn.Linear(subword_embedding.size(1), subword_embedding.size(0))
|
96 |
+
)
|
97 |
+
self.initialize(config.hidden_size, subword_embedding)
|
98 |
+
|
99 |
+
def initialize(self, hidden_size, embedding):
|
100 |
+
std = math.sqrt(2.0 / (5.0 * hidden_size))
|
101 |
+
nn.init.trunc_normal_(self.nonlinearity[1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
102 |
+
self.nonlinearity[-1].weight = embedding
|
103 |
+
self.nonlinearity[1].bias.data.zero_()
|
104 |
+
self.nonlinearity[-1].bias.data.zero_()
|
105 |
+
|
106 |
+
def forward(self, x, masked_lm_labels=None):
|
107 |
+
if masked_lm_labels is not None:
|
108 |
+
x = torch.index_select(x.flatten(0, 1), 0, torch.nonzero(masked_lm_labels.flatten() != -100).squeeze())
|
109 |
+
x = self.nonlinearity(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class EncoderLayer(nn.Module):
|
114 |
+
def __init__(self, config):
|
115 |
+
super().__init__()
|
116 |
+
self.attention = Attention(config)
|
117 |
+
self.mlp = FeedForward(config)
|
118 |
+
|
119 |
+
def forward(self, x, padding_mask, relative_embedding):
|
120 |
+
attention_output, attention_probs = self.attention(x, padding_mask, relative_embedding)
|
121 |
+
x = x + attention_output
|
122 |
+
x = x + self.mlp(x)
|
123 |
+
return x, attention_probs
|
124 |
+
|
125 |
+
|
126 |
+
class GeGLU(nn.Module):
|
127 |
+
def forward(self, x):
|
128 |
+
x, gate = x.chunk(2, dim=-1)
|
129 |
+
x = x * gelu_new(gate)
|
130 |
+
return x
|
131 |
+
|
132 |
+
|
133 |
+
class FeedForward(nn.Module):
|
134 |
+
def __init__(self, config):
|
135 |
+
super().__init__()
|
136 |
+
self.mlp = nn.Sequential(
|
137 |
+
nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False),
|
138 |
+
nn.Linear(config.hidden_size, 2*config.intermediate_size, bias=False),
|
139 |
+
GeGLU(),
|
140 |
+
nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False),
|
141 |
+
nn.Linear(config.intermediate_size, config.hidden_size, bias=False),
|
142 |
+
nn.Dropout(config.hidden_dropout_prob)
|
143 |
+
)
|
144 |
+
self.initialize(config.hidden_size)
|
145 |
+
|
146 |
+
def initialize(self, hidden_size):
|
147 |
+
std = math.sqrt(2.0 / (5.0 * hidden_size))
|
148 |
+
nn.init.trunc_normal_(self.mlp[1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
149 |
+
nn.init.trunc_normal_(self.mlp[-2].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
return self.mlp(x)
|
153 |
+
|
154 |
+
|
155 |
+
class MaskedSoftmax(torch.autograd.Function):
|
156 |
+
@staticmethod
|
157 |
+
def forward(self, x, mask, dim):
|
158 |
+
self.dim = dim
|
159 |
+
x.masked_fill_(mask, float('-inf'))
|
160 |
+
x = torch.softmax(x, self.dim)
|
161 |
+
x.masked_fill_(mask, 0.0)
|
162 |
+
self.save_for_backward(x)
|
163 |
+
return x
|
164 |
+
|
165 |
+
@staticmethod
|
166 |
+
def backward(self, grad_output):
|
167 |
+
output, = self.saved_tensors
|
168 |
+
input_grad = softmax_backward_data(self, grad_output, output, self.dim, output)
|
169 |
+
return input_grad, None, None
|
170 |
+
|
171 |
+
|
172 |
+
class Attention(nn.Module):
|
173 |
+
def __init__(self, config):
|
174 |
+
super().__init__()
|
175 |
+
|
176 |
+
self.config = config
|
177 |
+
|
178 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
179 |
+
raise ValueError(f"The hidden size {config.hidden_size} is not a multiple of the number of attention heads {config.num_attention_heads}")
|
180 |
+
|
181 |
+
self.hidden_size = config.hidden_size
|
182 |
+
self.num_heads = config.num_attention_heads
|
183 |
+
self.head_size = config.hidden_size // config.num_attention_heads
|
184 |
+
|
185 |
+
self.in_proj_qk = nn.Linear(config.hidden_size, 2*config.hidden_size, bias=True)
|
186 |
+
self.in_proj_v = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
|
187 |
+
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
|
188 |
+
|
189 |
+
self.pre_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False)
|
190 |
+
self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
|
191 |
+
|
192 |
+
position_indices = torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(1) \
|
193 |
+
- torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
|
194 |
+
position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
|
195 |
+
position_indices = config.position_bucket_size - 1 + position_indices
|
196 |
+
self.register_buffer("position_indices", position_indices, persistent=True)
|
197 |
+
|
198 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
199 |
+
self.scale = 1.0 / math.sqrt(3 * self.head_size)
|
200 |
+
self.initialize()
|
201 |
+
|
202 |
+
def make_log_bucket_position(self, relative_pos, bucket_size, max_position):
|
203 |
+
sign = torch.sign(relative_pos)
|
204 |
+
mid = bucket_size // 2
|
205 |
+
abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos).clamp(max=max_position - 1))
|
206 |
+
log_pos = torch.ceil(torch.log(abs_pos / mid) / math.log((max_position-1) / mid) * (mid - 1)).int() + mid
|
207 |
+
bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
|
208 |
+
return bucket_pos
|
209 |
+
|
210 |
+
def initialize(self):
|
211 |
+
std = math.sqrt(2.0 / (5.0 * self.hidden_size))
|
212 |
+
nn.init.trunc_normal_(self.in_proj_qk.weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
213 |
+
nn.init.trunc_normal_(self.in_proj_v.weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
214 |
+
nn.init.trunc_normal_(self.out_proj.weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
215 |
+
self.in_proj_qk.bias.data.zero_()
|
216 |
+
self.in_proj_v.bias.data.zero_()
|
217 |
+
self.out_proj.bias.data.zero_()
|
218 |
+
|
219 |
+
def compute_attention_scores(self, hidden_states, relative_embedding):
|
220 |
+
key_len, batch_size, _ = hidden_states.size()
|
221 |
+
query_len = key_len
|
222 |
+
|
223 |
+
if self.position_indices.size(0) < query_len:
|
224 |
+
position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
|
225 |
+
- torch.arange(query_len, dtype=torch.long).unsqueeze(0)
|
226 |
+
position_indices = self.make_log_bucket_position(position_indices, self.position_bucket_size, 512)
|
227 |
+
position_indices = self.position_bucket_size - 1 + position_indices
|
228 |
+
self.position_indices = position_indices.to(hidden_states.device)
|
229 |
+
|
230 |
+
hidden_states = self.pre_layer_norm(hidden_states)
|
231 |
+
|
232 |
+
query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
|
233 |
+
value = self.in_proj_v(hidden_states) # shape: [T, B, D]
|
234 |
+
|
235 |
+
query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
236 |
+
key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
237 |
+
value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
238 |
+
|
239 |
+
attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
|
240 |
+
|
241 |
+
pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2T-1, 2D]
|
242 |
+
query_pos, key_pos = pos.view(-1, self.num_heads, 2*self.head_size).chunk(2, dim=2)
|
243 |
+
query = query.view(batch_size, self.num_heads, query_len, self.head_size)
|
244 |
+
key = key.view(batch_size, self.num_heads, query_len, self.head_size)
|
245 |
+
|
246 |
+
attention_c_p = torch.einsum("bhqd,khd->bhqk", query, key_pos.squeeze(1) * self.scale)
|
247 |
+
attention_p_c = torch.einsum("bhkd,qhd->bhqk", key * self.scale, query_pos.squeeze(1))
|
248 |
+
|
249 |
+
position_indices = self.position_indices[:query_len, :key_len].expand(batch_size, self.num_heads, -1, -1)
|
250 |
+
attention_c_p = attention_c_p.gather(3, position_indices)
|
251 |
+
attention_p_c = attention_p_c.gather(2, position_indices)
|
252 |
+
|
253 |
+
attention_scores = attention_scores.view(batch_size, self.num_heads, query_len, key_len)
|
254 |
+
attention_scores.add_(attention_c_p)
|
255 |
+
attention_scores.add_(attention_p_c)
|
256 |
+
|
257 |
+
return attention_scores, value
|
258 |
+
|
259 |
+
def compute_output(self, attention_probs, value):
|
260 |
+
attention_probs = self.dropout(attention_probs)
|
261 |
+
context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
|
262 |
+
context = context.transpose(0, 1).reshape(context.size(1), -1, self.hidden_size) # shape: [Q, B, H*D]
|
263 |
+
context = self.out_proj(context)
|
264 |
+
context = self.post_layer_norm(context)
|
265 |
+
context = self.dropout(context)
|
266 |
+
return context
|
267 |
+
|
268 |
+
def forward(self, hidden_states, attention_mask, relative_embedding):
|
269 |
+
attention_scores, value = self.compute_attention_scores(hidden_states, relative_embedding)
|
270 |
+
attention_probs = MaskedSoftmax.apply(attention_scores, attention_mask, -1)
|
271 |
+
return self.compute_output(attention_probs, value), attention_probs.detach()
|
272 |
+
|
273 |
+
|
274 |
+
class Embedding(nn.Module):
|
275 |
+
def __init__(self, config):
|
276 |
+
super().__init__()
|
277 |
+
self.hidden_size = config.hidden_size
|
278 |
+
|
279 |
+
self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
280 |
+
self.word_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
281 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
282 |
+
|
283 |
+
self.relative_embedding = nn.Parameter(torch.empty(2 * config.position_bucket_size - 1, config.hidden_size))
|
284 |
+
self.relative_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
285 |
+
|
286 |
+
self.initialize()
|
287 |
+
|
288 |
+
def initialize(self):
|
289 |
+
std = math.sqrt(2.0 / (5.0 * self.hidden_size))
|
290 |
+
nn.init.trunc_normal_(self.relative_embedding, mean=0.0, std=std, a=-2*std, b=2*std)
|
291 |
+
nn.init.trunc_normal_(self.word_embedding.weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
292 |
+
|
293 |
+
def forward(self, input_ids):
|
294 |
+
word_embedding = self.dropout(self.word_layer_norm(self.word_embedding(input_ids)))
|
295 |
+
relative_embeddings = self.relative_layer_norm(self.relative_embedding)
|
296 |
+
return word_embedding, relative_embeddings
|
297 |
+
|
298 |
+
|
299 |
+
#
|
300 |
+
# HuggingFace wrappers
|
301 |
+
#
|
302 |
+
|
303 |
+
class NorbertPreTrainedModel(PreTrainedModel):
|
304 |
+
config_class = NorbertConfig
|
305 |
+
base_model_prefix = "norbert3"
|
306 |
+
supports_gradient_checkpointing = True
|
307 |
+
|
308 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
309 |
+
if isinstance(module, Encoder):
|
310 |
+
module.activation_checkpointing = value
|
311 |
+
|
312 |
+
def _init_weights(self, module):
|
313 |
+
pass # everything is already initialized
|
314 |
+
|
315 |
+
|
316 |
+
class NorbertModel(NorbertPreTrainedModel):
|
317 |
+
def __init__(self, config, add_mlm_layer=False, gradient_checkpointing=False, **kwargs):
|
318 |
+
super().__init__(config, **kwargs)
|
319 |
+
self.config = config
|
320 |
+
|
321 |
+
self.embedding = Embedding(config)
|
322 |
+
self.transformer = Encoder(config, activation_checkpointing=gradient_checkpointing)
|
323 |
+
self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
|
324 |
+
|
325 |
+
def get_input_embeddings(self):
|
326 |
+
return self.embedding.word_embedding
|
327 |
+
|
328 |
+
def set_input_embeddings(self, value):
|
329 |
+
self.embedding.word_embedding = value
|
330 |
+
|
331 |
+
def get_contextualized_embeddings(
|
332 |
+
self,
|
333 |
+
input_ids: Optional[torch.Tensor] = None,
|
334 |
+
attention_mask: Optional[torch.Tensor] = None
|
335 |
+
) -> List[torch.Tensor]:
|
336 |
+
if input_ids is not None:
|
337 |
+
input_shape = input_ids.size()
|
338 |
+
else:
|
339 |
+
raise ValueError("You have to specify input_ids")
|
340 |
+
|
341 |
+
batch_size, seq_length = input_shape
|
342 |
+
device = input_ids.device
|
343 |
+
|
344 |
+
if attention_mask is None:
|
345 |
+
attention_mask = torch.zeros(batch_size, seq_length, dtype=torch.bool, device=device)
|
346 |
+
else:
|
347 |
+
attention_mask = ~attention_mask.bool()
|
348 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
349 |
+
|
350 |
+
static_embeddings, relative_embedding = self.embedding(input_ids.t())
|
351 |
+
contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
|
352 |
+
contextualized_embeddings = [e.transpose(0, 1) for e in contextualized_embeddings]
|
353 |
+
last_layer = contextualized_embeddings[-1]
|
354 |
+
contextualized_embeddings = [contextualized_embeddings[0]] + [
|
355 |
+
contextualized_embeddings[i] - contextualized_embeddings[i - 1]
|
356 |
+
for i in range(1, len(contextualized_embeddings))
|
357 |
+
]
|
358 |
+
return last_layer, contextualized_embeddings, attention_probs
|
359 |
+
|
360 |
+
def forward(
|
361 |
+
self,
|
362 |
+
input_ids: Optional[torch.Tensor] = None,
|
363 |
+
attention_mask: Optional[torch.Tensor] = None,
|
364 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
365 |
+
position_ids: Optional[torch.Tensor] = None,
|
366 |
+
output_hidden_states: Optional[bool] = None,
|
367 |
+
output_attentions: Optional[bool] = None,
|
368 |
+
return_dict: Optional[bool] = None,
|
369 |
+
**kwargs
|
370 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
|
371 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
372 |
+
|
373 |
+
sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
|
374 |
+
|
375 |
+
if not return_dict:
|
376 |
+
return (
|
377 |
+
sequence_output,
|
378 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
379 |
+
*([attention_probs] if output_attentions else [])
|
380 |
+
)
|
381 |
+
|
382 |
+
return BaseModelOutput(
|
383 |
+
last_hidden_state=sequence_output,
|
384 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
385 |
+
attentions=attention_probs if output_attentions else None
|
386 |
+
)
|
387 |
+
|
388 |
+
|
389 |
+
class Classifier(nn.Module):
|
390 |
+
def __init__(self, hidden_size, vocab_size, dropout):
|
391 |
+
super().__init__()
|
392 |
+
|
393 |
+
self.transform = nn.Sequential(
|
394 |
+
nn.Linear(hidden_size, hidden_size),
|
395 |
+
nn.GELU(),
|
396 |
+
nn.LayerNorm(hidden_size, elementwise_affine=False),
|
397 |
+
nn.Dropout(dropout),
|
398 |
+
nn.Linear(hidden_size, vocab_size)
|
399 |
+
)
|
400 |
+
self.initialize(hidden_size)
|
401 |
+
|
402 |
+
def initialize(self, hidden_size):
|
403 |
+
std = math.sqrt(2.0 / (5.0 * hidden_size))
|
404 |
+
nn.init.trunc_normal_(self.transform[0].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
405 |
+
nn.init.trunc_normal_(self.transform[-1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
406 |
+
self.transform[0].bias.data.zero_()
|
407 |
+
self.transform[-1].bias.data.zero_()
|
408 |
+
|
409 |
+
def forward(self, x):
|
410 |
+
return self.transform(x)
|
411 |
+
|
412 |
+
|
413 |
+
class ZeroClassifier(nn.Module):
|
414 |
+
def forward(self, x):
|
415 |
+
output = torch.zeros(x.size(0), x.size(1), 2, device=x.device, dtype=x.dtype)
|
416 |
+
output[:, :, 0] = 1.0
|
417 |
+
output[:, :, 1] = -1.0
|
418 |
+
return output
|
419 |
+
|
420 |
+
|
421 |
+
class EdgeClassifier(nn.Module):
|
422 |
+
def __init__(self, hidden_size, dep_hidden_size, vocab_size, dropout):
|
423 |
+
super().__init__()
|
424 |
+
|
425 |
+
self.head_dep_transform = nn.Sequential(
|
426 |
+
nn.Linear(hidden_size, hidden_size),
|
427 |
+
nn.GELU(),
|
428 |
+
nn.LayerNorm(hidden_size, elementwise_affine=False),
|
429 |
+
nn.Dropout(dropout)
|
430 |
+
)
|
431 |
+
self.head_root_transform = nn.Sequential(
|
432 |
+
nn.Linear(hidden_size, hidden_size),
|
433 |
+
nn.GELU(),
|
434 |
+
nn.LayerNorm(hidden_size, elementwise_affine=False),
|
435 |
+
nn.Dropout(dropout)
|
436 |
+
)
|
437 |
+
self.head_bilinear = nn.Parameter(torch.zeros(hidden_size, hidden_size))
|
438 |
+
self.head_linear_dep = nn.Linear(hidden_size, 1, bias=False)
|
439 |
+
self.head_linear_root = nn.Linear(hidden_size, 1, bias=False)
|
440 |
+
self.head_bias = nn.Parameter(torch.zeros(1))
|
441 |
+
|
442 |
+
self.dep_dep_transform = nn.Sequential(
|
443 |
+
nn.Linear(hidden_size, dep_hidden_size),
|
444 |
+
nn.GELU(),
|
445 |
+
nn.LayerNorm(dep_hidden_size, elementwise_affine=False),
|
446 |
+
nn.Dropout(dropout)
|
447 |
+
)
|
448 |
+
self.dep_root_transform = nn.Sequential(
|
449 |
+
nn.Linear(hidden_size, dep_hidden_size),
|
450 |
+
nn.GELU(),
|
451 |
+
nn.LayerNorm(dep_hidden_size, elementwise_affine=False),
|
452 |
+
nn.Dropout(dropout)
|
453 |
+
)
|
454 |
+
self.dep_bilinear = nn.Parameter(torch.zeros(dep_hidden_size, dep_hidden_size, vocab_size))
|
455 |
+
self.dep_linear_dep = nn.Linear(dep_hidden_size, vocab_size, bias=False)
|
456 |
+
self.dep_linear_root = nn.Linear(dep_hidden_size, vocab_size, bias=False)
|
457 |
+
self.dep_bias = nn.Parameter(torch.zeros(vocab_size))
|
458 |
+
|
459 |
+
self.hidden_size = hidden_size
|
460 |
+
self.dep_hidden_size = dep_hidden_size
|
461 |
+
|
462 |
+
self.mask_value = float("-inf")
|
463 |
+
self.initialize(hidden_size)
|
464 |
+
|
465 |
+
def initialize(self, hidden_size):
|
466 |
+
std = math.sqrt(2.0 / (5.0 * hidden_size))
|
467 |
+
nn.init.trunc_normal_(self.head_dep_transform[0].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
468 |
+
nn.init.trunc_normal_(self.head_root_transform[0].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
469 |
+
nn.init.trunc_normal_(self.dep_dep_transform[0].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
470 |
+
nn.init.trunc_normal_(self.dep_root_transform[0].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
471 |
+
|
472 |
+
nn.init.trunc_normal_(self.head_linear_dep.weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
473 |
+
nn.init.trunc_normal_(self.head_linear_root.weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
474 |
+
nn.init.trunc_normal_(self.dep_linear_dep.weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
475 |
+
nn.init.trunc_normal_(self.dep_linear_root.weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
476 |
+
|
477 |
+
self.head_dep_transform[0].bias.data.zero_()
|
478 |
+
self.head_root_transform[0].bias.data.zero_()
|
479 |
+
self.dep_dep_transform[0].bias.data.zero_()
|
480 |
+
self.dep_root_transform[0].bias.data.zero_()
|
481 |
+
|
482 |
+
def forward(self, head_x, dep_x, lengths, head_gold=None):
|
483 |
+
head_dep = self.head_dep_transform(head_x[:, 1:, :])
|
484 |
+
head_root = self.head_root_transform(head_x)
|
485 |
+
head_prediction = torch.einsum("bkn,nm,blm->bkl", head_dep, self.head_bilinear, head_root / math.sqrt(self.hidden_size)) \
|
486 |
+
+ self.head_linear_dep(head_dep) + self.head_linear_root(head_root).transpose(1, 2) + self.head_bias
|
487 |
+
|
488 |
+
mask = (torch.arange(head_x.size(1)).unsqueeze(0) >= lengths.unsqueeze(1)).unsqueeze(1).to(head_x.device)
|
489 |
+
mask = mask | (torch.ones(head_x.size(1) - 1, head_x.size(1), dtype=torch.bool, device=head_x.device).tril(1) & torch.ones(head_x.size(1) - 1, head_x.size(1), dtype=torch.bool, device=head_x.device).triu(1))
|
490 |
+
head_prediction = head_prediction.masked_fill(mask, self.mask_value)
|
491 |
+
|
492 |
+
if head_gold is None:
|
493 |
+
head_logp = torch.log_softmax(head_prediction, dim=-1)
|
494 |
+
head_logp = F.pad(head_logp, (0, 0, 1, 0), value=torch.nan).cpu()
|
495 |
+
head_gold = []
|
496 |
+
for i, length in enumerate(lengths.tolist()):
|
497 |
+
head = self.max_spanning_tree(head_logp[i, :length, :length])
|
498 |
+
head = head + ((head_x.size(1) - 1) - len(head)) * [0]
|
499 |
+
head_gold.append(torch.tensor(head))
|
500 |
+
head_gold = torch.stack(head_gold).to(head_x.device)
|
501 |
+
|
502 |
+
dep_dep = self.dep_dep_transform(dep_x[:, 1:])
|
503 |
+
dep_root = dep_x.gather(1, head_gold.unsqueeze(-1).expand(-1, -1, dep_x.size(-1)).clamp(min=0))
|
504 |
+
dep_root = self.dep_root_transform(dep_root)
|
505 |
+
dep_prediction = torch.einsum("btm,mnl,btn->btl", dep_dep, self.dep_bilinear, dep_root / math.sqrt(self.dep_hidden_size)) \
|
506 |
+
+ self.dep_linear_dep(dep_dep) + self.dep_linear_root(dep_root) + self.dep_bias
|
507 |
+
|
508 |
+
return head_prediction, dep_prediction, head_gold
|
509 |
+
|
510 |
+
def max_spanning_tree(self, weight_matrix):
|
511 |
+
weight_matrix = weight_matrix.clone()
|
512 |
+
# weight_matrix[:, 0] = torch.nan
|
513 |
+
|
514 |
+
# we need to make sure that the root is the parent of a single node
|
515 |
+
# first, we try to use the default weights, it should work in most cases
|
516 |
+
parents, _ = dependency_decoding.chu_liu_edmonds(weight_matrix.numpy().astype(float))
|
517 |
+
|
518 |
+
assert parents[0] == -1, f"{parents}\n{weight_matrix}"
|
519 |
+
parents = parents[1:]
|
520 |
+
|
521 |
+
# check if the root is the parent of a single node
|
522 |
+
if parents.count(0) == 1:
|
523 |
+
return parents
|
524 |
+
|
525 |
+
# if not, we need to modify the weights and try all possibilities
|
526 |
+
# we try to find the node that is the parent of the root
|
527 |
+
best_score = float("-inf")
|
528 |
+
best_parents = None
|
529 |
+
|
530 |
+
for i in range(len(parents)):
|
531 |
+
weight_matrix_mod = weight_matrix.clone()
|
532 |
+
weight_matrix_mod[:i+1, 0] = torch.nan
|
533 |
+
weight_matrix_mod[i+2:, 0] = torch.nan
|
534 |
+
parents, score = dependency_decoding.chu_liu_edmonds(weight_matrix_mod.numpy().astype(float))
|
535 |
+
parents = parents[1:]
|
536 |
+
|
537 |
+
if score > best_score:
|
538 |
+
best_score = score
|
539 |
+
best_parents = parents
|
540 |
+
|
541 |
+
def print_whole_matrix(matrix):
|
542 |
+
for i in range(matrix.shape[0]):
|
543 |
+
print(" ".join([str(x) for x in matrix[i]]))
|
544 |
+
|
545 |
+
assert best_parents is not None, f"{best_parents}\n{print_whole_matrix(weight_matrix)}"
|
546 |
+
return best_parents
|
547 |
+
|
548 |
+
|
549 |
+
class Model(nn.Module):
|
550 |
+
def __init__(self, dataset):
|
551 |
+
super().__init__()
|
552 |
+
|
553 |
+
# config = BertConfig("../../configs/base.json")
|
554 |
+
# self.bert = Bert(config)
|
555 |
+
# checkpoint = torch.load("../../checkpoints/test_wd=0.01/model.bin", map_location="cpu")
|
556 |
+
# self.bert.load_state_dict(checkpoint["model"], strict=False)
|
557 |
+
|
558 |
+
config = NorbertConfig.from_json_file("config.json")
|
559 |
+
self.bert = NorbertModel(config)
|
560 |
+
|
561 |
+
self.n_layers = config.num_hidden_layers
|
562 |
+
|
563 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
564 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
|
565 |
+
self.upos_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
|
566 |
+
self.xpos_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
|
567 |
+
self.feats_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
|
568 |
+
self.lemma_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
|
569 |
+
self.head_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
|
570 |
+
self.dep_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
|
571 |
+
self.ner_layer_score = nn.Parameter(torch.zeros(self.n_layers + 1, dtype=torch.float))
|
572 |
+
|
573 |
+
self.lemma_classifier = Classifier(config.hidden_size, len(dataset.lemma_vocab), config.hidden_dropout_prob)
|
574 |
+
self.upos_classifier = Classifier(config.hidden_size, len(dataset.upos_vocab), config.hidden_dropout_prob) if len(dataset.upos_vocab) > 2 else ZeroClassifier()
|
575 |
+
self.xpos_classifier = Classifier(config.hidden_size, len(dataset.xpos_vocab), config.hidden_dropout_prob) if len(dataset.xpos_vocab) > 2 else ZeroClassifier()
|
576 |
+
self.feats_classifier = Classifier(config.hidden_size, len(dataset.feats_vocab), config.hidden_dropout_prob) if len(dataset.feats_vocab) > 2 else ZeroClassifier()
|
577 |
+
self.edge_classifier = EdgeClassifier(config.hidden_size, 128, len(dataset.arc_dep_vocab), config.hidden_dropout_prob)
|
578 |
+
self.ner_classifier = Classifier(config.hidden_size, len(dataset.ne_vocab), config.hidden_dropout_prob) if len(dataset.ne_vocab) > 2 else ZeroClassifier()
|
579 |
+
|
580 |
+
def forward(self, x, alignment_mask, subword_lengths, word_lengths, head_gold=None):
|
581 |
+
padding_mask = (torch.arange(x.size(1)).unsqueeze(0) < subword_lengths.unsqueeze(1)).to(x.device)
|
582 |
+
x = self.bert(x, padding_mask, output_hidden_states=True).hidden_states
|
583 |
+
x = torch.stack(x, dim=0)
|
584 |
+
|
585 |
+
upos_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.upos_layer_score, dim=0))
|
586 |
+
xpos_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.xpos_layer_score, dim=0))
|
587 |
+
feats_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.feats_layer_score, dim=0))
|
588 |
+
lemma_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.lemma_layer_score, dim=0))
|
589 |
+
head_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.head_layer_score, dim=0))
|
590 |
+
dep_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.dep_layer_score, dim=0))
|
591 |
+
ne_x = torch.einsum("lbtd, l -> btd", x, torch.softmax(self.ner_layer_score, dim=0))
|
592 |
+
|
593 |
+
upos_x = torch.einsum("bsd,bst->btd", upos_x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).clamp(min=1.0)
|
594 |
+
xpos_x = torch.einsum("bsd,bst->btd", xpos_x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).clamp(min=1.0)
|
595 |
+
feats_x = torch.einsum("bsd,bst->btd", feats_x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).clamp(min=1.0)
|
596 |
+
lemma_x = torch.einsum("bsd,bst->btd", lemma_x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).clamp(min=1.0)
|
597 |
+
head_x = torch.einsum("bsd,bst->btd", head_x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).clamp(min=1.0)
|
598 |
+
dep_x = torch.einsum("bsd,bst->btd", dep_x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).clamp(min=1.0)
|
599 |
+
ne_x = torch.einsum("bsd, bst -> btd", ne_x, alignment_mask) / alignment_mask.sum(1).unsqueeze(-1).clamp(min=1.0)
|
600 |
+
|
601 |
+
upos_x = self.dropout(self.layer_norm(upos_x[:, 1:-1, :]))
|
602 |
+
xpos_x = self.dropout(self.layer_norm(xpos_x[:, 1:-1, :]))
|
603 |
+
feats_x = self.dropout(self.layer_norm(feats_x[:, 1:-1, :]))
|
604 |
+
lemma_x = self.dropout(self.layer_norm(lemma_x[:, 1:-1, :]))
|
605 |
+
head_x = self.dropout(self.layer_norm(head_x[:, 0:-1, :]))
|
606 |
+
dep_x = self.dropout(self.layer_norm(dep_x[:, 0:-1, :]))
|
607 |
+
ne_x = self.dropout(self.layer_norm(ne_x[:, 1:-1, :]))
|
608 |
+
|
609 |
+
lemma_preds = self.lemma_classifier(lemma_x)
|
610 |
+
upos_preds = self.upos_classifier(upos_x)
|
611 |
+
xpos_preds = self.xpos_classifier(xpos_x)
|
612 |
+
feats_preds = self.feats_classifier(feats_x)
|
613 |
+
ne_preds = self.ner_classifier(feats_x)
|
614 |
+
head_prediction, dep_prediction, head_liu = self.edge_classifier(head_x, dep_x, word_lengths, head_gold)
|
615 |
+
|
616 |
+
return lemma_preds, upos_preds, xpos_preds, feats_preds, head_prediction, dep_prediction, ne_preds, head_liu
|
617 |
+
|
618 |
+
|
619 |
+
class Parser:
|
620 |
+
def __init__(self):
|
621 |
+
checkpoint = torch.load("checkpoint.bin", map_location="cpu")
|
622 |
+
|
623 |
+
self.dataset = Dataset()
|
624 |
+
self.dataset.load_state_dict(checkpoint["dataset"])
|
625 |
+
|
626 |
+
self.model = Model(self.dataset)
|
627 |
+
self.model.load_state_dict(checkpoint["model"])
|
628 |
+
self.model.eval()
|
629 |
+
del checkpoint
|
630 |
+
|
631 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
632 |
+
self.model.to(self.device)
|
633 |
+
|
634 |
+
def parse(self, sentence):
|
635 |
+
sentence = ftfy.fix_text(sentence.strip())
|
636 |
+
forms, subwords, alignment = self.dataset.prepare_input(sentence)
|
637 |
+
|
638 |
+
with torch.no_grad():
|
639 |
+
output = self.model(
|
640 |
+
subwords.to(self.device),
|
641 |
+
alignment.to(self.device),
|
642 |
+
torch.tensor([len(forms) + 1], device=self.device),
|
643 |
+
torch.tensor([subwords.size(1)], device=self.device)
|
644 |
+
)
|
645 |
+
|
646 |
+
lemma_p, upos_p, xpos_p, feats_p, _, dep_p, ne_p, head_p = output
|
647 |
+
lemmas, upos, xpos, feats, heads, deprel, ne = self.dataset.decode_output(
|
648 |
+
forms, lemma_p, upos_p, xpos_p, feats_p, dep_p, ne_p, head_p
|
649 |
+
)
|
650 |
+
|
651 |
+
return {
|
652 |
+
"forms": forms,
|
653 |
+
"lemmas": lemmas,
|
654 |
+
"upos": upos,
|
655 |
+
"xpos": xpos,
|
656 |
+
"feats": feats,
|
657 |
+
"heads": heads,
|
658 |
+
"deprel": deprel,
|
659 |
+
"ne": ne
|
660 |
+
}
|
requirements.txt
CHANGED
@@ -1,4 +1,8 @@
|
|
1 |
tabulate
|
2 |
matplotlib
|
3 |
networkx
|
4 |
-
pygraphviz
|
|
|
|
|
|
|
|
|
|
1 |
tabulate
|
2 |
matplotlib
|
3 |
networkx
|
4 |
+
pygraphviz
|
5 |
+
ftfy
|
6 |
+
torch
|
7 |
+
transformers
|
8 |
+
dependency_decoding
|
tokenizer.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Natural Language Toolkit: NLTK's very own tokenizer, slightly modified.
|
2 |
+
#
|
3 |
+
# Copyright (C) 2001-2023 NLTK Project
|
4 |
+
# Author: Liling Tan
|
5 |
+
# Tom Aarsen <> (modifications)
|
6 |
+
# URL: <https://www.nltk.org>
|
7 |
+
|
8 |
+
|
9 |
+
import re
|
10 |
+
import warnings
|
11 |
+
from typing import Iterator, List, Tuple
|
12 |
+
|
13 |
+
|
14 |
+
def align_tokens(tokens, sentence):
|
15 |
+
"""
|
16 |
+
This module attempt to find the offsets of the tokens in *s*, as a sequence
|
17 |
+
of ``(start, end)`` tuples, given the tokens and also the source string.
|
18 |
+
|
19 |
+
>>> from nltk.tokenize import TreebankWordTokenizer
|
20 |
+
>>> from nltk.tokenize.util import align_tokens
|
21 |
+
>>> s = str("The plane, bound for St Petersburg, crashed in Egypt's "
|
22 |
+
... "Sinai desert just 23 minutes after take-off from Sharm el-Sheikh "
|
23 |
+
... "on Saturday.")
|
24 |
+
>>> tokens = TreebankWordTokenizer().tokenize(s)
|
25 |
+
>>> expected = [(0, 3), (4, 9), (9, 10), (11, 16), (17, 20), (21, 23),
|
26 |
+
... (24, 34), (34, 35), (36, 43), (44, 46), (47, 52), (52, 54),
|
27 |
+
... (55, 60), (61, 67), (68, 72), (73, 75), (76, 83), (84, 89),
|
28 |
+
... (90, 98), (99, 103), (104, 109), (110, 119), (120, 122),
|
29 |
+
... (123, 131), (131, 132)]
|
30 |
+
>>> output = list(align_tokens(tokens, s))
|
31 |
+
>>> len(tokens) == len(expected) == len(output) # Check that length of tokens and tuples are the same.
|
32 |
+
True
|
33 |
+
>>> expected == list(align_tokens(tokens, s)) # Check that the output is as expected.
|
34 |
+
True
|
35 |
+
>>> tokens == [s[start:end] for start, end in output] # Check that the slices of the string corresponds to the tokens.
|
36 |
+
True
|
37 |
+
|
38 |
+
:param tokens: The list of strings that are the result of tokenization
|
39 |
+
:type tokens: list(str)
|
40 |
+
:param sentence: The original string
|
41 |
+
:type sentence: str
|
42 |
+
:rtype: list(tuple(int,int))
|
43 |
+
"""
|
44 |
+
point = 0
|
45 |
+
offsets = []
|
46 |
+
for token in tokens:
|
47 |
+
try:
|
48 |
+
start = sentence.index(token, point)
|
49 |
+
except ValueError as e:
|
50 |
+
raise ValueError(f'substring "{token}" not found in "{sentence}"') from e
|
51 |
+
point = start + len(token)
|
52 |
+
offsets.append((start, point))
|
53 |
+
return offsets
|
54 |
+
|
55 |
+
|
56 |
+
class NLTKWordTokenizer:
|
57 |
+
"""
|
58 |
+
The NLTK tokenizer that has improved upon the TreebankWordTokenizer.
|
59 |
+
|
60 |
+
This is the method that is invoked by ``word_tokenize()``. It assumes that the
|
61 |
+
text has already been segmented into sentences, e.g. using ``sent_tokenize()``.
|
62 |
+
|
63 |
+
The tokenizer is "destructive" such that the regexes applied will munge the
|
64 |
+
input string to a state beyond re-construction. It is possible to apply
|
65 |
+
`TreebankWordDetokenizer.detokenize` to the tokenized outputs of
|
66 |
+
`NLTKDestructiveWordTokenizer.tokenize` but there's no guarantees to
|
67 |
+
revert to the original string.
|
68 |
+
"""
|
69 |
+
|
70 |
+
# Starting quotes.
|
71 |
+
STARTING_QUOTES = [
|
72 |
+
(re.compile("([Β«βββ]|[`]+)", re.U), r" \1 "),
|
73 |
+
(re.compile(r"^\""), r' " '),
|
74 |
+
(re.compile(r"(``)"), r" \1 "),
|
75 |
+
(re.compile(r"([ \(\[{<])(\"|\'{2})"), r'\1 " '),
|
76 |
+
# (re.compile(r"(?i)(\')(?!re|ve|ll|m|t|s|d|n)(\w)\b", re.U), r"\1 \2"),
|
77 |
+
]
|
78 |
+
|
79 |
+
# Ending quotes.
|
80 |
+
ENDING_QUOTES = [
|
81 |
+
(re.compile("([Β»ββ])", re.U), r" \1 "),
|
82 |
+
(re.compile(r"''"), " '' "),
|
83 |
+
(re.compile(r'"'), ' " '),
|
84 |
+
(re.compile(r"([^' ])('[sS]|'[mM]|'[dD]|') "), r"\1 \2 "),
|
85 |
+
# (re.compile(r"([^' ])('ll|'LL|'re|'RE|'ve|'VE|n't|N'T) "), r"\1 \2 "),
|
86 |
+
]
|
87 |
+
|
88 |
+
# For improvements for starting/closing quotes from TreebankWordTokenizer,
|
89 |
+
# see discussion on https://github.com/nltk/nltk/pull/1437
|
90 |
+
# Adding to TreebankWordTokenizer, nltk.word_tokenize now splits on
|
91 |
+
# - chervon quotes u'\xab' and u'\xbb' .
|
92 |
+
# - unicode quotes u'\u2018', u'\u2019', u'\u201c' and u'\u201d'
|
93 |
+
# See https://github.com/nltk/nltk/issues/1995#issuecomment-376741608
|
94 |
+
# Also, behavior of splitting on clitics now follows Stanford CoreNLP
|
95 |
+
# - clitics covered (?!re|ve|ll|m|t|s|d)(\w)\b
|
96 |
+
|
97 |
+
# Punctuation.
|
98 |
+
PUNCTUATION = [
|
99 |
+
(re.compile(r'([^\.])(\.)([\]\)}>"\'' "Β»ββ " r"]*)\s*$", re.U), r"\1 \2 \3 "),
|
100 |
+
(re.compile(r"([:,])([^\d])"), r" \1 \2"),
|
101 |
+
(re.compile(r"([:,])$"), r" \1 "),
|
102 |
+
(
|
103 |
+
re.compile(r"\.{2,}", re.U),
|
104 |
+
r" \g<0> ",
|
105 |
+
), # See https://github.com/nltk/nltk/pull/2322
|
106 |
+
(re.compile(r"[;@#$%&]"), r" \g<0> "),
|
107 |
+
(
|
108 |
+
re.compile(r'([^\.])(\.)([\]\)}>"\']*)\s*$'),
|
109 |
+
r"\1 \2\3 ",
|
110 |
+
), # Handles the final period.
|
111 |
+
(re.compile(r"[?!]"), r" \g<0> "),
|
112 |
+
(re.compile(r"([^'])' "), r"\1 ' "),
|
113 |
+
(
|
114 |
+
re.compile(r"[*]", re.U),
|
115 |
+
r" \g<0> ",
|
116 |
+
), # See https://github.com/nltk/nltk/pull/2322
|
117 |
+
]
|
118 |
+
|
119 |
+
# Pads parentheses
|
120 |
+
PARENS_BRACKETS = (re.compile(r"[\]\[\(\)\{\}\<\>]"), r" \g<0> ")
|
121 |
+
|
122 |
+
# Optionally: Convert parentheses, brackets and converts them to PTB symbols.
|
123 |
+
# CONVERT_PARENTHESES = [
|
124 |
+
# (re.compile(r"\("), "-LRB-"),
|
125 |
+
# (re.compile(r"\)"), "-RRB-"),
|
126 |
+
# (re.compile(r"\["), "-LSB-"),
|
127 |
+
# (re.compile(r"\]"), "-RSB-"),
|
128 |
+
# (re.compile(r"\{"), "-LCB-"),
|
129 |
+
# (re.compile(r"\}"), "-RCB-"),
|
130 |
+
# ]
|
131 |
+
|
132 |
+
DOUBLE_DASHES = (re.compile(r"--"), r" -- ")
|
133 |
+
|
134 |
+
# List of contractions adapted from Robert MacIntyre's tokenizer.
|
135 |
+
# _contractions = MacIntyreContractions()
|
136 |
+
# CONTRACTIONS2 = list(map(re.compile, _contractions.CONTRACTIONS2))
|
137 |
+
# CONTRACTIONS3 = list(map(re.compile, _contractions.CONTRACTIONS3))
|
138 |
+
|
139 |
+
def tokenize(
|
140 |
+
self, text: str
|
141 |
+
) -> List[str]:
|
142 |
+
r"""Return a tokenized copy of `text`.
|
143 |
+
|
144 |
+
>>> from nltk.tokenize import NLTKWordTokenizer
|
145 |
+
>>> s = '''Good muffins cost $3.88 (roughly 3,36 euros)\nin New York. Please buy me\ntwo of them.\nThanks.'''
|
146 |
+
>>> NLTKWordTokenizer().tokenize(s) # doctest: +NORMALIZE_WHITESPACE
|
147 |
+
['Good', 'muffins', 'cost', '$', '3.88', '(', 'roughly', '3,36',
|
148 |
+
'euros', ')', 'in', 'New', 'York.', 'Please', 'buy', 'me', 'two',
|
149 |
+
'of', 'them.', 'Thanks', '.']
|
150 |
+
>>> NLTKWordTokenizer().tokenize(s, convert_parentheses=True) # doctest: +NORMALIZE_WHITESPACE
|
151 |
+
['Good', 'muffins', 'cost', '$', '3.88', '-LRB-', 'roughly', '3,36',
|
152 |
+
'euros', '-RRB-', 'in', 'New', 'York.', 'Please', 'buy', 'me', 'two',
|
153 |
+
'of', 'them.', 'Thanks', '.']
|
154 |
+
|
155 |
+
|
156 |
+
:param text: A string with a sentence or sentences.
|
157 |
+
:type text: str
|
158 |
+
:param convert_parentheses: if True, replace parentheses to PTB symbols,
|
159 |
+
e.g. `(` to `-LRB-`. Defaults to False.
|
160 |
+
:type convert_parentheses: bool, optional
|
161 |
+
:param return_str: If True, return tokens as space-separated string,
|
162 |
+
defaults to False.
|
163 |
+
:type return_str: bool, optional
|
164 |
+
:return: List of tokens from `text`.
|
165 |
+
:rtype: List[str]
|
166 |
+
"""
|
167 |
+
|
168 |
+
for regexp, substitution in self.STARTING_QUOTES:
|
169 |
+
text = regexp.sub(substitution, text)
|
170 |
+
|
171 |
+
for regexp, substitution in self.PUNCTUATION:
|
172 |
+
text = regexp.sub(substitution, text)
|
173 |
+
|
174 |
+
# Handles parentheses.
|
175 |
+
regexp, substitution = self.PARENS_BRACKETS
|
176 |
+
text = regexp.sub(substitution, text)
|
177 |
+
|
178 |
+
# Handles double dash.
|
179 |
+
regexp, substitution = self.DOUBLE_DASHES
|
180 |
+
text = regexp.sub(substitution, text)
|
181 |
+
|
182 |
+
# add extra space to make things easier
|
183 |
+
text = " " + text + " "
|
184 |
+
|
185 |
+
for regexp, substitution in self.ENDING_QUOTES:
|
186 |
+
text = regexp.sub(substitution, text)
|
187 |
+
|
188 |
+
return text.split()
|
189 |
+
|
190 |
+
def span_tokenize(self, text: str) -> Iterator[Tuple[int, int]]:
|
191 |
+
r"""
|
192 |
+
Returns the spans of the tokens in ``text``.
|
193 |
+
Uses the post-hoc nltk.tokens.align_tokens to return the offset spans.
|
194 |
+
|
195 |
+
>>> from nltk.tokenize import NLTKWordTokenizer
|
196 |
+
>>> s = '''Good muffins cost $3.88\nin New (York). Please (buy) me\ntwo of them.\n(Thanks).'''
|
197 |
+
>>> expected = [(0, 4), (5, 12), (13, 17), (18, 19), (19, 23),
|
198 |
+
... (24, 26), (27, 30), (31, 32), (32, 36), (36, 37), (37, 38),
|
199 |
+
... (40, 46), (47, 48), (48, 51), (51, 52), (53, 55), (56, 59),
|
200 |
+
... (60, 62), (63, 68), (69, 70), (70, 76), (76, 77), (77, 78)]
|
201 |
+
>>> list(NLTKWordTokenizer().span_tokenize(s)) == expected
|
202 |
+
True
|
203 |
+
>>> expected = ['Good', 'muffins', 'cost', '$', '3.88', 'in',
|
204 |
+
... 'New', '(', 'York', ')', '.', 'Please', '(', 'buy', ')',
|
205 |
+
... 'me', 'two', 'of', 'them.', '(', 'Thanks', ')', '.']
|
206 |
+
>>> [s[start:end] for start, end in NLTKWordTokenizer().span_tokenize(s)] == expected
|
207 |
+
True
|
208 |
+
|
209 |
+
:param text: A string with a sentence or sentences.
|
210 |
+
:type text: str
|
211 |
+
:yield: Tuple[int, int]
|
212 |
+
"""
|
213 |
+
raw_tokens = self.tokenize(text)
|
214 |
+
|
215 |
+
# Convert converted quotes back to original double quotes
|
216 |
+
# Do this only if original text contains double quote(s) or double
|
217 |
+
# single-quotes (because '' might be transformed to `` if it is
|
218 |
+
# treated as starting quotes).
|
219 |
+
# if ('"' in text) or ("''" in text):
|
220 |
+
# # Find double quotes and converted quotes
|
221 |
+
# matched = [m.group() for m in re.finditer(r"``|'{2}|\"", text)]
|
222 |
+
|
223 |
+
# # Replace converted quotes back to double quotes
|
224 |
+
# tokens = [
|
225 |
+
# matched.pop(0) if tok in ['"', "``", "''"] else tok
|
226 |
+
# for tok in raw_tokens
|
227 |
+
# ]
|
228 |
+
# else:
|
229 |
+
tokens = raw_tokens
|
230 |
+
|
231 |
+
yield from align_tokens(tokens, text)
|