sam-pointer-bart-base-v0.3 / document_store.py
ArneBinder's picture
Upload 9 files
86277c0 verified
raw
history blame
8.39 kB
import json
import logging
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
import gradio as gr
import pandas as pd
from annotation_utils import labeled_span_to_id
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from vector_store import SimpleVectorStore, VectorStore
logger = logging.getLogger(__name__)
def get_annotation_from_document(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
annotation_id: str,
annotation_layer: str,
) -> LabeledSpan:
# use predictions
annotations = document[annotation_layer].predictions
id2annotation = {labeled_span_to_id(annotation): annotation for annotation in annotations}
annotation = id2annotation.get(annotation_id)
if annotation is None:
raise gr.Error(
f"annotation '{annotation_id}' not found in document '{document.id}'. Available "
f"annotations: {id2annotation}"
)
return annotation
class DocumentStore:
DOCUMENT_TYPE = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
def __init__(self, vector_store: Optional[VectorStore[Tuple[str, str], List[float]]] = None):
# The annotated documents. As key, we use the document id. All documents keep the embeddings
# of the ADUs in the metadata.
self.documents: Dict[
str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
] = {}
# The vector store to efficiently retrieve similar ADUs. Can be constructed from the
# documents.
self.vector_store: VectorStore[Tuple[str, str], List[float]] = (
vector_store or SimpleVectorStore()
)
def get_annotation(
self,
doc_id: str,
annotation_id: str,
annotation_layer: str,
) -> LabeledSpan:
document = self.documents.get(doc_id)
if document is None:
raise gr.Error(
f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}"
)
return get_annotation_from_document(document, annotation_id, annotation_layer)
def get_similar_adus_df(
self,
ref_annotation_id: str,
ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
min_similarity: float,
top_k: int,
) -> pd.DataFrame:
similar_entries = self.vector_store.retrieve_similar(
ref_id=(ref_document.id, ref_annotation_id),
min_similarity=min_similarity,
top_k=top_k,
)
similar_annotations = [
self.get_annotation(
doc_id=doc_id,
annotation_id=annotation_id,
annotation_layer="labeled_spans",
)
for (doc_id, annotation_id), _ in similar_entries
]
df = pd.DataFrame(
[
# unpack the tuple (doc_id, annotation_id) to separate columns
# and add the similarity score and the text of the annotation
(doc_id, annotation_id, score, str(annotation))
for ((doc_id, annotation_id), score), annotation in zip(
similar_entries, similar_annotations
)
],
columns=["doc_id", "adu_id", "sim_score", "text"],
)
return df
def get_relevant_adus_df(
self,
ref_annotation_id: str,
ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
min_similarity: float,
top_k: int,
relation_types: List[str],
columns: List[str],
) -> pd.DataFrame:
similar_entries = self.vector_store.retrieve_similar(
ref_id=(ref_document.id, ref_annotation_id),
min_similarity=min_similarity,
top_k=top_k,
)
result = []
for (doc_id, annotation_id), score in similar_entries:
# skip entries from the same document
if doc_id == ref_document.id:
continue
document = self.documents[doc_id]
tail2rels = defaultdict(list)
head2rels = defaultdict(list)
for rel in document.binary_relations.predictions:
# skip non-argumentative relations
if rel.label not in relation_types:
continue
head2rels[rel.head].append(rel)
tail2rels[rel.tail].append(rel)
id2annotation = {
labeled_span_to_id(annotation): annotation
for annotation in document.labeled_spans.predictions
}
annotation = id2annotation.get(annotation_id)
# note: we do not need to check if the annotation is different from the reference annotation,
# because they come from different documents and we already skip entries from the same document
for rel in head2rels.get(annotation, []):
result.append(
{
"doc_id": doc_id,
"reference_adu": str(annotation),
"sim_score": score,
"rel_score": rel.score,
"relation": rel.label,
"adu": str(rel.tail),
}
)
# define column order
df = pd.DataFrame(result, columns=columns)
return df
def add_document(
self, document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
) -> None:
try:
if document.id in self.documents:
gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
# save the processed document to the index
self.documents[document.id] = document
# save the embeddings to the vector store
for adu_id, embedding in document.metadata["embeddings"].items():
self.vector_store.save((document.id, adu_id), embedding)
except Exception as e:
raise gr.Error(f"Failed to add document {document.id} to index: {e}")
def add_document_from_dict(self, document_dict: dict) -> None:
document = self.DOCUMENT_TYPE.fromdict(document_dict)
# metadata is not automatically deserialized, so we need to set it manually
document.metadata = document_dict["metadata"]
self.add_document(document)
def add_documents(
self, documents: List[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]
) -> None:
size_before = len(self.documents)
for document in documents:
self.add_document(document)
size_after = len(self.documents)
gr.Info(
f"Added {size_after - size_before} documents to the index ({size_after} documents in total)."
)
def add_from_json(self, file_path: str) -> None:
size_before = len(self.documents)
with open(file_path, "r", encoding="utf-8") as f:
processed_documents_json = json.load(f)
for _, document_json in processed_documents_json.items():
self.add_document_from_dict(document_dict=document_json)
size_after = len(self.documents)
gr.Info(
f"Added {size_after - size_before} documents to the index ({size_after} documents in total)."
)
def save_to_json(self, file_path: str, **kwargs) -> None:
with open(file_path, "w", encoding="utf-8") as f:
json.dump(self.as_dict(), f, **kwargs)
def get_document(
self, doc_id: str
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
return self.documents[doc_id]
def overview(self) -> pd.DataFrame:
df = pd.DataFrame(
[
(
doc_id,
len(document.labeled_spans.predictions),
len(document.binary_relations.predictions),
)
for doc_id, document in self.documents.items()
],
columns=["doc_id", "num_adus", "num_relations"],
)
return df
def as_dict(self) -> dict:
return {doc_id: document.asdict() for doc_id, document in self.documents.items()}