ArneBinder commited on
Commit
a8df5fb
1 Parent(s): b0174fa

from https://github.com/ArneBinder/pie-document-level/pull/238

Browse files
Files changed (5) hide show
  1. app.py +4 -4
  2. document_store.py +108 -40
  3. embedding.py +4 -4
  4. rendering_utils.py +32 -12
  5. vector_store.py +9 -1
app.py CHANGED
@@ -354,7 +354,7 @@ def main():
354
  minimum=0.0,
355
  maximum=1.0,
356
  step=0.01,
357
- value=0.8,
358
  )
359
  top_k = gr.Slider(
360
  label="Top K",
@@ -398,10 +398,10 @@ def main():
398
  )
399
 
400
  show_overview_kwargs = dict(
401
- fn=lambda document_store, show_max_sims: document_store.overview(
402
  with_max_cross_doc_sims=show_max_sims
403
  ),
404
- inputs=[document_store_state, show_max_cross_docu_sims],
405
  outputs=[processed_documents_df],
406
  )
407
  predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
@@ -505,7 +505,7 @@ def main():
505
  DocumentStore.get_all2all_adu_similarities,
506
  columns=all2all_adu_similarities.headers,
507
  ),
508
- inputs=[document_store_state],
509
  outputs=[all2all_adu_similarities],
510
  )
511
 
 
354
  minimum=0.0,
355
  maximum=1.0,
356
  step=0.01,
357
+ value=0.95,
358
  )
359
  top_k = gr.Slider(
360
  label="Top K",
 
398
  )
399
 
400
  show_overview_kwargs = dict(
401
+ fn=lambda document_store, show_max_sims, min_sim: document_store.overview(
402
  with_max_cross_doc_sims=show_max_sims
403
  ),
404
+ inputs=[document_store_state, show_max_cross_docu_sims, min_similarity],
405
  outputs=[processed_documents_df],
406
  )
407
  predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
 
505
  DocumentStore.get_all2all_adu_similarities,
506
  columns=all2all_adu_similarities.headers,
507
  ),
508
+ inputs=[document_store_state, min_similarity],
509
  outputs=[all2all_adu_similarities],
510
  )
511
 
document_store.py CHANGED
@@ -16,6 +16,7 @@ from pytorch_ie.documents import (
16
  TextBasedDocument,
17
  TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
18
  )
 
19
  from vector_store import VectorStore
20
 
21
  logger = logging.getLogger(__name__)
@@ -342,6 +343,38 @@ class DocumentStore:
342
  f"Added {len(documents_json)} documents to the index ({len(self.documents)} documents in total)."
343
  )
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  def add_documents_from_zip(self, file_path: str) -> None:
346
  temp_dir = os.path.join(tempfile.gettempdir(), "document_store")
347
  # remove the temporary directory if it already exists
@@ -418,7 +451,9 @@ class DocumentStore:
418
 
419
  return document
420
 
421
- def overview(self, with_max_cross_doc_sims: bool = False) -> pd.DataFrame:
 
 
422
  rows = []
423
  for doc_id, document in self.documents.items():
