bowphs commited on
Commit
3bc4816
1 Parent(s): 0f418b3

Add initial attempt of a code framework.

Browse files
Files changed (6) hide show
  1. README.md +5 -5
  2. app.py +342 -0
  3. models.py +112 -0
  4. requirements.txt +8 -0
  5. scrollbar.css +30 -0
  6. utils.py +161 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Athenas Lens
3
- emoji: 📈
4
- colorFrom: yellow
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.40.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
+ title: Athena's Lens
3
+ emoji: 🦉
4
+ colorFrom: red
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.3.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, Union, Dict, Mapping
2
+ import base64
3
+ import os
4
+
5
+ from bs4 import BeautifulSoup
6
+ import gradio as gr
7
+ from spacy import displacy
8
+ from transformers import (
9
+ AutoTokenizer,
10
+ AutoModelForTokenClassification,
11
+ BatchEncoding,
12
+ AutoModelForSeq2SeqLM,
13
+ DataCollatorForTokenClassification,
14
+ )
15
+ import torch
16
+
17
+ from utils import get_dependencies, preprocess_text
18
+ from models import (
19
+ DependencyRobertaForTokenClassification,
20
+ LabelRobertaForTokenClassification,
21
+ )
22
+
23
+
24
+ DEFAULT_TEXT = "τίω δέ μιν ἐν καρὸς αἴσῃ."
25
+ BUTTON_CSS = "float: right; --tw-border-opacity: 1; border-color: rgb(229 231 235 / var(--tw-border-opacity)); --tw-gradient-from: rgb(243 244 246 / 0.7); --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to, rgb(243 244 246 / 0)); --tw-gradient-to: rgb(229 231 235 / 0.8); --tw-text-opacity: 1; color: rgb(55 65 81 / var(--tw-text-opacity)); border-width: 1px; --tw-bg-opacity: 1; background-color: rgb(255 255 255 / var(--tw-bg-opacity)); background-image: linear-gradient(to bottom right, var(--tw-gradient-stops)); display: inline-flex; flex: 1 1 0%; align-items: center; justify-content: center; --tw-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05); --tw-shadow-colored: 0 1px 2px 0 var(--tw-shadow-color); box-shadow: var(--tw-ring-offset-shadow, 0 0 #0000), var(--tw-ring-shadow, 0 0 #0000), var(--tw-shadow); -webkit-appearance: button; border-radius: 0.5rem; padding-top: 0.5rem; padding-bottom: 0.5rem; padding-left: 1rem; padding-right: 1rem; font-size: 1rem; line-height: 1.5rem; font-weight: 600;"
26
+ DEFAULT_COLOR = "white"
27
+
28
+ MODEL_PATHS = {
29
+ "POS": "bowphs/testid",
30
+ "LEMMATIZATION": "bowphs/lemmatization-demo",
31
+ "DEPENDENCY": "bowphs/depenBERTa_perseus",
32
+ "LABELS": "bowphs/depenBERTa_labler_perseus",
33
+ }
34
+ MODEL_MAX_LENGTH = 512
35
+
36
+ AUTH_TOKEN = os.environ.get("TOKEN") or True
37
+ # PoS
38
+ pos_tokenizer = AutoTokenizer.from_pretrained(
39
+ MODEL_PATHS["POS"], model_max_length=MODEL_MAX_LENGTH, use_auth_token=AUTH_TOKEN
40
+ )
41
+ pos_model = AutoModelForTokenClassification.from_pretrained(
42
+ MODEL_PATHS["POS"], use_auth_token=AUTH_TOKEN
43
+ )
44
+
45
+ # Lemmatization
46
+ lemmatizer_tokenizer = AutoTokenizer.from_pretrained(
47
+ MODEL_PATHS["LEMMATIZATION"],
48
+ model_max_length=MODEL_MAX_LENGTH,
49
+ use_auth_token=AUTH_TOKEN,
50
+ )
51
+ lemmatizer_model = AutoModelForSeq2SeqLM.from_pretrained(
52
+ MODEL_PATHS["LEMMATIZATION"], use_auth_token=AUTH_TOKEN
53
+ )
54
+
55
+ # Dependency Parsing
56
+ dependency_tokenizer = AutoTokenizer.from_pretrained(
57
+ MODEL_PATHS["DEPENDENCY"],
58
+ model_max_length=MODEL_MAX_LENGTH,
59
+ use_auth_token=AUTH_TOKEN,
60
+ )
61
+ arcs_model = DependencyRobertaForTokenClassification.from_pretrained(
62
+ MODEL_PATHS["DEPENDENCY"], use_auth_token=AUTH_TOKEN
63
+ )
64
+ labels_model = LabelRobertaForTokenClassification.from_pretrained(
65
+ MODEL_PATHS["LABELS"], use_auth_token=AUTH_TOKEN
66
+ )
67
+
68
+ data_collator = DataCollatorForTokenClassification(dependency_tokenizer)
69
+
70
+
71
+ def is_valid_selection(col_arcs, col_labels) -> bool:
72
+ if not col_arcs and col_labels:
73
+ return False
74
+ return True
75
+
76
+
77
+ def get_pos_predictions(inputs) -> torch.Tensor:
78
+ """Get part of speech predictions."""
79
+ return pos_model(inputs["input_ids"]).logits.argmax(-1) # type: ignore
80
+
81
+
82
+ def execute_parse(
83
+ text_input: str,
84
+ col_pos: bool,
85
+ col_arcs: bool,
86
+ col_labels: bool,
87
+ col_lemmata: bool,
88
+ compact: bool,
89
+ bg: str,
90
+ text: str,
91
+ ) -> Tuple[str, str]:
92
+ if is_valid_selection(col_arcs, col_labels):
93
+ return parse(
94
+ text_input, col_pos, col_arcs, col_labels, col_lemmata, compact, bg, text
95
+ )
96
+ return "Please check 'Dependency Arcs' before checking 'Dependency Labels'", ""
97
+
98
+
99
+ def lemmatize(tokens: List[str]) -> List[str]:
100
+ def construct_task(word_idx: int) -> str:
101
+ return f"lemmatize: {' '.join(tokens[:word_idx])} <extra_id_0> {tokens[word_idx]} <extra_id_1> {' '.join(list(tokens[word_idx]))} <extra_id_2> {' '.join(tokens[word_idx+1:])}"
102
+
103
+ predictions = [
104
+ lemmatizer_tokenizer.decode(
105
+ lemmatizer_model.generate(
106
+ lemmatizer_tokenizer(construct_task(word_idx), return_tensors="pt")[
107
+ "input_ids"
108
+ ],
109
+ max_length=20,
110
+ num_beams=5,
111
+ num_return_sequences=1,
112
+ early_stopping=True,
113
+ )[0],
114
+ skip_special_tokens=True,
115
+ )
116
+ for word_idx in range(len(tokens))
117
+ ]
118
+
119
+ return predictions
120
+
121
+
122
+ def add_lemma_visualization(soup, lemmata: List[str], col_arcs: bool) -> str:
123
+ for token, lemma in zip(soup.find_all(class_="displacy-token")[col_arcs:], lemmata):
124
+ pos_tag = token.find(class_="displacy-tag")
125
+ lemma_tag = soup.new_tag(
126
+ "tspan",
127
+ class_="displacy-lemma",
128
+ dy="2em",
129
+ fill="currentColor",
130
+ x=pos_tag.attrs["x"],
131
+ )
132
+ lemma_tag.string = lemma
133
+ pos_tag.insert_after(lemma_tag)
134
+ return str(soup)
135
+
136
+
137
+ def download_svg(svg):
138
+ encode = base64.b64encode(bytes(svg, "utf-8"))
139
+ img = "data:image/svg+xml;base64," + str(encode)[2:-1]
140
+ html = f'<a download="displacy.svg" href="{img}" style="{BUTTON_CSS}">Download as SVG</a>'
141
+ return html
142
+
143
+
144
+ def prepare_doc(
145
+ tokens: List[str], col_pos: bool, pos_outputs: torch.Tensor, inputs: BatchEncoding,
146
+ ) -> Dict[str, List[Dict[str, str]]]:
147
+ doc: Dict[str, List[Dict[str, str]]] = {
148
+ "words": [], #[{"text": "ROOT", "tag": ""}],
149
+ "arcs": [],
150
+ }
151
+ word_ids = inputs.word_ids()
152
+ previous_word_idx = None
153
+
154
+ for idx, word_idx in enumerate(word_ids):
155
+ if word_idx != previous_word_idx and word_idx is not None:
156
+ tag_repr = (
157
+ pos_model.config.id2label[pos_outputs[0][idx].item()] if col_pos else ""
158
+ )
159
+ doc["words"].append({"text": tokens[word_idx], "tag": tag_repr})
160
+ previous_word_idx = word_idx
161
+
162
+ return doc
163
+
164
+
165
+ def parse(
166
+ text_input: str,
167
+ col_pos: bool,
168
+ col_arcs: bool,
169
+ col_labels: bool,
170
+ col_lemmata: bool,
171
+ compact: bool,
172
+ bg: str,
173
+ text: str,
174
+ ) -> Tuple[str, str]:
175
+ tokens = preprocess_text(text_input)
176
+ inputs = pos_tokenizer(
177
+ tokens,
178
+ return_tensors="pt",
179
+ truncation=True,
180
+ padding=True,
181
+ is_split_into_words=True,
182
+ )
183
+ pos_outputs = get_pos_predictions(inputs)
184
+
185
+ doc = prepare_doc(tokens, col_pos, pos_outputs, inputs)
186
+
187
+ if col_arcs:
188
+ doc["words"].insert(0, {"text": "ROOT", "tag": ""})
189
+ doc["arcs"] = get_dependencies(
190
+ arcs_model,
191
+ labels_model,
192
+ dependency_tokenizer,
193
+ data_collator,
194
+ col_labels,
195
+ tokens,
196
+ )["arcs"]
197
+
198
+ options = {"compact": compact, "bg": bg, "color": text}
199
+ svg = displacy.render(doc, manual=True, style="dep", options=options)
200
+
201
+ if col_lemmata:
202
+ soup = BeautifulSoup(svg, "lxml-xml")
203
+ lemmata = lemmatize(tokens)
204
+ svg = add_lemma_visualization(soup, lemmata, col_arcs)
205
+
206
+ download_link = download_svg(svg)
207
+
208
+ return svg, download_link
209
+
210
+
211
+ def setup_parser_ui():
212
+ demo = gr.Blocks(css="scrollbar.css")
213
+ with demo:
214
+ with gr.Box():
215
+ with gr.Row():
216
+ with gr.Column():
217
+ gr.Markdown("# Athena's Lens")
218
+ gr.Markdown(
219
+ "### From Ἀlkaios to Ὠrigen: A Modern Lens on Timeless Texts"
220
+ )
221
+ with gr.Box():
222
+ with gr.Column():
223
+ gr.Markdown(" ## Enter some text")
224
+ with gr.Row():
225
+ with gr.Column(scale=0.5):
226
+ text_input = gr.Textbox(
227
+ value=DEFAULT_TEXT, interactive=True, label="Input Text"
228
+ )
229
+ with gr.Row():
230
+ with gr.Column(scale=0.25):
231
+ button = gr.Button("Update", variant="primary").style(
232
+ full_width=False
233
+ )
234
+ with gr.Box():
235
+ with gr.Column():
236
+ with gr.Row():
237
+ with gr.Column():
238
+ gr.Markdown("## Parser")
239
+ with gr.Row():
240
+ with gr.Column():
241
+ col_pos = gr.Checkbox(label="PoS Labels", value=True)
242
+ col_arcs = gr.Checkbox(label="Dependency Arcs", value=False)
243
+ col_labels = gr.Checkbox(label="Dependency Labels", value=False)
244
+ col_lemmata = gr.Checkbox(label="Lemmata", value=False)
245
+ compact = gr.Checkbox(label="Compact", value=False)
246
+ with gr.Column():
247
+ bg = gr.Textbox(label="Background Color", value=DEFAULT_COLOR)
248
+ with gr.Column():
249
+ text = gr.Textbox(label="Text Color", value="black")
250
+ with gr.Row():
251
+ dep_output = gr.HTML(
252
+ value=parse(
253
+ DEFAULT_TEXT,
254
+ True,
255
+ False,
256
+ False,
257
+ False,
258
+ False,
259
+ DEFAULT_COLOR,
260
+ "black",
261
+ )[0]
262
+ )
263
+ with gr.Row():
264
+ with gr.Column(scale=0.25):
265
+ dep_button = gr.Button(
266
+ "Update Parser", variant="primary"
267
+ ).style(full_width=False)
268
+ with gr.Column():
269
+ dep_download_button = gr.HTML(
270
+ value=download_svg(dep_output.value)
271
+ )
272
+
273
+ with gr.Box():
274
+ with gr.Column():
275
+ with gr.Row():
276
+ with gr.Column():
277
+ gr.Markdown("## Contact")
278
+ gr.Markdown(
279
+ "If you have any questions, suggestions, comments, or problems, feel free to [reach out](mailto:riemenschneider@cl.uni-heidelberg.de)."
280
+ )
281
+ gr.Markdown("## Citation")
282
+ gr.Markdown(
283
+ "This space uses models from [this](https://aclanthology.org/2023.acl-long.846.pdf) paper."
284
+ )
285
+ gr.Markdown(
286
+ """```bibtex
287
+ @incollection{riemenschneider-frank-2023-exploring,
288
+ title = "Exploring Large Language Models for Classical Philology",
289
+ author = "Riemenschneider, Frederick and Frank, Anette",
290
+ booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
291
+ month = jul,
292
+ year = "2023",
293
+ address = "Toronto, Canada",
294
+ publisher = "Association for Computational Linguistics",
295
+ url = "https://aclanthology.org/2023.acl-long.846",
296
+ doi = "10.18653/v1/2023.acl-long.846",
297
+ pages = "15181--15199",
298
+ }
299
+ ```
300
+ """
301
+ )
302
+
303
+ button.click(
304
+ execute_parse,
305
+ inputs=[
306
+ text_input,
307
+ col_pos,
308
+ col_arcs,
309
+ col_labels,
310
+ col_lemmata,
311
+ compact,
312
+ bg,
313
+ text,
314
+ ],
315
+ outputs=[dep_output, dep_download_button],
316
+ )
317
+
318
+ dep_button.click(
319
+ execute_parse,
320
+ inputs=[
321
+ text_input,
322
+ col_pos,
323
+ col_arcs,
324
+ col_labels,
325
+ col_lemmata,
326
+ compact,
327
+ bg,
328
+ text,
329
+ ],
330
+ outputs=[dep_output, dep_download_button],
331
+ )
332
+
333
+ demo.launch()
334
+
335
+
336
+ def main():
337
+ demo = setup_parser_ui()
338
+ demo.launch()
339
+
340
+
341
+ if __name__ == "__main__":
342
+ main()
models.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import RobertaPreTrainedModel
4
+ from transformers.modeling_outputs import TokenClassifierOutput
5
+ from transformers.models.roberta.modeling_roberta import RobertaConfig, RobertaModel
6
+
7
+ from utils import batched_index_select
8
+
9
+
10
+ class DependencyRobertaForTokenClassification(RobertaPreTrainedModel):
11
+ config_class = RobertaConfig # type: ignore
12
+
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
16
+ self.u_a = nn.Linear(768, 768)
17
+ self.w_a = nn.Linear(768, 768)
18
+ self.v_a_inv = nn.Linear(768, 1, bias=False)
19
+ self.criterion = nn.NLLLoss()
20
+ self.init_weights()
21
+
22
+ def forward(
23
+ self,
24
+ input_ids=None,
25
+ attention_mask=None,
26
+ token_type_ids=None,
27
+ labels=None,
28
+ **kwargs,
29
+ ):
30
+ loss = 0.0
31
+ output = self.roberta(
32
+ input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
33
+ )[0]
34
+ batch_size, seq_len, _ = output.size()
35
+
36
+ parent_prob_table = []
37
+ for i in range(0, seq_len):
38
+ target = output[:, i, :].expand(seq_len, batch_size, -1).transpose(0, 1)
39
+ mask = output.eq(target)[:, :, 0].unsqueeze(2)
40
+ p_head = self.attention(output, target, mask)
41
+ if labels is not None:
42
+ current_loss = self.criterion(p_head.squeeze(-1), labels[:, i])
43
+ if not torch.all(labels[:, i] == -100):
44
+ loss += current_loss
45
+ parent_prob_table.append(torch.exp(p_head))
46
+
47
+ parent_prob_table = torch.cat((parent_prob_table), dim=2).data.transpose(1, 2)
48
+ prob, topi = parent_prob_table.topk(k=1, dim=2)
49
+ preds = topi.squeeze(-1)
50
+ loss = loss / seq_len
51
+ output = TokenClassifierOutput(loss=loss, logits=preds)
52
+
53
+ if labels is not None:
54
+ return output, preds, parent_prob_table, labels
55
+ else:
56
+ return output, preds, parent_prob_table
57
+
58
+ def attention(self, source, target, mask=None):
59
+ function_g = self.v_a_inv(torch.tanh(self.u_a(source) + self.w_a(target)))
60
+ if mask is not None:
61
+ function_g.masked_fill_(mask, -1e4)
62
+ return nn.functional.log_softmax(function_g, dim=1)
63
+
64
+
65
+ class LabelRobertaForTokenClassification(RobertaPreTrainedModel):
66
+ config_class = RobertaConfig # type: ignore
67
+
68
+ def __init__(self, config):
69
+ super().__init__(config)
70
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
71
+ self.num_labels = 33
72
+ self.hidden = nn.Linear(768 * 2, 768)
73
+ self.relu = nn.ReLU()
74
+ self.out = nn.Linear(768, self.num_labels)
75
+ self.loss_fct = nn.CrossEntropyLoss()
76
+
77
+ def forward(
78
+ self,
79
+ input_ids=None,
80
+ attention_mask=None,
81
+ token_type_ids=None,
82
+ labels=None,
83
+ **kwargs,
84
+ ):
85
+ loss = 0.0
86
+ output = self.roberta(
87
+ input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
88
+ )[0]
89
+ batch_size, seq_len, _ = output.size()
90
+ logits = []
91
+ for i in range(seq_len):
92
+ current_token = output[:, i, :]
93
+ connected_with_index = kwargs["head_labels"][:, i]
94
+ connected_with_index[connected_with_index == -100] = 0
95
+ connected_with_embedding = batched_index_select(
96
+ output.clone(), 1, connected_with_index.clone()
97
+ )
98
+ combined_embeddings = torch.cat(
99
+ (current_token, connected_with_embedding.squeeze(1)), -1
100
+ )
101
+ pred = self.out(self.relu(self.hidden(combined_embeddings)))
102
+ pred = pred.view(-1, self.num_labels)
103
+ logits.append(pred)
104
+ if labels is not None:
105
+ current_loss = self.loss_fct(pred, labels[:, i].view(-1))
106
+ if not torch.all(labels[:, i] == -100):
107
+ loss += current_loss
108
+
109
+ loss = loss / seq_len
110
+ logits = torch.stack(logits, dim=1)
111
+ output = TokenClassifierOutput(loss=loss, logits=logits)
112
+ return output
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ pandas==1.4.2
2
+ gradio==3.3.1
3
+ beautifulsoup4
4
+ lxml
5
+ ufal.chu-liu-edmonds
6
+ spacy
7
+ transformers
8
+ torch
scrollbar.css ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .output-html {
2
+ overflow-x: auto;
3
+ }
4
+
5
+ .output-html::-webkit-scrollbar {
6
+ -webkit-appearance: none;
7
+ }
8
+
9
+ .output-html::-webkit-scrollbar:vertical {
10
+ width: 0px;
11
+ }
12
+
13
+ .output-html::-webkit-scrollbar:horizontal {
14
+ height: 11px;
15
+ }
16
+
17
+ .output-html::-webkit-scrollbar-thumb {
18
+ border-radius: 8px;
19
+ border: 2px solid white;
20
+ background-color: rgba(0, 0, 0, .5);
21
+ }
22
+
23
+ .output-html::-webkit-scrollbar-track {
24
+ background-color: #fff;
25
+ border-radius: 8px;
26
+ }
27
+
28
+ .spans {
29
+ min-height: 75px;
30
+ }
utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Dict, Set
3
+ import numpy as np
4
+ import torch
5
+ from ufal.chu_liu_edmonds import chu_liu_edmonds
6
+
7
+ DEPENDENCY_RELATIONS = [
8
+ "acl",
9
+ "advcl",
10
+ "advmod",
11
+ "amod",
12
+ "appos",
13
+ "aux",
14
+ "case",
15
+ "cc",
16
+ "ccomp",
17
+ "conj",
18
+ "cop",
19
+ "csubj",
20
+ "det",
21
+ "iobj",
22
+ "mark",
23
+ "nmod",
24
+ "nsubj",
25
+ "nummod",
26
+ "obj",
27
+ "obl",
28
+ "parataxis",
29
+ "punct",
30
+ "root",
31
+ "vocative",
32
+ "xcomp",
33
+ ]
34
+ INDEX2TAG = {idx: tag for idx, tag in enumerate(DEPENDENCY_RELATIONS)}
35
+ TAG2INDEX = {tag: idx for idx, tag in enumerate(DEPENDENCY_RELATIONS)}
36
+
37
+
38
+ def preprocess_text(text: str) -> List[str]:
39
+ text = text.strip()
40
+ text = re.sub("(?<! )(?=[.,!?()·;:])|(?<=[.,!?()·;:])(?! )", r" ", text)
41
+ return text.split()
42
+
43
+
44
+ def batched_index_select(
45
+ input: torch.Tensor, dim: int, index: torch.Tensor
46
+ ) -> torch.Tensor:
47
+ views = [input.shape[0]] + [
48
+ 1 if i != dim else -1 for i in range(1, len(input.shape))
49
+ ]
50
+ expanse = list(input.shape)
51
+ expanse[0] = -1
52
+ expanse[dim] = -1
53
+ index = index.view(views).expand(expanse)
54
+ return torch.gather(input, dim, index)
55
+
56
+
57
+ def get_relevant_tokens(tokenized: torch.Tensor, start_ids: Set[int]) -> List[int]:
58
+ return [tokenized[idx].item() for idx in range(len(tokenized)) if idx in start_ids]
59
+
60
+
61
+ def resolve(
62
+ edmonds_head: List[int], word_ids: List[int], parent_probs_table: torch.Tensor
63
+ ) -> torch.Tensor:
64
+ multiple_roots = [i for i, x in enumerate(edmonds_head) if x == 0]
65
+ if len(multiple_roots) > 1:
66
+ main_root = max(multiple_roots, key=edmonds_head.count)
67
+ secondary_roots = set(multiple_roots) - {main_root}
68
+ for root in secondary_roots:
69
+ parent_probs_table[0][word_ids.index(root)][0] = 0
70
+ return parent_probs_table
71
+
72
+
73
+ def apply_chu_liu_edmonds(
74
+ parent_probs_table: torch.Tensor, tokenized_input: Dict, start_ids: Set[int]
75
+ ) -> List[int]:
76
+ parent_probs_table = (
77
+ parent_probs_table
78
+ if parent_probs_table.shape[1] == parent_probs_table.shape[2]
79
+ else parent_probs_table[:, :, 1:]
80
+ )
81
+ edmonds_heads, _ = chu_liu_edmonds(
82
+ parent_probs_table.squeeze(0).cpu().numpy().astype("double")
83
+ )
84
+ edmonds_heads = torch.tensor(edmonds_heads).unsqueeze(0)
85
+ edmonds_heads[edmonds_heads == -1] = 0
86
+ tokenized_input["head_labels"] = edmonds_heads
87
+ return get_relevant_tokens(edmonds_heads[0], start_ids)
88
+
89
+
90
+ def get_word_endings(tokenized_input):
91
+ word_ids = tokenized_input.word_ids(batch_index=0)
92
+ start_ids = set()
93
+ word_endings = {0: (1, 0)}
94
+ for word_id in word_ids:
95
+ if word_id is not None:
96
+ start, end = tokenized_input.word_to_tokens(
97
+ batch_or_word_index=0, word_index=word_id
98
+ )
99
+ start_ids.add(start)
100
+ word_endings[start] = (end, word_id + 1)
101
+ for a in range(start + 1, end + 1):
102
+ word_endings[a] = (end, word_id + 1)
103
+ return word_endings, start_ids, word_ids
104
+
105
+
106
+ def get_dependencies(
107
+ dependency_parser,
108
+ label_parser,
109
+ tokenizer,
110
+ collator,
111
+ labels: bool,
112
+ sentence: List[str],
113
+ ) -> Dict:
114
+ tokenized_input = tokenizer(
115
+ sentence, truncation=True, is_split_into_words=True, add_special_tokens=True
116
+ )
117
+ dep_dict: Dict[str, List[Dict[str, str]]] = {
118
+ "words": [{"text": "ROOT", "tag": ""}],
119
+ "arcs": [],
120
+ }
121
+
122
+ word_endings, start_ids, word_ids = get_word_endings(tokenized_input)
123
+ tokenized_input = collator([tokenized_input])
124
+ _, _, parent_probs_table = dependency_parser(**tokenized_input)
125
+
126
+ irrelevant = torch.tensor(
127
+ [
128
+ idx.item()
129
+ for idx in torch.arange(parent_probs_table.size(1))
130
+ if idx.item() not in start_ids and idx.item() != 0
131
+ ]
132
+ )
133
+ if irrelevant.nelement() > 0:
134
+ parent_probs_table.index_fill_(1, irrelevant, torch.nan)
135
+ parent_probs_table.index_fill_(2, irrelevant, torch.nan)
136
+
137
+ edmonds_head = apply_chu_liu_edmonds(parent_probs_table, tokenized_input, start_ids)
138
+ parent_probs_table = resolve(edmonds_head, word_ids, parent_probs_table)
139
+ edmonds_head = apply_chu_liu_edmonds(parent_probs_table, tokenized_input, start_ids)
140
+
141
+ if labels:
142
+ predictions_labels = np.argmax(
143
+ label_parser(**tokenized_input).logits.detach().cpu().numpy(), axis=-1
144
+ )
145
+ predicted_relations = get_relevant_tokens(predictions_labels[0], start_ids)
146
+ predicted_relations = [
147
+ INDEX2TAG[predicted_relations[idx]] for idx in range(len(sentence))
148
+ ]
149
+ else:
150
+ predicted_relations = [""] * len(sentence)
151
+
152
+ for idx, head in enumerate(edmonds_head):
153
+ arc = {
154
+ "start": min(idx + 1, word_endings[head][1]),
155
+ "end": max(idx + 1, word_endings[head][1]),
156
+ "label": predicted_relations[idx],
157
+ "dir": "left" if idx + 1 < word_endings[head][1] else "right",
158
+ }
159
+ dep_dict["arcs"].append(arc)
160
+
161
+ return dep_dict