|
import json |
|
import logging |
|
import os |
|
import shutil |
|
import tempfile |
|
import zipfile |
|
from collections import defaultdict |
|
from typing import Any, Dict, List, Optional |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
from annotation_utils import labeled_span_to_id |
|
from pytorch_ie import Annotation |
|
from pytorch_ie.documents import ( |
|
TextBasedDocument, |
|
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
) |
|
from vector_store import VectorStore |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_annotation_from_document( |
|
document: TextBasedDocument, |
|
annotation_id: str, |
|
annotation_layer: str, |
|
use_predictions: bool, |
|
) -> Annotation: |
|
"""Get an annotation from a document by its id. Note that the annotation id is constructed from |
|
the annotation itself, so it is unique within the document. |
|
|
|
Args: |
|
document: The document to get the annotation from. |
|
annotation_id: The id of the annotation. |
|
annotation_layer: The name of the annotation layer. |
|
use_predictions: Whether to use the predictions of the annotation layer. |
|
|
|
Returns: |
|
The annotation with the given id. |
|
""" |
|
|
|
annotations = document[annotation_layer] |
|
if use_predictions: |
|
annotations = annotations.predictions |
|
|
|
if annotation_layer == "labeled_spans": |
|
annotation_to_id_func = labeled_span_to_id |
|
else: |
|
raise gr.Error(f"Unknown annotation layer '{annotation_layer}'.") |
|
|
|
id2annotation = {annotation_to_id_func(annotation): annotation for annotation in annotations} |
|
annotation = id2annotation.get(annotation_id) |
|
if annotation is None: |
|
raise gr.Error( |
|
f"annotation '{annotation_id}' not found in document '{document.id}'. Available " |
|
f"annotations: {id2annotation}" |
|
) |
|
return annotation |
|
|
|
|
|
def get_related_annotation_records_from_document( |
|
document: TextBasedDocument, |
|
reference_annotation: Annotation, |
|
relation_layer_name: str, |
|
use_predictions: bool, |
|
annotation_caption: str, |
|
relation_types: Optional[List[str]] = None, |
|
additional_static_columns: Optional[Dict[str, str]] = None, |
|
) -> List[Dict[str, str]]: |
|
"""Get related annotations from a document for a given reference annotation. The related |
|
annotations are all annotations that are targets (tails) of relations with the reference |
|
annotation as source (head). |
|
|
|
Args: |
|
document: The document to get the related annotations from. |
|
reference_annotation: The reference annotation. Should be an annotation from the document. |
|
relation_layer_name: The name of the relation layer. |
|
use_predictions: Whether to use the predictions of the relation layer. |
|
annotation_caption: The caption for the related annotations in the result. |
|
relation_types: The types of relations to consider. If None, all relation types are considered. |
|
additional_static_columns: Additional static columns to add to the result. |
|
|
|
Returns: |
|
A list of dictionaries with the related annotations and additional columns. |
|
""" |
|
|
|
result = [] |
|
|
|
|
|
relation_layer = document[relation_layer_name] |
|
if use_predictions: |
|
relation_layer = relation_layer.predictions |
|
|
|
|
|
tail2rels = defaultdict(list) |
|
head2rels = defaultdict(list) |
|
for rel in relation_layer: |
|
|
|
if relation_types is not None and rel.label not in relation_types: |
|
continue |
|
head2rels[rel.head].append(rel) |
|
tail2rels[rel.tail].append(rel) |
|
|
|
|
|
|
|
for rel in head2rels.get(reference_annotation, []): |
|
result.append( |
|
{ |
|
"doc_id": document.id, |
|
f"reference_{annotation_caption}": str(reference_annotation), |
|
"rel_score": rel.score, |
|
"relation": rel.label, |
|
annotation_caption: str(rel.tail), |
|
**(additional_static_columns or {}), |
|
} |
|
) |
|
return result |
|
|
|
|
|
class DocumentStore: |
|
"""A document store that allows to add, retrieve, and search for documents and annotations. |
|
|
|
The store keeps the documents in memory and stores the embeddings of the labeled spans in a vector |
|
store to efficiently retrieve similar or related spans. |
|
|
|
Args: |
|
vector_store: The vector store to use. If None, a new SimpleVectorStore is created. |
|
document_type: The type of the documents to store. Should be a subclass of TextBasedDocument with |
|
a span and a relation layer (see below). |
|
span_layer_name: The name of the span annotation layer. This should be a valid annotation layer |
|
of type LabelSpan in the document type. |
|
relation_layer_name: The name of the argumentative relation annotation layer. This should be a |
|
valid annotation layer of type BinaryRelation in the document type. |
|
span_annotation_caption: The caption for the span annotations (e.g. in the statistical overview) |
|
relation_annotation_caption: The caption for the relation annotations (e.g. in the statistical |
|
overview) |
|
use_predictions: Whether to use the predictions of the annotation layers. If True, the predictions |
|
are used, otherwise the gold annotations are used. |
|
""" |
|
|
|
JSON_FILE_NAME = "documents.json" |
|
|
|
def __init__( |
|
self, |
|
vector_store: VectorStore[Dict[str, Any], List[float]], |
|
document_type: type[ |
|
TextBasedDocument |
|
] = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
span_layer_name: str = "labeled_spans", |
|
relation_layer_name: str = "binary_relations", |
|
span_annotation_caption: str = "span", |
|
relation_annotation_caption: str = "relation", |
|
use_predictions: bool = True, |
|
): |
|
|
|
|
|
self.documents: Dict[str, TextBasedDocument] = {} |
|
|
|
|
|
self.vector_store = vector_store |
|
|
|
self.document_type = document_type |
|
self.span_layer_name = span_layer_name |
|
self.relation_layer_name = relation_layer_name |
|
self.use_predictions = use_predictions |
|
self.layer_captions = { |
|
self.span_layer_name: span_annotation_caption, |
|
self.relation_layer_name: relation_annotation_caption, |
|
} |
|
|
|
def get_annotation( |
|
self, |
|
doc_id: str, |
|
annotation_id: str, |
|
annotation_layer: str, |
|
use_predictions: bool, |
|
) -> Annotation: |
|
document = self.documents.get(doc_id) |
|
if document is None: |
|
raise gr.Error( |
|
f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}" |
|
) |
|
return get_annotation_from_document( |
|
document, annotation_id, annotation_layer, use_predictions=use_predictions |
|
) |
|
|
|
def construct_embedding_payload(self, document: TextBasedDocument, annotation_id: str) -> dict: |
|
payload = {"doc_id": document.id, "annotation_id": annotation_id} |
|
return payload |
|
|
|
def get_similar_annotations_df( |
|
self, |
|
ref_annotation_id: str, |
|
ref_document: TextBasedDocument, |
|
annotation_layer: str, |
|
**similarity_kwargs, |
|
) -> pd.DataFrame: |
|
"""Get similar annotations from documents in the store sorted by similarity. Usually, the |
|
reference annotation is returned as the most similar annotation. |
|
|
|
Args: |
|
ref_annotation_id: The id of the reference annotation. |
|
ref_document: The document of the reference annotation. |
|
annotation_layer: The name of the annotation layer to consider. |
|
**similarity_kwargs: Additional keyword arguments that will be passed to the vector |
|
store to retrieve similar entries (see VectorStore.retrieve_similar()). |
|
|
|
Returns: |
|
A DataFrame with the similar annotations with columns: doc_id, annotation_id, sim_score, |
|
and text. |
|
""" |
|
|
|
similar_entries = self.vector_store.retrieve_similar( |
|
ref_payload=self.construct_embedding_payload(ref_document, ref_annotation_id), |
|
**similarity_kwargs, |
|
) |
|
|
|
similar_annotations = [ |
|
self.get_annotation( |
|
doc_id=payload["doc_id"], |
|
annotation_id=payload["annotation_id"], |
|
annotation_layer=annotation_layer, |
|
use_predictions=self.use_predictions, |
|
) |
|
for _, payload, _ in similar_entries |
|
] |
|
df = pd.DataFrame( |
|
[ |
|
|
|
|
|
(payload["doc_id"], payload["annotation_id"], score, str(annotation)) |
|
for (_, payload, score), annotation in zip(similar_entries, similar_annotations) |
|
], |
|
columns=["doc_id", "annotation_id", "sim_score", "text"], |
|
) |
|
|
|
return df |
|
|
|
def get_related_annotations_from_other_documents_df( |
|
self, |
|
ref_annotation_id: str, |
|
ref_document: TextBasedDocument, |
|
min_similarity: float, |
|
top_k: int, |
|
relation_types: List[str], |
|
columns: List[str], |
|
) -> pd.DataFrame: |
|
"""Get related annotations from documents in the store for a given reference annotation. |
|
First, similar annotations are retrieved from the vector store. Then, annotations that are |
|
linked to them via relations are returned. Only annotations from other documents are |
|
considered. |
|
|
|
Args: |
|
ref_annotation_id: The id of the reference annotation. |
|
ref_document: The document of the reference annotation. |
|
min_similarity: The minimum similarity score to consider. |
|
top_k: The number of related annotations to return. |
|
relation_types: The types of relations to consider. |
|
columns: The columns to include in the result DataFrame. |
|
|
|
Returns: |
|
A DataFrame with the columns that contain: the related annotation, the relation type, |
|
the similar annotation, the similarity score, and the relation score. |
|
""" |
|
|
|
similar_entries = self.vector_store.retrieve_similar( |
|
ref_payload=self.construct_embedding_payload(ref_document, ref_annotation_id), |
|
min_similarity=min_similarity, |
|
top_k=top_k, |
|
) |
|
result = [] |
|
for _, payload, score in similar_entries: |
|
doc_id = payload["doc_id"] |
|
|
|
if doc_id == ref_document.id: |
|
continue |
|
document = self.documents[doc_id] |
|
reference_annotation = get_annotation_from_document( |
|
document=document, |
|
annotation_id=payload["annotation_id"], |
|
annotation_layer=self.span_layer_name, |
|
use_predictions=self.use_predictions, |
|
) |
|
|
|
new_entries = get_related_annotation_records_from_document( |
|
document=document, |
|
reference_annotation=reference_annotation, |
|
relation_types=relation_types, |
|
relation_layer_name=self.relation_layer_name, |
|
use_predictions=self.use_predictions, |
|
annotation_caption=self.layer_captions[self.span_layer_name], |
|
additional_static_columns={"sim_score": str(score)}, |
|
) |
|
result.extend(new_entries) |
|
|
|
|
|
df = pd.DataFrame(result, columns=columns) |
|
return df |
|
|
|
def add_document(self, document: TextBasedDocument) -> None: |
|
try: |
|
if document.id in self.documents: |
|
gr.Warning(f"Document '{document.id}' already in index. Overwriting.") |
|
|
|
|
|
document = document.copy() |
|
|
|
|
|
self.documents[document.id] = document |
|
|
|
|
|
if "embeddings" in document.metadata: |
|
for annotation_id, embedding in document.metadata["embeddings"].items(): |
|
payload = self.construct_embedding_payload(document, annotation_id) |
|
self.vector_store.add(payload=payload, embedding=embedding) |
|
|
|
document.metadata = { |
|
k: v for k, v in document.metadata.items() if k != "embeddings" |
|
} |
|
|
|
except Exception as e: |
|
raise gr.Error(f"Failed to add document {document.id} to index: {e}") |
|
|
|
def add_document_from_dict(self, document_dict: dict) -> None: |
|
document = self.document_type.fromdict(document_dict) |
|
self.add_document(document) |
|
|
|
def add_documents(self, documents: List[TextBasedDocument]) -> None: |
|
for document in documents: |
|
self.add_document(document) |
|
gr.Info( |
|
f"Added {len(documents)} documents to the index ({len(self.documents)} documents in total)." |
|
) |
|
|
|
def add_documents_from_json(self, file_path: str) -> None: |
|
with open(file_path, "r", encoding="utf-8") as f: |
|
documents_json = json.load(f) |
|
for _, document_json in documents_json.items(): |
|
self.add_document_from_dict(document_dict=document_json) |
|
gr.Info( |
|
f"Added {len(documents_json)} documents to the index ({len(self.documents)} documents in total)." |
|
) |
|
|
|
def add_documents_from_zip(self, file_path: str) -> None: |
|
temp_dir = os.path.join(tempfile.gettempdir(), "document_store") |
|
|
|
if os.path.exists(temp_dir): |
|
shutil.rmtree(temp_dir) |
|
with zipfile.ZipFile(file_path, "r") as zipf: |
|
|
|
zipf.extractall(temp_dir) |
|
json_file_path = os.path.join(temp_dir, self.JSON_FILE_NAME) |
|
self.add_documents_from_json(json_file_path) |
|
|
|
self.vector_store.load_from_directory(temp_dir) |
|
|
|
shutil.rmtree(temp_dir) |
|
|
|
def add_documents_from_file(self, file_path: str) -> None: |
|
if file_path.endswith(".json"): |
|
self.add_documents_from_json(file_path) |
|
elif file_path.endswith(".zip"): |
|
self.add_documents_from_zip(file_path) |
|
else: |
|
raise gr.Error(f"Unsupported file format: {file_path}") |
|
|
|
def save_to_json(self, file_path: str, include_embeddings: bool = True, **kwargs) -> None: |
|
with open(file_path, "w", encoding="utf-8") as f: |
|
json.dump(self.as_dict(include_embeddings=include_embeddings), f, **kwargs) |
|
|
|
def save_to_zip(self, file_path: str, **kwargs) -> None: |
|
|
|
temp_dir = os.path.join(tempfile.gettempdir(), "document_store") |
|
|
|
if os.path.exists(temp_dir): |
|
shutil.rmtree(temp_dir) |
|
os.makedirs(temp_dir) |
|
temp_file_path = os.path.join(temp_dir, self.JSON_FILE_NAME) |
|
self.save_to_json(temp_file_path, include_embeddings=False, **kwargs) |
|
self.vector_store.save_to_directory(temp_dir) |
|
|
|
with zipfile.ZipFile(file_path, "w") as zipf: |
|
for root, _, files in os.walk(temp_dir): |
|
for file in files: |
|
zipf.write( |
|
os.path.join(root, file), |
|
os.path.relpath(os.path.join(root, file), temp_dir), |
|
) |
|
|
|
shutil.rmtree(temp_dir) |
|
|
|
def save_to_file(self, file_path: str, **kwargs) -> None: |
|
if file_path.endswith(".json"): |
|
self.save_to_json(file_path, **kwargs) |
|
elif file_path.endswith(".zip"): |
|
self.save_to_zip(file_path, **kwargs) |
|
else: |
|
raise gr.Error(f"Unsupported file format: {file_path}") |
|
|
|
def get_document(self, doc_id: str, with_embeddings: bool = False) -> TextBasedDocument: |
|
document = self.documents[doc_id] |
|
if not with_embeddings: |
|
return document |
|
|
|
|
|
|
|
document = document.copy() |
|
|
|
embeddings = {} |
|
for annotation in document[self.span_layer_name].predictions: |
|
annotation_id = labeled_span_to_id(annotation) |
|
payload = self.construct_embedding_payload(document, annotation_id) |
|
embedding = self.vector_store.get(payload=payload) |
|
if embedding is not None: |
|
embeddings[annotation_id] = embedding |
|
document.metadata["embeddings"] = embeddings |
|
|
|
return document |
|
|
|
def overview(self) -> pd.DataFrame: |
|
rows = [] |
|
for doc_id, document in self.documents.items(): |
|
layers = { |
|
caption: document[layer_name] |
|
for layer_name, caption in self.layer_captions.items() |
|
} |
|
if self.use_predictions: |
|
layers = {caption: layer.predictions for caption, layer in layers.items()} |
|
layer_sizes = {f"num_{caption}s": len(layer) for caption, layer in layers.items()} |
|
rows.append({"doc_id": doc_id, **layer_sizes}) |
|
df = pd.DataFrame(rows) |
|
return df |
|
|
|
def as_dict(self, include_embeddings: bool = True) -> dict: |
|
result = {} |
|
for doc_id, document in self.documents.items(): |
|
doc_dict = document.asdict() |
|
if not include_embeddings and "embeddings" in (doc_dict.get("metadata") or {}): |
|
doc_dict["metadata"] = { |
|
k: v for k, v in doc_dict["metadata"].items() if k != "embeddings" |
|
} |
|
result[doc_id] = doc_dict |
|
return result |
|
|