424
  layers = {
@@ -433,13 +468,8 @@ class DocumentStore:
433
 
434
  # add highest cross-document similarity score for each document
435
  if with_max_cross_doc_sims and len(self.documents) > 1:
436
- # Setting min_similarity to None and top_k to None to get all similarities. Otherwise,
437
- # it may happen that this occludes max cross-doc sim for some documents in the
438
- # case that there are more than top_k ADUs in the reference document that have a higher
439
- # similarity with each other than the highest similarity to any ADU in another document
440
- # or if the cross-doc similarity is below the min_similarity threshold.
441
  all2all_adu_similarities = self.get_all2all_adu_similarities(
442
- min_similarity=None, top_k=None, columns=["doc_id", "other_doc_id", "sim_score"]
443
  )
444
  max_doc2doc_similarities = all2all_adu_similarities.pivot_table(
445
  values="sim_score", index="doc_id", columns="other_doc_id", aggfunc="max"
@@ -478,50 +508,88 @@ class DocumentStore:
478
  def get_all2all_adu_similarities(
479
  self,
480
  min_similarity: Optional[float] = 0.5,
481
- top_k: Optional[int] = 100,
482
  columns: Optional[List[str]] = None,
483
  ) -> pd.DataFrame:
484
  """Get the similarities between all ADUs in the store.
485
 
486
  Args:
487
- min_similarity: The minimum similarity score to consider.
488
- top_k: The number of similar ADUs to return.
489
  columns: The columns to include in the result DataFrame. If None, all columns are included.
490
 
491
  Returns:
492
  A DataFrame with the columns: doc_id, text, other_doc_id, other_text, sim_score.
493
  """
494
- result = []
495
- document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
496
- for doc_id, document in self.documents.items():
497
- for adu in document.labeled_spans.predictions:
498
- adu_id = labeled_span_to_id(adu)
499
- similar_entries = self.vector_store.retrieve_similar(
500
- ref_payload=self.construct_embedding_payload(document, adu_id),
501
- min_similarity=min_similarity,
502
- top_k=top_k,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  )
504
- for _, payload, score in similar_entries:
505
- other_doc_id = payload["doc_id"]
506
- other_document = self.documents[other_doc_id]
507
- other_adu = get_annotation_from_document(
508
- other_document,
509
- payload["annotation_id"],
510
- self.span_layer_name,
511
- use_predictions=self.use_predictions,
512
- )
513
- result.append(
514
- {
515
- "sim_score": score,
516
- "doc_id": doc_id,
517
- "other_doc_id": other_doc_id,
518
- "adu_id": adu_id,
519
- "other_adu_id": payload["annotation_id"],
520
- "text": str(adu),
521
- "other_text": str(other_adu),
522
- }
523
- )
524
- result_df = pd.DataFrame(result)
525
  if columns is not None:
526
  result_df = result_df[columns]
527
  return result_df
 
16
  TextBasedDocument,
17
  TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
18
  )
19
+ from scipy.sparse import csr_matrix
20
  from vector_store import VectorStore
21
 
22
  logger = logging.getLogger(__name__)
 
343
  f"Added {len(documents_json)} documents to the index ({len(self.documents)} documents in total)."
344
  )
345
 
346
+ def get_payloads_for_missing_and_unexpected_embeddings(self) -> dict[str, dict[str, Any]]:
347
+ """Get the payloads for missing and unexpected embeddings in the vector store. An embedding
348
+ is missing if its annotation is in the documents but the embedding is not in the vector
349
+ store. An embedding is unexpected if it is in the vector store but the annotation is not in
350
+ the documents.
351
+
352
+ Returns:
353
+ A dictionary with the missing and unexpected payloads.
354
+ """
355
+ expected_payloads = []
356
+ for document in self.documents.values():
357
+ for annotation in document[self.span_layer_name].predictions:
358
+ annotation_id = labeled_span_to_id(annotation)
359
+ payload = self.construct_embedding_payload(document, annotation_id)
360
+ expected_payloads.append(payload)
361
+ vector_sore_payloads = self.vector_store.as_indices_vectors_payloads()[2]
362
+ # construct mappings from ids to payloads to compare the expected and actual payloads
363
+ expected_mapping = {
364
+ json.dumps(payload, sort_keys=True): payload for payload in expected_payloads
365
+ }
366
+ vector_store_mapping = {
367
+ json.dumps(payload, sort_keys=True): payload for payload in vector_sore_payloads
368
+ }
369
+ missing = set(expected_mapping) - set(vector_store_mapping)
370
+ unexpected = set(vector_store_mapping) - set(expected_mapping)
371
+
372
+ # return the missing and unexpected payloads
373
+ return {
374
+ "missing": {payload: expected_mapping[payload] for payload in missing},
375
+ "unexpected": {payload: vector_store_mapping[payload] for payload in unexpected},
376
+ }
377
+
378
  def add_documents_from_zip(self, file_path: str) -> None:
379
  temp_dir = os.path.join(tempfile.gettempdir(), "document_store")
380
  # remove the temporary directory if it already exists
 
451
 
452
  return document
453
 
454
+ def overview(
455
+ self, with_max_cross_doc_sims: bool = False, min_similarity: float = 0.9
456
+ ) -> pd.DataFrame:
457
  rows = []
458
  for doc_id, document in self.documents.items():
