|
import json |
|
import logging |
|
from collections import defaultdict |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
from annotation_utils import labeled_span_to_id |
|
from pytorch_ie.annotations import LabeledSpan |
|
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions |
|
from vector_store import SimpleVectorStore, VectorStore |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_annotation_from_document( |
|
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
annotation_id: str, |
|
annotation_layer: str, |
|
) -> LabeledSpan: |
|
|
|
annotations = document[annotation_layer].predictions |
|
id2annotation = {labeled_span_to_id(annotation): annotation for annotation in annotations} |
|
annotation = id2annotation.get(annotation_id) |
|
if annotation is None: |
|
raise gr.Error( |
|
f"annotation '{annotation_id}' not found in document '{document.id}'. Available " |
|
f"annotations: {id2annotation}" |
|
) |
|
return annotation |
|
|
|
|
|
class DocumentStore: |
|
|
|
DOCUMENT_TYPE = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions |
|
|
|
def __init__(self, vector_store: Optional[VectorStore[Tuple[str, str], List[float]]] = None): |
|
|
|
|
|
self.documents: Dict[ |
|
str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions |
|
] = {} |
|
|
|
|
|
self.vector_store: VectorStore[Tuple[str, str], List[float]] = ( |
|
vector_store or SimpleVectorStore() |
|
) |
|
|
|
def get_annotation( |
|
self, |
|
doc_id: str, |
|
annotation_id: str, |
|
annotation_layer: str, |
|
) -> LabeledSpan: |
|
document = self.documents.get(doc_id) |
|
if document is None: |
|
raise gr.Error( |
|
f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}" |
|
) |
|
return get_annotation_from_document(document, annotation_id, annotation_layer) |
|
|
|
def get_similar_adus_df( |
|
self, |
|
ref_annotation_id: str, |
|
ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
min_similarity: float, |
|
top_k: int, |
|
) -> pd.DataFrame: |
|
similar_entries = self.vector_store.retrieve_similar( |
|
ref_id=(ref_document.id, ref_annotation_id), |
|
min_similarity=min_similarity, |
|
top_k=top_k, |
|
) |
|
|
|
similar_annotations = [ |
|
self.get_annotation( |
|
doc_id=doc_id, |
|
annotation_id=annotation_id, |
|
annotation_layer="labeled_spans", |
|
) |
|
for (doc_id, annotation_id), _ in similar_entries |
|
] |
|
df = pd.DataFrame( |
|
[ |
|
|
|
|
|
(doc_id, annotation_id, score, str(annotation)) |
|
for ((doc_id, annotation_id), score), annotation in zip( |
|
similar_entries, similar_annotations |
|
) |
|
], |
|
columns=["doc_id", "adu_id", "sim_score", "text"], |
|
) |
|
|
|
return df |
|
|
|
def get_relevant_adus_df( |
|
self, |
|
ref_annotation_id: str, |
|
ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
min_similarity: float, |
|
top_k: int, |
|
relation_types: List[str], |
|
columns: List[str], |
|
) -> pd.DataFrame: |
|
similar_entries = self.vector_store.retrieve_similar( |
|
ref_id=(ref_document.id, ref_annotation_id), |
|
min_similarity=min_similarity, |
|
top_k=top_k, |
|
) |
|
result = [] |
|
for (doc_id, annotation_id), score in similar_entries: |
|
|
|
if doc_id == ref_document.id: |
|
continue |
|
document = self.documents[doc_id] |
|
tail2rels = defaultdict(list) |
|
head2rels = defaultdict(list) |
|
for rel in document.binary_relations.predictions: |
|
|
|
if rel.label not in relation_types: |
|
continue |
|
head2rels[rel.head].append(rel) |
|
tail2rels[rel.tail].append(rel) |
|
|
|
id2annotation = { |
|
labeled_span_to_id(annotation): annotation |
|
for annotation in document.labeled_spans.predictions |
|
} |
|
annotation = id2annotation.get(annotation_id) |
|
|
|
|
|
for rel in head2rels.get(annotation, []): |
|
result.append( |
|
{ |
|
"doc_id": doc_id, |
|
"reference_adu": str(annotation), |
|
"sim_score": score, |
|
"rel_score": rel.score, |
|
"relation": rel.label, |
|
"adu": str(rel.tail), |
|
} |
|
) |
|
|
|
|
|
df = pd.DataFrame(result, columns=columns) |
|
return df |
|
|
|
def add_document( |
|
self, document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions |
|
) -> None: |
|
try: |
|
if document.id in self.documents: |
|
gr.Warning(f"Document '{document.id}' already in index. Overwriting.") |
|
|
|
|
|
self.documents[document.id] = document |
|
|
|
|
|
for adu_id, embedding in document.metadata["embeddings"].items(): |
|
self.vector_store.save((document.id, adu_id), embedding) |
|
|
|
except Exception as e: |
|
raise gr.Error(f"Failed to add document {document.id} to index: {e}") |
|
|
|
def add_document_from_dict(self, document_dict: dict) -> None: |
|
document = self.DOCUMENT_TYPE.fromdict(document_dict) |
|
|
|
document.metadata = document_dict["metadata"] |
|
self.add_document(document) |
|
|
|
def add_documents( |
|
self, documents: List[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions] |
|
) -> None: |
|
size_before = len(self.documents) |
|
for document in documents: |
|
self.add_document(document) |
|
size_after = len(self.documents) |
|
gr.Info( |
|
f"Added {size_after - size_before} documents to the index ({size_after} documents in total)." |
|
) |
|
|
|
def add_from_json(self, file_path: str) -> None: |
|
size_before = len(self.documents) |
|
with open(file_path, "r", encoding="utf-8") as f: |
|
processed_documents_json = json.load(f) |
|
for _, document_json in processed_documents_json.items(): |
|
self.add_document_from_dict(document_dict=document_json) |
|
size_after = len(self.documents) |
|
gr.Info( |
|
f"Added {size_after - size_before} documents to the index ({size_after} documents in total)." |
|
) |
|
|
|
def save_to_json(self, file_path: str, **kwargs) -> None: |
|
with open(file_path, "w", encoding="utf-8") as f: |
|
json.dump(self.as_dict(), f, **kwargs) |
|
|
|
def get_document( |
|
self, doc_id: str |
|
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: |
|
return self.documents[doc_id] |
|
|
|
def overview(self) -> pd.DataFrame: |
|
df = pd.DataFrame( |
|
[ |
|
( |
|
doc_id, |
|
len(document.labeled_spans.predictions), |
|
len(document.binary_relations.predictions), |
|
) |
|
for doc_id, document in self.documents.items() |
|
], |
|
columns=["doc_id", "num_adus", "num_relations"], |
|
) |
|
return df |
|
|
|
def as_dict(self) -> dict: |
|
return {doc_id: document.asdict() for doc_id, document in self.documents.items()} |
|
|