import logging from collections import defaultdict from typing import Dict, List, Optional, Tuple import gradio as gr import pandas as pd from pie_modules.document.processing import tokenize_document from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from pytorch_ie import Pipeline from pytorch_ie.annotations import LabeledSpan, Span from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from rendering_utils import labeled_span_to_id from transformers import PreTrainedModel, PreTrainedTokenizer from vector_store import SimpleVectorStore, VectorStore logger = logging.getLogger(__name__) def _embed_text_annotations( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, text_layer_name: str, ) -> Dict[Span, List[float]]: # to not modify the original document document = document.copy() # tokenize_document does not yet consider predictions, so we need to add them manually document[text_layer_name].extend(document[text_layer_name].predictions.clear()) added_annotations = [] tokenizer_kwargs = { "max_length": 512, "stride": 64, "truncation": True, "return_overflowing_tokens": True, } tokenized_documents = tokenize_document( document, tokenizer=tokenizer, result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, partition_layer="labeled_partitions", added_annotations=added_annotations, strict_span_conversion=False, **tokenizer_kwargs, ) # just tokenize again to get tensors in the correct format for the model model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs) # this is added when using return_overflowing_tokens=True, but the model does not accept it model_inputs.pop("overflow_to_sample_mapping", None) assert len(model_inputs.encodings) == len(tokenized_documents) model_output = model(**model_inputs) # get embeddings for all text annotations embeddings = {} for batch_idx in range(len(model_output.last_hidden_state)): text2tok_ann = added_annotations[batch_idx][text_layer_name] tok2text_ann = {v: k for k, v in text2tok_ann.items()} for tok_ann in tokenized_documents[batch_idx].labeled_spans: # skip "empty" annotations if tok_ann.start == tok_ann.end: continue # use the max pooling strategy to get a single embedding for the annotation text embedding = model_output.last_hidden_state[batch_idx, tok_ann.start : tok_ann.end].max( dim=0 )[0] text_ann = tok2text_ann[tok_ann] if text_ann in embeddings: logger.warning( f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)" ) embeddings[text_ann] = embedding return embeddings def _annotate( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, pipeline: Pipeline, embedding_model: Optional[PreTrainedModel] = None, embedding_tokenizer: Optional[PreTrainedTokenizer] = None, ) -> None: # execute prediction pipeline pipeline(document) if embedding_model is not None and embedding_tokenizer is not None: adu_embeddings = _embed_text_annotations( document=document, model=embedding_model, tokenizer=embedding_tokenizer, text_layer_name="labeled_spans", ) # convert keys to str because JSON keys must be strings adu_embeddings_dict = { labeled_span_to_id(k): v.detach().tolist() for k, v in adu_embeddings.items() } document.metadata["embeddings"] = adu_embeddings_dict else: gr.Warning( "No embedding model provided. Skipping embedding extraction. You can load an embedding " "model in the 'Model Configuration' section." ) def create_and_annotate_document( text: str, doc_id: str, models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]], ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided text, annotate it, and add it to the index. Parameters: text: The text to process. doc_id: The ID of the document. models: A tuple containing the prediction pipeline and the embedding model and tokenizer. Returns: The processed document. """ try: document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( id=doc_id, text=text, metadata={} ) # add single partition from the whole text (the model only considers text in partitions) document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text")) # annotate the document _annotate( document=document, pipeline=models[0], embedding_model=models[1], embedding_tokenizer=models[2], ) return document except Exception as e: raise gr.Error(f"Failed to process text: {e}") def get_annotation_from_document( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, annotation_id: str, annotation_layer: str, ) -> LabeledSpan: # use predictions annotations = document[annotation_layer].predictions id2annotation = {labeled_span_to_id(annotation): annotation for annotation in annotations} annotation = id2annotation.get(annotation_id) if annotation is None: raise gr.Error( f"annotation '{annotation_id}' not found in document '{document.id}'. Available " f"annotations: {id2annotation}" ) return annotation class DocumentStore: DOCUMENT_TYPE = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions def __init__(self, vector_store: Optional[VectorStore[Tuple[str, str]]] = None): self.documents = {} self.vector_store = vector_store or SimpleVectorStore() def get_annotation( self, doc_id: str, annotation_id: str, annotation_layer: str, ) -> LabeledSpan: document = self.documents.get(doc_id) if document is None: raise gr.Error( f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}" ) return get_annotation_from_document(document, annotation_id, annotation_layer) def get_similar_adus_df( self, ref_annotation_id: str, ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, min_similarity: float, top_k: int, ) -> pd.DataFrame: similar_entries = self.vector_store.retrieve_similar( ref_id=(ref_document.id, ref_annotation_id), min_similarity=min_similarity, top_k=top_k, ) similar_annotations = [ self.get_annotation( doc_id=doc_id, annotation_id=annotation_id, annotation_layer="labeled_spans", ) for (doc_id, annotation_id), _ in similar_entries ] df = pd.DataFrame( [ # unpack the tuple (doc_id, annotation_id) to separate columns # and add the similarity score and the text of the annotation (doc_id, annotation_id, score, str(annotation)) for ((doc_id, annotation_id), score), annotation in zip( similar_entries, similar_annotations ) ], columns=["doc_id", "adu_id", "sim_score", "text"], ) return df def get_relevant_adus_df( self, ref_annotation_id: str, ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, min_similarity: float, top_k: int, relation_types: List[str], columns: List[str], ) -> pd.DataFrame: similar_entries = self.vector_store.retrieve_similar( ref_id=(ref_document.id, ref_annotation_id), min_similarity=min_similarity, top_k=top_k, ) result = [] for (doc_id, annotation_id), score in similar_entries: # skip entries from the same document if doc_id == ref_document.id: continue document = self.documents[doc_id] tail2rels = defaultdict(list) head2rels = defaultdict(list) for rel in document.binary_relations.predictions: # skip non-argumentative relations if rel.label not in relation_types: continue head2rels[rel.head].append(rel) tail2rels[rel.tail].append(rel) id2annotation = { labeled_span_to_id(annotation): annotation for annotation in document.labeled_spans.predictions } annotation = id2annotation.get(annotation_id) # note: we do not need to check if the annotation is different from the reference annotation, # because they come from different documents and we already skip entries from the same document for rel in head2rels.get(annotation, []): result.append( { "doc_id": doc_id, "reference_adu": str(annotation), "sim_score": score, "rel_score": rel.score, "relation": rel.label, "adu": str(rel.tail), } ) # define column order df = pd.DataFrame(result, columns=columns) return df def add_document( self, document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions ) -> None: try: if document.id in self.documents: gr.Warning(f"Document '{document.id}' already in index. Overwriting.") # save the processed document to the index self.documents[document.id] = document # save the embeddings to the vector store for adu_id, embedding in document.metadata["embeddings"].items(): self.vector_store.save((document.id, adu_id), embedding) gr.Info( f"Added document {document.id} to index (index contains {len(self.documents)} " f"documents and {len(self.vector_store)} embeddings)." ) except Exception as e: raise gr.Error(f"Failed to add document {document.id} to index: {e}") def add_document_from_dict(self, document_dict: dict) -> None: document = self.DOCUMENT_TYPE.fromdict(document_dict) # metadata is not automatically deserialized, so we need to set it manually document.metadata = document_dict["metadata"] self.add_document(document) def add_documents( self, documents: List[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions] ) -> None: for document in documents: self.add_document(document) def get_document( self, doc_id: str ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: return self.documents[doc_id] def overview(self) -> pd.DataFrame: df = pd.DataFrame( [ ( doc_id, len(document.labeled_spans.predictions), len(document.binary_relations.predictions), ) for doc_id, document in self.documents.items() ], columns=["doc_id", "num_adus", "num_relations"], ) return df def as_dict(self) -> dict: return {doc_id: document.asdict() for doc_id, document in self.documents.items()}