459
  layers = {
 
468
 
469
  # add highest cross-document similarity score for each document
470
  if with_max_cross_doc_sims and len(self.documents) > 1:
 
 
 
 
 
471
  all2all_adu_similarities = self.get_all2all_adu_similarities(
472
+ min_similarity=min_similarity, columns=["doc_id", "other_doc_id", "sim_score"]
473
  )
474
  max_doc2doc_similarities = all2all_adu_similarities.pivot_table(
475
  values="sim_score", index="doc_id", columns="other_doc_id", aggfunc="max"
 
508
  def get_all2all_adu_similarities(
509
  self,
510
  min_similarity: Optional[float] = 0.5,
 
511
  columns: Optional[List[str]] = None,
512
  ) -> pd.DataFrame:
513
  """Get the similarities between all ADUs in the store.
514
 
515
  Args:
516
+ min_similarity: The minimum similarity score to consider. If None, all similarities are included.
 
517
  columns: The columns to include in the result DataFrame. If None, all columns are included.
518
 
519
  Returns:
520
  A DataFrame with the columns: doc_id, text, other_doc_id, other_text, sim_score.
521
  """
522
+
523
+ # shape of all_embeddings: (num_embeddings, embedding_dim)
524
+ (
525
+ all_embed_ids,
526
+ all_embeddings,
527
+ all_payloads,
528
+ ) = self.vector_store.as_indices_vectors_payloads()
529
+
530
+ doc_id_and_annotation_id2annotation_text = {}
531
+ for doc in self.documents.values():
532
+ for annotation in doc[self.span_layer_name]:
533
+ doc_id_and_annotation_id2annotation_text[
534
+ (doc.id, labeled_span_to_id(annotation))
535
+ ] = str(annotation)
536
+ for annotation in doc[self.span_layer_name].predictions:
537
+ doc_id_and_annotation_id2annotation_text[
538
+ (doc.id, labeled_span_to_id(annotation))
539
+ ] = str(annotation)
540
+
541
+ # calculate cosine similarities between all embeddings
542
+ dot_prod = np.dot(all_embeddings, all_embeddings.T)
543
+ norm = np.linalg.norm(all_embeddings, axis=1)
544
+ norm_prod = np.outer(norm, norm)
545
+ similarities = dot_prod / norm_prod
546
+
547
+ gr.Info(f"Similarities shape: {similarities.shape}")
548
+
549
+ if min_similarity is not None:
550
+ gr.Info(f"Filtering similarities below {min_similarity}.")
551
+ # set similarities below min_similarity to 0
552
+ similarities[similarities < min_similarity] = 0.0
553
+
554
+ # set triangular part to 0
555
+ similarities = np.triu(similarities, k=1)
556
+ # create a sparse matrix
557
+ sparse_matrix = csr_matrix(similarities)
558
+ sparse_matrix.eliminate_zeros()
559
+ # Get the non-zero values and their indices
560
+ non_zero_idx = sparse_matrix.nonzero()
561
+ scores = sparse_matrix.data
562
+
563
+ gr.Info(f"Number of similarities above {min_similarity}: {len(scores)}")
564
+
565
+ # construct the DataFrame
566
+ records = []
567
+ for idx1, idx2 in zip(non_zero_idx[0], non_zero_idx[1]):
568
+ if idx1 < idx2:
569
+ doc_id1 = all_payloads[idx1]["doc_id"]
570
+ doc_id2 = all_payloads[idx2]["doc_id"]
571
+ annotation_id1 = all_payloads[idx1]["annotation_id"]
572
+ annotation_id2 = all_payloads[idx2]["annotation_id"]
573
+ annotation_text1 = doc_id_and_annotation_id2annotation_text[
574
+ (doc_id1, annotation_id1)
575
+ ]
576
+ annotation_text2 = doc_id_and_annotation_id2annotation_text[
577
+ (doc_id2, annotation_id2)
578
+ ]
579
+ records.append(
580
+ {
581
+ "sim_score": scores[idx1],
582
+ "doc_id": doc_id1,
583
+ "other_doc_id": doc_id2,
584
+ "adu_id": annotation_id1,
585
+ "other_adu_id": annotation_id2,
586
+ "text": annotation_text1,
587
+ "other_text": annotation_text2,
588
+ }
589
  )
590
+ result_df = pd.DataFrame(records)
591
+ gr.Info(f"DataFrame shape: {result_df.shape}")
592
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  if columns is not None:
594
  result_df = result_df[columns]
595
  return result_df
embedding.py CHANGED
@@ -114,10 +114,10 @@ class HuggingfaceEmbeddingModel(EmbeddingModel):
114
  )
115
  text_ann = tok2text_ann[tok_ann]
116
 
117
- if text_ann in embeddings:
118
- logger.warning(
119
- f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
120
- )
121
  embeddings[text_ann] = embedding
122
  example_idx += 1
123
 
 
114
  )
115
  text_ann = tok2text_ann[tok_ann]
116
 
117
+ # if text_ann in embeddings:
118
+ # logger.warning(
119
+ # f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
120
+ # )
121
  embeddings[text_ann] = embedding
122
  example_idx += 1
123
 
