ArneBinder commited on
Commit
25fcabc
1 Parent(s): 0596e00

full pipeline

Browse files
Files changed (4) hide show
  1. app.py +425 -70
  2. rendering_utils.py +4 -5
  3. requirements.txt +1 -1
  4. vector_store.py +65 -0
app.py CHANGED
@@ -1,9 +1,12 @@
1
  import json
2
  import logging
 
 
3
  from functools import partial
4
- from typing import Any, Optional, Tuple
5
 
6
  import gradio as gr
 
7
  from pie_modules.document.processing import tokenize_document
8
  from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
9
  from pie_modules.models import * # noqa: F403
@@ -16,6 +19,7 @@ from pytorch_ie.models import * # noqa: F403
16
  from pytorch_ie.taskmodules import * # noqa: F403
17
  from rendering_utils import render_displacy, render_pretty_table
18
  from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
 
19
 
20
  logger = logging.getLogger(__name__)
21
 
@@ -41,18 +45,26 @@ def embed_text_annotations(
41
  # tokenize_document does not yet consider predictions, so we need to add them manually
42
  document[text_layer_name].extend(document[text_layer_name].predictions.clear())
43
  added_annotations = []
44
- # TODO: set return_overflowing_tokens=True and max_length=...?
45
- tokenizer_kwargs = {}
 
 
 
 
46
  tokenized_documents = tokenize_document(
47
  document,
48
  tokenizer=tokenizer,
49
  result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
50
  partition_layer="labeled_partitions",
51
  added_annotations=added_annotations,
 
52
  **tokenizer_kwargs,
53
  )
54
  # just tokenize again to get tensors in the correct format for the model
 
55
  model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs)
 
 
56
  assert len(model_inputs.encodings) == len(tokenized_documents)
57
  model_output = model(**model_inputs)
58
 
