import json import logging from collections import defaultdict from typing import Dict, List, Optional, Tuple import gradio as gr import pandas as pd from annotation_utils import labeled_span_to_id from pytorch_ie.annotations import LabeledSpan from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from vector_store import SimpleVectorStore, VectorStore logger = logging.getLogger(__name__) def get_annotation_from_document( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, annotation_id: str, annotation_layer: str, ) -> LabeledSpan: # use predictions annotations = document[annotation_layer].predictions id2annotation = {labeled_span_to_id(annotation): annotation for annotation in annotations} annotation = id2annotation.get(annotation_id) if annotation is None: raise gr.Error( f"annotation '{annotation_id}' not found in document '{document.id}'. Available " f"annotations: {id2annotation}" ) return annotation class DocumentStore: DOCUMENT_TYPE = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions def __init__(self, vector_store: Optional[VectorStore[Tuple[str, str], List[float]]] = None): # The annotated documents. As key, we use the document id. All documents keep the embeddings # of the ADUs in the metadata. self.documents: Dict[ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions ] = {} # The vector store to efficiently retrieve similar ADUs. Can be constructed from the # documents. self.vector_store: VectorStore[Tuple[str, str], List[float]] = ( vector_store or SimpleVectorStore() ) def get_annotation( self, doc_id: str, annotation_id: str, annotation_layer: str, ) -> LabeledSpan: document = self.documents.get(doc_id) if document is None: raise gr.Error( f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}" ) return get_annotation_from_document(document, annotation_id, annotation_layer) def get_similar_adus_df( self, ref_annotation_id: str, ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, min_similarity: float, top_k: int, ) -> pd.DataFrame: similar_entries = self.vector_store.retrieve_similar( ref_id=(ref_document.id, ref_annotation_id), min_similarity=min_similarity, top_k=top_k, ) similar_annotations = [ self.get_annotation( doc_id=doc_id, annotation_id=annotation_id, annotation_layer="labeled_spans", ) for (doc_id, annotation_id), _ in similar_entries ] df = pd.DataFrame( [ # unpack the tuple (doc_id, annotation_id) to separate columns # and add the similarity score and the text of the annotation (doc_id, annotation_id, score, str(annotation)) for ((doc_id, annotation_id), score), annotation in zip( similar_entries, similar_annotations ) ], columns=["doc_id", "adu_id", "sim_score", "text"], ) return df def get_relevant_adus_df( self, ref_annotation_id: str, ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, min_similarity: float, top_k: int, relation_types: List[str], columns: List[str], ) -> pd.DataFrame: similar_entries = self.vector_store.retrieve_similar( ref_id=(ref_document.id, ref_annotation_id), min_similarity=min_similarity, top_k=top_k, ) result = [] for (doc_id, annotation_id), score in similar_entries: # skip entries from the same document if doc_id == ref_document.id: continue document = self.documents[doc_id] tail2rels = defaultdict(list) head2rels = defaultdict(list) for rel in document.binary_relations.predictions: # skip non-argumentative relations if rel.label not in relation_types: continue head2rels[rel.head].append(rel) tail2rels[rel.tail].append(rel) id2annotation = { labeled_span_to_id(annotation): annotation for annotation in document.labeled_spans.predictions } annotation = id2annotation.get(annotation_id) # note: we do not need to check if the annotation is different from the reference annotation, # because they come from different documents and we already skip entries from the same document for rel in head2rels.get(annotation, []): result.append( { "doc_id": doc_id, "reference_adu": str(annotation), "sim_score": score, "rel_score": rel.score, "relation": rel.label, "adu": str(rel.tail), } ) # define column order df = pd.DataFrame(result, columns=columns) return df def add_document( self, document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions ) -> None: try: if document.id in self.documents: gr.Warning(f"Document '{document.id}' already in index. Overwriting.") # save the processed document to the index self.documents[document.id] = document # save the embeddings to the vector store for adu_id, embedding in document.metadata["embeddings"].items(): self.vector_store.save((document.id, adu_id), embedding) except Exception as e: raise gr.Error(f"Failed to add document {document.id} to index: {e}") def add_document_from_dict(self, document_dict: dict) -> None: document = self.DOCUMENT_TYPE.fromdict(document_dict) # metadata is not automatically deserialized, so we need to set it manually document.metadata = document_dict["metadata"] self.add_document(document) def add_documents( self, documents: List[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions] ) -> None: size_before = len(self.documents) for document in documents: self.add_document(document) size_after = len(self.documents) gr.Info( f"Added {size_after - size_before} documents to the index ({size_after} documents in total)." ) def add_from_json(self, file_path: str) -> None: size_before = len(self.documents) with open(file_path, "r", encoding="utf-8") as f: processed_documents_json = json.load(f) for _, document_json in processed_documents_json.items(): self.add_document_from_dict(document_dict=document_json) size_after = len(self.documents) gr.Info( f"Added {size_after - size_before} documents to the index ({size_after} documents in total)." ) def save_to_json(self, file_path: str, **kwargs) -> None: with open(file_path, "w", encoding="utf-8") as f: json.dump(self.as_dict(), f, **kwargs) def get_document( self, doc_id: str ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: return self.documents[doc_id] def overview(self) -> pd.DataFrame: df = pd.DataFrame( [ ( doc_id, len(document.labeled_spans.predictions), len(document.binary_relations.predictions), ) for doc_id, document in self.documents.items() ], columns=["doc_id", "num_adus", "num_relations"], ) return df def as_dict(self) -> dict: return {doc_id: document.asdict() for doc_id, document in self.documents.items()}