rendering_utils.py CHANGED
@@ -4,12 +4,20 @@ from collections import defaultdict
4
  from typing import Dict, List, Optional, Union
5
 
6
  from annotation_utils import labeled_span_to_id
7
- from pytorch_ie.annotations import BinaryRelation
8
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
9
  from rendering_utils_displacy import EntityRenderer
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
 
 
 
 
 
 
13
 
14
  def render_pretty_table(
15
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
@@ -36,15 +44,27 @@ def render_displacy(
36
  **render_kwargs,
37
  ):
38
 
39
- spans = list(document.labeled_spans) + list(document.labeled_spans.predictions)
40
  spacy_doc = {
41
  "text": document.text,
42
  "ents": [
43
- {"start": entity.start, "end": entity.end, "label": entity.label} for entity in spans
 
 
 
 
 
 
 
 
44
  ],
45
  "title": None,
46
  }
47
 
 
 
 
 
48
  renderer = EntityRenderer(options=entity_options)
49
  html = renderer.render([spacy_doc], page=True, minify=True).strip()
50
 
@@ -53,10 +73,9 @@ def render_displacy(
53
  binary_relations = list(document.binary_relations) + list(
54
  document.binary_relations.predictions
55
  )
56
- sorted_entities = sorted(spans, key=lambda x: (x.start, x.end))
57
  html = inject_relation_data(
58
  html,
59
- sorted_entities=sorted_entities,
60
  binary_relations=binary_relations,
61
  additional_colors=colors_hover,
62
  )
@@ -65,7 +84,7 @@ def render_displacy(
65
 
66
  def inject_relation_data(
67
  html: str,
68
- sorted_entities,
69
  binary_relations: List[BinaryRelation],
70
  additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
71
  ) -> str:
@@ -80,11 +99,10 @@ def inject_relation_data(
80
  entity2heads[relation.tail].append((relation.head, relation.label))
81
  entity2tails[relation.head].append((relation.tail, relation.label))
82
 
 
83
  # Add unique IDs to each entity
84
  entities = soup.find_all(class_="entity")
85
- for idx, entity in enumerate(entities):
86
- annotation = sorted_entities[idx]
87
- entity["id"] = labeled_span_to_id(annotation)
88
  original_color = entity["style"].split("background:")[1].split(";")[0].strip()
89
  entity["data-color-original"] = original_color
90
  if additional_colors is not None:
@@ -92,9 +110,11 @@ def inject_relation_data(
92
  entity[f"data-color-{key}"] = (
93
  json.dumps(color) if isinstance(color, dict) else color
94
  )
95
- entity_annotation = sorted_entities[idx]
96
- # sanity check
97
- if str(entity_annotation) != entity.next:
 
 
98
  logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
99
  entity["data-label"] = entity_annotation.label
100
  entity["data-relation-tails"] = json.dumps(
 
4
  from typing import Dict, List, Optional, Union
5
 
6
  from annotation_utils import labeled_span_to_id
7
+ from pytorch_ie.annotations import BinaryRelation, LabeledSpan
8
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
9
  from rendering_utils_displacy import EntityRenderer
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
+ # adjusted from rendering_utils_displacy.TPL_ENT
14
+ TPL_ENT_WITH_ID = """
15
+ <mark class="entity" id="{id}" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
16
+ {text}
17
+ <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
18
+ </mark>
19
+ """
20
+
21
 
22
  def render_pretty_table(
23
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
 
44
  **render_kwargs,
45
  ):
46
 
47
+ labeled_spans = list(document.labeled_spans) + list(document.labeled_spans.predictions)
48
  spacy_doc = {
49
  "text": document.text,
50
  "ents": [
51
+ {
52
+ "start": labeled_span.start,
53
+ "end": labeled_span.end,
54
+ "label": labeled_span.label,
55
+ # pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
56
+ # on hover and to inject the relation data.
57
+ "params": {"id": labeled_span_to_id(labeled_span)},
58
+ }
59
+ for labeled_span in labeled_spans
60
  ],
61
  "title": None,
62
  }
63
 
64
+ # copy to avoid modifying the original options
65
+ entity_options = entity_options.copy()
66
+ # use the custom template with the entity ID
67
+ entity_options["template"] = TPL_ENT_WITH_ID
68
  renderer = EntityRenderer(options=entity_options)
69
  html = renderer.render([spacy_doc], page=True, minify=True).strip()
70
 
 
73
  binary_relations = list(document.binary_relations) + list(
74
  document.binary_relations.predictions
75
  )
 
76
  html = inject_relation_data(
77
  html,
78
+ labeled_spans=labeled_spans,
79
  binary_relations=binary_relations,
80
  additional_colors=colors_hover,
81
  )
 
84
 
85
  def inject_relation_data(
86
  html: str,
87
+ labeled_spans: List[LabeledSpan],
88
  binary_relations: List[BinaryRelation],
89
  additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
90
  ) -> str:
 
99
  entity2heads[relation.tail].append((relation.head, relation.label))
100
  entity2tails[relation.head].append((relation.tail, relation.label))
101
 
102
+ ann_id2annotation = {labeled_span_to_id(entity): entity for entity in labeled_spans}
103
  # Add unique IDs to each entity
104
  entities = soup.find_all(class_="entity")
105
+ for entity in entities:
 
 
106
  original_color = entity["style"].split("background:")[1].split(";")[0].strip()
107
  entity["data-color-original"] = original_color
108
  if additional_colors is not None:
 
110
  entity[f"data-color-{key}"] = (
111
  json.dumps(color) if isinstance(color, dict) else color
112
  )
113
+ entity_annotation = ann_id2annotation[entity["id"]]
114
+ # sanity check.
115
+ annotation_text_without_newline = str(entity_annotation).replace("\n", "")
116
+ # Just check the start, because the text has the label attached to the end
117
+ if not entity.text.startswith(annotation_text_without_newline):
118
  logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
119
  entity["data-label"] = entity_annotation.label
120
  entity["data-relation-tails"] = json.dumps(
vector_store.py CHANGED
@@ -52,12 +52,16 @@ class VectorStore(Generic[T, E], abc.ABC):
52
  def get(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> Optional[E]:
53
  return self._get(emb_id=self._get_emb_id(emb_id=emb_id, payload=payload))
54
 
 
 
 
55
  @abc.abstractmethod
56
  def _retrieve_similar(
57
  self, ref_id: str, top_k: Optional[int] = None, min_similarity: Optional[float] = None
58
  ) -> List[Tuple[T, float]]:
59
  """Retrieve IDs, payloads and the respective similarity scores with respect to the
60
- reference entry. Note that this requires the reference entry to be present in the store.
 
61
 
62
  Args:
63
  ref_id: The ID of the reference entry.
@@ -74,6 +78,8 @@ class VectorStore(Generic[T, E], abc.ABC):
74
  def retrieve_similar(
75
  self, ref_id: Optional[str] = None, ref_payload: Optional[T] = None, **kwargs
76
  ) -> List[Tuple[T, float]]:
 
 
77
  return self._retrieve_similar(
78
  ref_id=self._get_emb_id(emb_id=ref_id, payload=ref_payload), **kwargs
79
  )
@@ -244,6 +250,8 @@ class QdrantVectorStore(VectorStore[T, List[float]]):
244
  )
245
 
246
  def _get(self, emb_id: str) -> Optional[List[float]]:
 
 
247
  points = self.client.retrieve(
248
  collection_name=self.COLLECTION_NAME,
249
  ids=[self.emb_id2point_id[emb_id]],
 
52
  def get(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> Optional[E]:
53
  return self._get(emb_id=self._get_emb_id(emb_id=emb_id, payload=payload))
54
 
55
+ def has(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> bool:
56
+ return self.get(emb_id=emb_id, payload=payload) is not None
57
+
58
  @abc.abstractmethod
59
  def _retrieve_similar(
60
  self, ref_id: str, top_k: Optional[int] = None, min_similarity: Optional[float] = None
61
  ) -> List[Tuple[T, float]]:
62
  """Retrieve IDs, payloads and the respective similarity scores with respect to the
63
+ reference entry. In the case that the reference entry is not in the store itself, an empty
64
+ list will be returned.
65
 
66
  Args:
67
  ref_id: The ID of the reference entry.
 
78
  def retrieve_similar(
79
  self, ref_id: Optional[str] = None, ref_payload: Optional[T] = None, **kwargs
80
  ) -> List[Tuple[T, float]]:
81
+ if not self.has(emb_id=ref_id, payload=ref_payload):
82
+ return []
83
  return self._retrieve_similar(
84
  ref_id=self._get_emb_id(emb_id=ref_id, payload=ref_payload), **kwargs
85
  )
 
250
  )
251
 
252
  def _get(self, emb_id: str) -> Optional[List[float]]:
253
+ if emb_id not in self.emb_id2point_id:
254
+ return None
255
  points = self.client.retrieve(
256
  collection_name=self.COLLECTION_NAME,
257
  ids=[self.emb_id2point_id[emb_id]],