@@ -80,22 +92,16 @@ def embed_text_annotations(
80
  return embeddings
81
 
82
 
83
- def predict(
84
- text: str,
85
  pipeline: Pipeline,
86
  embedding_model: Optional[PreTrainedModel] = None,
87
  embedding_tokenizer: Optional[PreTrainedTokenizer] = None,
88
- ) -> Tuple[dict, str]:
89
- document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(text=text)
90
-
91
- # add single partition from the whole text (the model only considers text in partitions)
92
- document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
93
 
94
  # execute prediction pipeline
95
  pipeline(document)
96
 
97
- document_dict = document.asdict()
98
-
99
  if embedding_model is not None and embedding_tokenizer is not None:
100
  adu_embeddings = embed_text_annotations(
101
  document=document,
@@ -105,24 +111,19 @@ def predict(
105
  )
106
  # convert keys to str because JSON keys must be strings
107
  adu_embeddings_dict = {str(k._id): v.detach().tolist() for k, v in adu_embeddings.items()}
108
- document_dict["embeddings"] = adu_embeddings_dict
109
  else:
110
  gr.Warning(
111
- "No embedding model provided. Skipping embedding extraction. You can load an embedding model in the 'Model Configuration' section."
 
112
  )
113
 
114
- # Return as dict and JSON string. The latter is required because the JSON component converts floats
115
- # to ints which destroys de-serialization of the document (the scores of the annotations need to be floats)
116
- return document_dict, json.dumps(document_dict)
117
-
118
 
119
- def render(document_txt: str, render_with: str, render_kwargs_json: str) -> str:
120
- document_dict = json.loads(document_txt)
121
- # remove embeddings from document_dict to make it de-serializable
122
- document_dict.pop("embeddings", None)
123
- document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions.fromdict(
124
- document_dict
125
- )
126
  render_kwargs = json.loads(render_kwargs_json)
127
  if render_with == RENDER_WITH_PRETTY_TABLE:
128
  html = render_pretty_table(document, **render_kwargs)
@@ -135,26 +136,255 @@ def render(document_txt: str, render_with: str, render_kwargs_json: str) -> str:
135
 
136
 
137
  def add_to_index(
138
- output_txt: str, doc_id: str, processed_documents: dict, vector_store: Any
 
 
139
  ) -> None:
140
  try:
141
- if doc_id in processed_documents:
142
- gr.Warning(f"Document {doc_id} already in index. Overwriting.")
143
- output = json.loads(output_txt)
144
- # get the embeddings from the output and remove them from the output
145
- embeddings = output.pop("embeddings")
146
  # save the processed document to the index
147
- processed_documents[doc_id] = output
148
  # save the embeddings to the vector store
149
- for adu_id, embedding in embeddings.items():
150
- emb_id = f"{doc_id}:{adu_id}"
151
- # TODO: save embedding to vector store at emb_id (embedding is a list of 768 floats)
152
-
153
  gr.Info(
154
- f"Added document {doc_id} to index (index contains {len(processed_documents)} entries). (NOT YET IMPLEMENTED)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  )
156
  except Exception as e:
157
- raise gr.Error(f"Failed to add document {doc_id} to index: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
 
160
  def open_accordion():
@@ -202,13 +432,46 @@ def load_models(
202
  return argumentation_model, embedding_model, embedding_tokenizer
203
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def main():
206
 
207
  example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
208
 
209
- print("Loading argumentation model ...")
210
- argumentation_model = load_argumentation_model(
211
- model_name=DEFAULT_MODEL_NAME, revision=DEFAULT_MODEL_REVISION
 
 
212
  )
213
 
214
  default_render_kwargs = {
@@ -236,21 +499,18 @@ def main():
236
  },
237
  }
238
 
239
- # TODO: setup the vector store
240
- vector_store = None
241
-
242
  with gr.Blocks() as demo:
243
  processed_documents_state = gr.State(dict())
244
- vector_store_state = gr.State(vector_store)
245
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
246
- models_state = gr.State((argumentation_model, None, None))
247
  with gr.Row():
248
  with gr.Column(scale=1):
249
  doc_id = gr.Textbox(
250
  label="Document ID",
251
  value="user_input",
252
  )
253
- text = gr.Textbox(
254
  label="Text",
255
  lines=20,
256
  value=example_text,
@@ -277,12 +537,12 @@ def main():
277
 
278
  predict_btn = gr.Button("Analyse")
279
 
280
- output_txt = gr.Textbox(visible=False)
281
 
282
  with gr.Column(scale=1):
283
 
284
  with gr.Accordion("See plain result ...", open=False) as output_accordion:
285
- output_json = gr.JSON(label="Model Output")
286
 
287
  with gr.Accordion("Render Options", open=False):
288
  render_as = gr.Dropdown(
@@ -299,34 +559,121 @@ def main():
299
 
300
  rendered_output = gr.HTML(label="Rendered Output")
301
 
302
- add_to_index_btn = gr.Button("Add current result to Index")
 
 
 
303
 
304
- render_button_kwargs = dict(
305
- fn=render, inputs=[output_txt, render_as, render_kwargs], outputs=rendered_output
306
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
- def _predict(
309
- text: str,
310
- models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
311
- ) -> Tuple[dict, str]:
312
- return predict(text, *models)
313
 
314
- predict_btn.click(open_accordion, inputs=[], outputs=[output_accordion]).then(
315
- fn=_predict,
316
- inputs=[text, models_state],
317
- outputs=[output_json, output_txt],
318
  api_name="predict",
319
- ).success(**render_button_kwargs).success(
320
- close_accordion, inputs=[], outputs=[output_accordion]
 
 
 
 
 
 
 
 
 
 
 
321
  )
322
- render_btn.click(**render_button_kwargs, api_name="render")
323
 
324
- add_to_index_btn.click(
325
- fn=add_to_index,
326
- inputs=[output_txt, doc_id, processed_documents_state, vector_store_state],
327
  outputs=[],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  )
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  js = """
331
  () => {
332
  function maybeSetColor(entity, colorAttributeKey, colorDictKey) {
@@ -337,8 +684,6 @@ def main():
337
  color = colors[colorDictKey];
338
  } catch (e) {}
339
  if (color) {
340
- console.log('setting color', color);
341
- console.log('entity', entity);
342
  entity.style.backgroundColor = color;
343
  entity.style.color = '#000';
344
  }
@@ -391,6 +736,15 @@ def main():
391
  });
392
  }
393
  }
 
 
 
 
 
 
 
 
 
394
 
395
  const entities = document.querySelectorAll('.entity');
396
  entities.forEach(entity => {
@@ -400,6 +754,7 @@ def main():
400
  }
401
  entity.addEventListener('mouseover', () => {
402
  highlightRelationArguments(entity.id);
 
403
  });
404
  entity.addEventListener('mouseout', () => {
405
  highlightRelationArguments(null);
 
1
  import json
2
  import logging
3
+ import os.path
4
+ from collections import defaultdict
5
  from functools import partial
6
+ from typing import Any, Dict, List, Optional, Tuple
7
 
8
  import gradio as gr
9
+ import pandas as pd
10
  from pie_modules.document.processing import tokenize_document
11
  from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
12
  from pie_modules.models import * # noqa: F403
 
19
  from pytorch_ie.taskmodules import * # noqa: F403
20
  from rendering_utils import render_displacy, render_pretty_table
21
  from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
22
+ from vector_store import SimpleVectorStore
23
 
24
  logger = logging.getLogger(__name__)
25
 
 
45
  # tokenize_document does not yet consider predictions, so we need to add them manually
46
  document[text_layer_name].extend(document[text_layer_name].predictions.clear())
47
  added_annotations = []
48
+ tokenizer_kwargs = {
49
+ "max_length": 512,
50
+ "stride": 64,
51
+ "truncation": True,
52
+ "return_overflowing_tokens": True,
53
+ }
54
  tokenized_documents = tokenize_document(
55
  document,
56
  tokenizer=tokenizer,
57
  result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
58
  partition_layer="labeled_partitions",
59
  added_annotations=added_annotations,
60
+ strict_span_conversion=False,
61
  **tokenizer_kwargs,
62
  )
63
  # just tokenize again to get tensors in the correct format for the model
64
+ # TODO: fix for A34.txt from sciarg corpus
65
  model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs)
66
+ # this is added when using return_overflowing_tokens=True, but the model does not accept it
67
+ model_inputs.pop("overflow_to_sample_mapping", None)
68
  assert len(model_inputs.encodings) == len(tokenized_documents)
69
  model_output = model(**model_inputs)
70
 
 
92
  return embeddings
93
 
94
 
95
+ def annotate(
96
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
97
  pipeline: Pipeline,
98
  embedding_model: Optional[PreTrainedModel] = None,
99
  embedding_tokenizer: Optional[PreTrainedTokenizer] = None,
100
+ ) -> None:
 
 
 
 
101
 
102
  # execute prediction pipeline
103
  pipeline(document)
104
 
 
 
105
  if embedding_model is not None and embedding_tokenizer is not None:
106
  adu_embeddings = embed_text_annotations(
107
  document=document,
 
111
  )
112
  # convert keys to str because JSON keys must be strings
113
  adu_embeddings_dict = {str(k._id): v.detach().tolist() for k, v in adu_embeddings.items()}
114
+ document.metadata["embeddings"] = adu_embeddings_dict
115
  else:
116
  gr.Warning(
117
+ "No embedding model provided. Skipping embedding extraction. You can load an embedding "
118
+ "model in the 'Model Configuration' section."
119
  )
120
 
 
 
 
 
121
 
122
+ def render_annotated_document(
123
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
124
+ render_with: str,
125
+ render_kwargs_json: str,
126
+ ) -> str:
 
 
127
  render_kwargs = json.loads(render_kwargs_json)
128
  if render_with == RENDER_WITH_PRETTY_TABLE:
129
  html = render_pretty_table(document, **render_kwargs)
 
136
 
137
 
138
  def add_to_index(
139
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
140
+ processed_documents: dict,
141
+ vector_store: SimpleVectorStore,
142
  ) -> None:
143
  try:
144
+ if document.id in processed_documents:
145
+ gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
 
 
 
146
  # save the processed document to the index
147
+ processed_documents[document.id] = document
148
  # save the embeddings to the vector store
149
+ for adu_id, embedding in document.metadata["embeddings"].items():
150
+ vector_store.save((document.id, adu_id), embedding)
 
 
151
  gr.Info(
152
+ f"Added document {document.id} to index (index contains {len(processed_documents)} "
153
+ f"documents and {len(vector_store)} embeddings)."
154
+ )
155
+ except Exception as e:
156
+ raise gr.Error(f"Failed to add document {document.id} to index: {e}")
157
+
158
+
159
+ def process_text(
160
+ text: str,
161
+ doc_id: str,
162
+ models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
163
+ processed_documents: dict[
164
+ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
165
+ ],
166
+ vector_store: SimpleVectorStore,
167
+ ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
168
+ try:
169
+ document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
170
+ id=doc_id, text=text, metadata={}
171
+ )
172
+ # add single partition from the whole text (the model only considers text in partitions)
173
+ document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
174
+ # annotate the document
175
+ annotate(
176
+ document=document,
177
+ pipeline=models[0],
178
+ embedding_model=models[1],
179
+ embedding_tokenizer=models[2],
180
+ )
181
+ # add the document to the index
182
+ add_to_index(document, processed_documents, vector_store)
183
+
184
+ return document
185
+ except Exception as e:
186
+ raise gr.Error(f"Failed to process text: {e}")
187
+
188
+
189
+ def wrapped_process_text(
190
+ text: str,
191
+ doc_id: str,
192
+ models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
193
+ processed_documents: dict[
194
+ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
195
+ ],
196
+ vector_store: SimpleVectorStore,
197
+ ) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
198
+ document = process_text(
199
+ text=text,
200
+ doc_id=doc_id,
201
+ models=models,
202
+ processed_documents=processed_documents,
203
+ vector_store=vector_store,
204
+ )
205
+ # Return as dict and document to avoid serialization issues
206
+ return document.asdict(), document
207
+
208
+
209
+ def process_uploaded_file(
210
+ file_names: List[str],
211
+ models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
212
+ processed_documents: dict[
213
+ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
214
+ ],
215
+ vector_store: SimpleVectorStore,
216
+ ) -> None:
217
+ try:
218
+ for file_name in file_names:
219
+ if file_name.lower().endswith(".txt"):
220
+ # read the file content
221
+ with open(file_name, "r", encoding="utf-8") as f:
222
+ text = f.read()
223
+ base_file_name = os.path.basename(file_name)
224
+ gr.Info(f"Processing file '{base_file_name}' ...")
225
+ process_text(text, base_file_name, models, processed_documents, vector_store)
226
+ else:
227
+ raise gr.Error(f"Unsupported file format: {file_name}")
228
+ except Exception as e:
229
+ raise gr.Error(f"Failed to process uploaded files: {e}")
230
+
231
+
232
+ def _get_annotation_from_document(
233
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
234
+ annotation_id: str,
235
+ annotation_layer: str,
236
+ ) -> LabeledSpan:
237
+ # use predictions
238
+ annotations = document[annotation_layer].predictions
239
+ id2annotation = {str(annotation._id): annotation for annotation in annotations}
240
+ annotation = id2annotation.get(annotation_id)
241
+ if annotation is None:
242
+ raise gr.Error(
243
+ f"annotation '{annotation_id}' not found in document '{document.id}'. Available "
244
+ f"annotations: {id2annotation}"
245
+ )
246
+ return annotation
247
+
248
+
249
+ def _get_annotation(
250
+ doc_id: str,
251
+ annotation_id: str,
252
+ annotation_layer: str,
253
+ processed_documents: dict[
254
+ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
255
+ ],
256
+ ) -> LabeledSpan:
257
+ document = processed_documents.get(doc_id)
258
+ if document is None:
259
+ raise gr.Error(
260
+ f"Document '{doc_id}' not found in index. Available documents: {list(processed_documents)}"
261
+ )
262
+ return _get_annotation_from_document(document, annotation_id, annotation_layer)
263
+
264
+
265
+ def _get_similar_entries_from_vector_store(
266
+ ref_annotation_id: str,
267
+ ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
268
+ vector_store: SimpleVectorStore[Tuple[str, str]],
269
+ **retrieval_kwargs,
270
+ ) -> List[Tuple[Tuple[str, str], float]]:
271
+ embeddings = ref_document.metadata["embeddings"]
272
+ ref_embedding = embeddings.get(ref_annotation_id)
273
+ if ref_embedding is None:
274
+ raise gr.Error(
275
+ f"Embedding for annotation '{ref_annotation_id}' not found in metadata of "
276
+ f"document '{ref_document.id}'. Annotations with embeddings: {list(embeddings)}"
277
+ )
278
+
279
+ try:
280
+ similar_entries = vector_store.retrieve_similar(
281
+ ref_id=(ref_document.id, ref_annotation_id), **retrieval_kwargs
282
  )
283
  except Exception as e:
284
+ raise gr.Error(f"Failed to retrieve similar ADUs: {e}")
285
+
286
+ return similar_entries
287
+
288
+
289
+ def get_similar_adus(
290
+ ref_annotation_id: str,
291
+ ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
292
+ vector_store: SimpleVectorStore,
293
+ processed_documents: dict[
294
+ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
295
+ ],
296
+ min_similarity: float,
297
+ ) -> pd.DataFrame:
298
+ similar_entries = _get_similar_entries_from_vector_store(
299
+ ref_annotation_id=ref_annotation_id,
300
+ ref_document=ref_document,
301
+ vector_store=vector_store,
302
+ min_similarity=min_similarity,
303
+ )
304
+
305
+ similar_annotations = [
306
+ _get_annotation(
307
+ doc_id=doc_id,
308
+ annotation_id=annotation_id,
309
+ annotation_layer="labeled_spans",
310
+ processed_documents=processed_documents,
311
+ )
312
+ for (doc_id, annotation_id), _ in similar_entries
313
+ ]
314
+ df = pd.DataFrame(
315
+ [
316
+ # unpack the tuple (doc_id, annotation_id) to separate columns
317
+ # and add the similarity score and the text of the annotation
318
+ (doc_id, annotation_id, score, str(annotation))
319
+ for ((doc_id, annotation_id), score), annotation in zip(
320
+ similar_entries, similar_annotations
321
+ )
322
+ ],
323
+ columns=["doc_id", "adu_id", "sim_score", "text"],
324
+ )
325
+
326
+ return df
327
+
328
+
329
+ def get_relevant_adus(
330
+ ref_annotation_id: str,
331
+ ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
332
+ vector_store: SimpleVectorStore,
333
+ processed_documents: dict[
334
+ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
335
+ ],
336
+ min_similarity: float,
337
+ ) -> pd.DataFrame:
338
+ similar_entries = _get_similar_entries_from_vector_store(
339
+ ref_annotation_id=ref_annotation_id,
340
+ ref_document=ref_document,
341
+ vector_store=vector_store,
342
+ min_similarity=min_similarity,
343
+ )
344
+ ref_annotation = _get_annotation(
345
+ doc_id=ref_document.id,
346
+ annotation_id=ref_annotation_id,
347
+ annotation_layer="labeled_spans",
348
+ processed_documents=processed_documents,
349
+ )
350
+ result = []
351
+ for (doc_id, annotation_id), score in similar_entries:
352
+ # skip entries from the same document
353
+ if doc_id == ref_document.id:
354
+ continue
355
+ document = processed_documents[doc_id]
356
+ tail2rels = defaultdict(list)
357
+ head2rels = defaultdict(list)
358
+ for rel in document.binary_relations.predictions:
359
+ # skip non-argumentative relations
360
+ if rel.label in ["parts_of_same", "semantically_same"]:
361
+ continue
362
+ head2rels[rel.head].append(rel)
363
+ tail2rels[rel.tail].append(rel)
364
+
365
+ id2annotation = {
366
+ str(annotation._id): annotation for annotation in document.labeled_spans.predictions
367
+ }
368
+ annotation = id2annotation.get(annotation_id)
369
+ # note: we do not need to check if the annotation is different from the reference annotation,
370
+ # because they com from different documents and we already skip entries from the same document
371
+ for rel in head2rels.get(annotation, []):
372
+ result.append(
373
+ {
374
+ "doc_id": doc_id,
375
+ "reference_adu": str(annotation),
376
+ "sim_score": score,
377
+ "rel_score": rel.score,
378
+ "relation": rel.label,
379
+ "text": str(rel.tail),
380
+ }
381
+ )
382
+
383
+ # define column order
384
+ df = pd.DataFrame(
385
+ result, columns=["text", "relation", "doc_id", "reference_adu", "sim_score", "rel_score"]
386
+ )
387
+ return df
388
 
389
 
390
  def open_accordion():
 
432
  return argumentation_model, embedding_model, embedding_tokenizer
433
 
434
 
435
+ def update_processed_documents_df(
436
+ processed_documents: dict[str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]
437
+ ) -> pd.DataFrame:
438
+ df = pd.DataFrame(
439
+ [
440
+ (
441
+ doc_id,
442
+ len(document.labeled_spans.predictions),
443
+ len(document.binary_relations.predictions),
444
+ )
445
+ for doc_id, document in processed_documents.items()
446
+ ],
447
+ columns=["doc_id", "num_adus", "num_relations"],
448
+ )
449
+ return df
450
+
451
+
452
+ def select_processed_document(
453
+ evt: gr.SelectData,
454
+ processed_documents_df: pd.DataFrame,
455
+ processed_documents: Dict[
456
+ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
457
+ ],
458
+ ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
459
+ row_idx, col_idx = evt.index
460
+ doc_id = processed_documents_df.iloc[row_idx]["doc_id"]
461
+ gr.Info(f"Select document: {doc_id}")
462
+ doc = processed_documents[doc_id]
463
+ return doc
464
+
465
+
466
  def main():
467
 
468
  example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
469
 
470
+ print("Loading models ...")
471
+ argumentation_model, embedding_model, embedding_tokenizer = load_models(
472
+ model_name=DEFAULT_MODEL_NAME,
473
+ revision=DEFAULT_MODEL_REVISION,
474
+ embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
475
  )
476
 
477
  default_render_kwargs = {
 
499
  },
500
  }
501
 
 
 
 
502
  with gr.Blocks() as demo:
503
  processed_documents_state = gr.State(dict())
504
+ vector_store_state = gr.State(SimpleVectorStore())
505
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
506
+ models_state = gr.State((argumentation_model, embedding_model, embedding_tokenizer))
507
  with gr.Row():
508
  with gr.Column(scale=1):
509
  doc_id = gr.Textbox(
510
  label="Document ID",
511
  value="user_input",
512
  )
513
+ doc_text = gr.Textbox(
514
  label="Text",
515
  lines=20,
516
  value=example_text,
 
537
 
538
  predict_btn = gr.Button("Analyse")
539
 
540
+ document_state = gr.State()
541
 
542
  with gr.Column(scale=1):
543
 
544
  with gr.Accordion("See plain result ...", open=False) as output_accordion:
545
+ document_json = gr.JSON(label="Model Output")
546
 
547
  with gr.Accordion("Render Options", open=False):
548
  render_as = gr.Dropdown(
 
559
 
560
  rendered_output = gr.HTML(label="Rendered Output")
561
 
562
+ # add_to_index_btn = gr.Button("Add current result to Index")
563
+ upload_btn = gr.UploadButton(
564
+ "Upload & Analyse Documents", file_types=["text"], file_count="multiple"
565
+ )
566
 
567
+ with gr.Column(scale=1):
568
+ with gr.Accordion("Indexed Documents", open=False):
569
+ processed_documents_df = gr.DataFrame(
570
+ headers=["id", "num_adus", "num_relations"],
571
+ interactive=False,
572
+ )
573
+
574
+ with gr.Accordion("Reference ADU", open=False):
575
+ reference_adu_id = gr.Textbox(label="ID", elem_id="reference_adu_id")
576
+ reference_adu_text = gr.Textbox(label="Text")
577
+
578
+ with gr.Accordion("Retrieval Configuration", open=False):
579
+ min_similarity = gr.Slider(
580
+ label="Minimum Similarity",
581
+ minimum=0.0,
582
+ maximum=1.0,
583
+ step=0.01,
584
+ value=0.8,
585
+ )
586
+ retrieve_similar_adus_btn = gr.Button("Retrieve similar ADUs")
587
+ similar_adus = gr.DataFrame(headers=["doc_id", "adu_id", "score", "text"])
588
+
589
+ # retrieve_relevant_adus_btn = gr.Button("Retrieve relevant ADUs")
590
+ relevant_adus = gr.DataFrame(
591
+ label="Relevant ADUs from other documents",
592
+ headers=[
593
+ "text",
594
+ "relation",
595
+ "doc_id",
596
+ "reference_adu",
597
+ "sim_score",
598
+ "rel_score",
599
+ ],
600
+ )
601
 
602
+ render_event_kwargs = dict(
603
+ fn=render_annotated_document,
604
+ inputs=[document_state, render_as, render_kwargs],
605
+ outputs=rendered_output,
606
+ )
607
 
608
+ predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
609
+ fn=wrapped_process_text,
610
+ inputs=[doc_text, doc_id, models_state, processed_documents_state, vector_store_state],
611
+ outputs=[document_json, document_state],
612
  api_name="predict",
613
+ ).success(
614
+ fn=update_processed_documents_df,
615
+ inputs=[processed_documents_state],
616
+ outputs=[processed_documents_df],
617
+ )
618
+ render_btn.click(**render_event_kwargs, api_name="render")
619
+
620
+ document_state.change(
621
+ fn=lambda doc: doc.asdict(),
622
+ inputs=[document_state],
623
+ outputs=[document_json],
624
+ ).success(close_accordion, inputs=[], outputs=[output_accordion]).then(
625
+ **render_event_kwargs
626
  )
 
627
 
628
+ upload_btn.upload(
629
+ fn=process_uploaded_file,
630
+ inputs=[upload_btn, models_state, processed_documents_state, vector_store_state],
631
  outputs=[],
632
+ ).success(
633
+ fn=update_processed_documents_df,
634
+ inputs=[processed_documents_state],
635
+ outputs=[processed_documents_df],
636
+ )
637
+ processed_documents_df.select(
638
+ select_processed_document,
639
+ inputs=[processed_documents_df, processed_documents_state],
640
+ outputs=[document_state],
641
+ )
642
+
643
+ retrieve_relevant_adus_event_kwargs = dict(
644
+ fn=get_relevant_adus,
645
+ inputs=[
646
+ reference_adu_id,
647
+ document_state,
648
+ vector_store_state,
649
+ processed_documents_state,
650
+ min_similarity,
651
+ ],
652
+ outputs=[relevant_adus],
653
  )
654
 
655
+ reference_adu_id.change(
656
+ fn=partial(_get_annotation_from_document, annotation_layer="labeled_spans"),
657
+ inputs=[document_state, reference_adu_id],
658
+ outputs=[reference_adu_text],
659
+ ).success(**retrieve_relevant_adus_event_kwargs)
660
+
661
+ retrieve_similar_adus_btn.click(
662
+ fn=get_similar_adus,
663
+ inputs=[
664
+ reference_adu_id,
665
+ document_state,
666
+ vector_store_state,
667
+ processed_documents_state,
668
+ min_similarity,
669
+ ],
670
+ outputs=[similar_adus],
671
+ )
672
+
673
+ # retrieve_relevant_adus_btn.click(
674
+ # **retrieve_relevant_adus_event_kwargs
675
+ # )
676
+
677
  js = """
678
  () => {
679
  function maybeSetColor(entity, colorAttributeKey, colorDictKey) {
 
684
  color = colors[colorDictKey];
685
  } catch (e) {}
686
  if (color) {
 
 
687
  entity.style.backgroundColor = color;
688
  entity.style.color = '#000';
689
  }
 
736
  });
737
  }
738
  }
739
+ function setReferenceAduId(entityId) {
740
+ // get the textarea element that holds the reference adu id
741
+ let referenceAduIdDiv = document.querySelector('#reference_adu_id textarea');
742
+ // set the value of the input field
743
+ referenceAduIdDiv.value = entityId;
744
+ // trigger an input event to update the state
745
+ var event = new Event('input');
746
+ referenceAduIdDiv.dispatchEvent(event);
747
+ }
748
 
749
  const entities = document.querySelectorAll('.entity');
750
  entities.forEach(entity => {
 
754
  }
755
  entity.addEventListener('mouseover', () => {
756
  highlightRelationArguments(entity.id);
757
+ setReferenceAduId(entity.id);
758
  });
759
  entity.addEventListener('mouseout', () => {
760
  highlightRelationArguments(null);
rendering_utils.py CHANGED
@@ -76,12 +76,11 @@ def inject_relation_data(
76
  entity2heads[relation.tail].append((relation.head, relation.label))
77
  entity2tails[relation.head].append((relation.tail, relation.label))
78
 
79
- entity2id = {entity: f"entity-{idx}" for idx, entity in enumerate(sorted_entities)}
80
-
81
  # Add unique IDs to each entity
82
  entities = soup.find_all(class_="entity")
83
  for idx, entity in enumerate(entities):
84
- entity["id"] = f"entity-{idx}"
 
85
  original_color = entity["style"].split("background:")[1].split(";")[0].strip()
86
  entity["data-color-original"] = original_color
87
  if additional_colors is not None:
@@ -96,13 +95,13 @@ def inject_relation_data(
96
  entity["data-label"] = entity_annotation.label
97
  entity["data-relation-tails"] = json.dumps(
98
  [
99
- {"entity-id": entity2id[tail], "label": label}
100
  for tail, label in entity2tails.get(entity_annotation, [])
101
  ]
102
  )
103
  entity["data-relation-heads"] = json.dumps(
104
  [
105
- {"entity-id": entity2id[head], "label": label}
106
  for head, label in entity2heads.get(entity_annotation, [])
107
  ]
108
  )
 
76
  entity2heads[relation.tail].append((relation.head, relation.label))
77
  entity2tails[relation.head].append((relation.tail, relation.label))
78
 
 
 
79
  # Add unique IDs to each entity
80
  entities = soup.find_all(class_="entity")
81
  for idx, entity in enumerate(entities):
82
+ annotation = sorted_entities[idx]
83
+ entity["id"] = str(annotation._id)
84
  original_color = entity["style"].split("background:")[1].split(";")[0].strip()
85
  entity["data-color-original"] = original_color
86
  if additional_colors is not None:
 
95
  entity["data-label"] = entity_annotation.label
96
  entity["data-relation-tails"] = json.dumps(
97
  [
98
+ {"entity-id": str(tail._id), "label": label}
99
  for tail, label in entity2tails.get(entity_annotation, [])
100
  ]
101
  )
102
  entity["data-relation-heads"] = json.dumps(
103
  [
104
+ {"entity-id": str(head._id), "label": label}
105
  for head, label in entity2heads.get(entity_annotation, [])
106
  ]
107
  )
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio==4.31.4
2
  prettytable==3.10.0
3
  pie-modules==0.12.0
4
  beautifulsoup4==4.12.3
 
1
+ gradio==4.36.0
2
  prettytable==3.10.0
3
  pie-modules==0.12.0
4
  beautifulsoup4==4.12.3
vector_store.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Generic, Hashable, List, Optional, Tuple, TypeVar
2
+
3
+
4
+ def vector_norm(vector: List[float]) -> float:
5
+ return sum(x**2 for x in vector) ** 0.5
6
+
7
+
8
+ def cosine_similarity(a: List[float], b: List[float]) -> float:
9
+ return sum(a * b for a, b in zip(a, b)) / (vector_norm(a) * vector_norm(b))
10
+
11
+
12
+ T = TypeVar("T", bound=Hashable)
13
+
14
+
15
+ class SimpleVectorStore(Generic[T]):
16
+ def __init__(self):
17
+ self.vectors: dict[T, List[float]] = {}
18
+ self._cache = {}
19
+ self._sim = cosine_similarity
20
+
21
+ def save(self, emb_id: T, embedding: List[float]) -> None:
22
+ self.vectors[emb_id] = embedding
23
+
24
+ def get(self, emb_id: T) -> Optional[List[float]]:
25
+ return self.vectors.get(emb_id)
26
+
27
+ def delete(self, emb_id: T) -> None:
28
+ if emb_id in self.vectors:
29
+ del self.vectors[emb_id]
30
+ # remove from cache
31
+ self._cache = {k: v for k, v in self._cache.items() if emb_id not in k}
32
+
33
+ def clear(self) -> None:
34
+ self.vectors.clear()
35
+ self._cache.clear()
36
+
37
+ def __len__(self):
38
+ return len(self.vectors)
39
+
40
+ def retrieve_similar(
41
+ self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None
42
+ ) -> List[Tuple[T, float]]:
43
+ ref_embedding = self.get(ref_id)
44
+ if ref_embedding is None:
45
+ raise ValueError(f"Reference embedding '{ref_id}' not found.")
46
+
47
+ # calculate similarity to all embeddings
48
+ similarities = {}
49
+ for emb_id, embedding in self.vectors.items():
50
+ if (emb_id, ref_id) not in self._cache:
51
+ # use cosine similarity
52
+ self._cache[(emb_id, ref_id)] = self._sim(ref_embedding, embedding)
53
+ similarities[emb_id] = self._cache[(emb_id, ref_id)]
54
+
55
+ # sort by similarity
56
+ similar_entries = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
57
+
58
+ if min_similarity is not None:
59
+ similar_entries = [
60
+ (emb_id, sim) for emb_id, sim in similar_entries if sim >= min_similarity
61
+ ]
62
+ if top_k is not None:
63
+ similar_entries = similar_entries[:top_k]
64
+
65
+ return similar_entries