|
import json |
|
import logging |
|
from typing import Iterable, Optional, Sequence, Union |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
from pie_datasets import Dataset, IterableDataset, load_dataset |
|
from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger |
|
from pytorch_ie import Pipeline |
|
from pytorch_ie.annotations import LabeledSpan |
|
from pytorch_ie.auto import AutoPipeline |
|
from pytorch_ie.documents import ( |
|
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, |
|
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
) |
|
from typing_extensions import Protocol |
|
|
|
from src.langchain_modules import DocumentAwareSpanRetriever |
|
from src.langchain_modules.span_retriever import ( |
|
DocumentAwareSpanRetrieverWithRelations, |
|
_parse_config, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def annotate_document( |
|
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
argumentation_model: Pipeline, |
|
handle_parts_of_same: bool = False, |
|
) -> Union[ |
|
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, |
|
]: |
|
"""Annotate a document with the provided pipeline. |
|
|
|
Args: |
|
document: The document to annotate. |
|
argumentation_model: The pipeline to use for annotation. |
|
handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span. |
|
""" |
|
|
|
|
|
argumentation_model(document) |
|
|
|
if handle_parts_of_same: |
|
merger = SpansViaRelationMerger( |
|
relation_layer="binary_relations", |
|
link_relation_label="parts_of_same", |
|
create_multi_spans=True, |
|
result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, |
|
result_field_mapping={ |
|
"labeled_spans": "labeled_multi_spans", |
|
"binary_relations": "binary_relations", |
|
"labeled_partitions": "labeled_partitions", |
|
}, |
|
) |
|
document = merger(document) |
|
|
|
return document |
|
|
|
|
|
def create_document( |
|
text: str, doc_id: str, split_regex: Optional[str] = None |
|
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: |
|
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided |
|
text. |
|
|
|
Parameters: |
|
text: The text to process. |
|
doc_id: The ID of the document. |
|
split_regex: A regular expression pattern to use for splitting the text into partitions. |
|
|
|
Returns: |
|
The processed document. |
|
""" |
|
|
|
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( |
|
id=doc_id, text=text, metadata={} |
|
) |
|
if split_regex is not None: |
|
partitioner = RegexPartitioner( |
|
pattern=split_regex, partition_layer_name="labeled_partitions" |
|
) |
|
document = partitioner(document) |
|
else: |
|
|
|
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text")) |
|
return document |
|
|
|
|
|
def add_annotated_pie_documents( |
|
retriever: DocumentAwareSpanRetriever, |
|
pie_documents: Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions], |
|
use_predicted_annotations: bool, |
|
verbose: bool = False, |
|
) -> None: |
|
if verbose: |
|
gr.Info(f"Create span embeddings for {len(pie_documents)} documents...") |
|
num_docs_before = len(retriever.docstore) |
|
retriever.add_pie_documents(pie_documents, use_predicted_annotations=use_predicted_annotations) |
|
|
|
num_overwritten_docs = num_docs_before + len(pie_documents) - len(retriever.docstore) |
|
|
|
if num_overwritten_docs > 0: |
|
gr.Warning(f"{num_overwritten_docs} documents were overwritten.") |
|
|
|
|
|
def process_texts( |
|
texts: Iterable[str], |
|
doc_ids: Iterable[str], |
|
argumentation_model: Pipeline, |
|
retriever: DocumentAwareSpanRetriever, |
|
split_regex_escaped: Optional[str], |
|
handle_parts_of_same: bool = False, |
|
verbose: bool = False, |
|
) -> None: |
|
|
|
if len(set(doc_ids)) != len(list(doc_ids)): |
|
raise gr.Error("Document IDs must be unique.") |
|
pie_documents = [ |
|
create_document(text=text, doc_id=doc_id, split_regex=split_regex_escaped) |
|
for text, doc_id in zip(texts, doc_ids) |
|
] |
|
if verbose: |
|
gr.Info(f"Annotate {len(pie_documents)} documents...") |
|
pie_documents = [ |
|
annotate_document( |
|
document=pie_document, |
|
argumentation_model=argumentation_model, |
|
handle_parts_of_same=handle_parts_of_same, |
|
) |
|
for pie_document in pie_documents |
|
] |
|
add_annotated_pie_documents( |
|
retriever=retriever, |
|
pie_documents=pie_documents, |
|
use_predicted_annotations=True, |
|
verbose=verbose, |
|
) |
|
|
|
|
|
def add_annotated_pie_documents_from_dataset( |
|
retriever: DocumentAwareSpanRetriever, verbose: bool = False, **load_dataset_kwargs |
|
) -> None: |
|
try: |
|
gr.Info( |
|
"Loading PIE dataset with parameters:\n" + json.dumps(load_dataset_kwargs, indent=2) |
|
) |
|
dataset = load_dataset(**load_dataset_kwargs) |
|
if not isinstance(dataset, (Dataset, IterableDataset)): |
|
raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.") |
|
dataset_converted = dataset.to_document_type( |
|
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions |
|
) |
|
add_annotated_pie_documents( |
|
retriever=retriever, |
|
pie_documents=dataset_converted, |
|
use_predicted_annotations=False, |
|
verbose=verbose, |
|
) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to load dataset: {e}") |
|
|
|
|
|
def load_argumentation_model( |
|
model_name: str, |
|
revision: Optional[str] = None, |
|
device: str = "cpu", |
|
) -> Pipeline: |
|
try: |
|
|
|
if device == "cuda": |
|
pipeline_device = 0 |
|
elif device.startswith("cuda:"): |
|
pipeline_device = int(device.split(":")[1]) |
|
elif device == "cpu": |
|
pipeline_device = -1 |
|
else: |
|
raise gr.Error(f"Invalid device: {device}") |
|
|
|
model = AutoPipeline.from_pretrained( |
|
model_name, |
|
device=pipeline_device, |
|
num_workers=0, |
|
taskmodule_kwargs=dict(revision=revision), |
|
model_kwargs=dict(revision=revision), |
|
) |
|
gr.Info( |
|
f"Loaded argumentation model: model_name={model_name}, revision={revision}, device={device}" |
|
) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to load argumentation model: {e}") |
|
|
|
return model |
|
|
|
|
|
def load_retriever( |
|
retriever_config: str, |
|
config_format: str, |
|
device: str = "cpu", |
|
previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None, |
|
) -> DocumentAwareSpanRetrieverWithRelations: |
|
try: |
|
retriever_config = _parse_config(retriever_config, format=config_format) |
|
|
|
retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device |
|
result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config) |
|
|
|
if previous_retriever is not None: |
|
|
|
all_doc_ids = list(previous_retriever.docstore.yield_keys()) |
|
gr.Info(f"Storing {len(all_doc_ids)} documents from previous retriever...") |
|
all_docs = previous_retriever.docstore.mget(all_doc_ids) |
|
result.docstore.mset([(doc.id, doc) for doc in all_docs]) |
|
|
|
all_span_ids = list(previous_retriever.vectorstore.yield_keys()) |
|
all_spans = previous_retriever.vectorstore.mget(all_span_ids) |
|
result.vectorstore.mset([(span.id, span) for span in all_spans]) |
|
|
|
gr.Info("Retriever loaded successfully.") |
|
return result |
|
except Exception as e: |
|
raise gr.Error(f"Failed to load retriever: {e}") |
|
|
|
|
|
def retrieve_similar_spans( |
|
retriever: DocumentAwareSpanRetriever, |
|
query_span_id: str, |
|
**kwargs, |
|
) -> pd.DataFrame: |
|
if not query_span_id.strip(): |
|
raise gr.Error("No query span selected.") |
|
try: |
|
retrieval_result = retriever.invoke(input=query_span_id, **kwargs) |
|
records = [] |
|
for similar_span_doc in retrieval_result: |
|
pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc) |
|
span_ann = metadata["attached_span"] |
|
records.append( |
|
{ |
|
"doc_id": pie_doc.id, |
|
"span_id": similar_span_doc.id, |
|
"score": metadata["relevance_score"], |
|
"label": span_ann.label, |
|
"text": str(span_ann), |
|
} |
|
) |
|
return ( |
|
pd.DataFrame(records, columns=["doc_id", "score", "label", "text", "span_id"]) |
|
.sort_values(by="score", ascending=False) |
|
.round(3) |
|
) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to retrieve similar ADUs: {e}") |
|
|
|
|
|
def retrieve_relevant_spans( |
|
retriever: DocumentAwareSpanRetriever, |
|
query_span_id: str, |
|
relation_label_mapping: Optional[dict[str, str]] = None, |
|
**kwargs, |
|
) -> pd.DataFrame: |
|
if not query_span_id.strip(): |
|
raise gr.Error("No query span selected.") |
|
try: |
|
relation_label_mapping = relation_label_mapping or {} |
|
retrieval_result = retriever.invoke(input=query_span_id, return_related=True, **kwargs) |
|
records = [] |
|
for relevant_span_doc in retrieval_result: |
|
pie_doc, metadata = retriever.docstore.unwrap_with_metadata(relevant_span_doc) |
|
span_ann = metadata["attached_span"] |
|
tail_span_ann = metadata["attached_tail_span"] |
|
mapped_relation_label = relation_label_mapping.get( |
|
metadata["relation_label"], metadata["relation_label"] |
|
) |
|
records.append( |
|
{ |
|
"doc_id": pie_doc.id, |
|
"type": mapped_relation_label, |
|
"rel_score": metadata["relation_score"], |
|
"text": str(tail_span_ann), |
|
"span_id": relevant_span_doc.id, |
|
"label": tail_span_ann.label, |
|
"ref_score": metadata["relevance_score"], |
|
"ref_label": span_ann.label, |
|
"ref_text": str(span_ann), |
|
"ref_span_id": metadata["head_id"], |
|
} |
|
) |
|
return ( |
|
pd.DataFrame( |
|
records, |
|
columns=[ |
|
"type", |
|
|
|
|
|
"ref_score", |
|
"label", |
|
"text", |
|
"ref_label", |
|
"ref_text", |
|
"doc_id", |
|
"span_id", |
|
"ref_span_id", |
|
], |
|
) |
|
.sort_values(by=["ref_score"], ascending=False) |
|
.round(3) |
|
) |
|
except Exception as e: |
|
raise gr.Error(f"Failed to retrieve relevant ADUs: {e}") |
|
|
|
|
|
class RetrieverCallable(Protocol): |
|
def __call__( |
|
self, |
|
retriever: DocumentAwareSpanRetriever, |
|
query_span_id: str, |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
pass |
|
|
|
|
|
def _retrieve_for_all_spans( |
|
retriever: DocumentAwareSpanRetriever, |
|
query_doc_id: str, |
|
retrieve_func: RetrieverCallable, |
|
query_span_id_column: str = "query_span_id", |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
if not query_doc_id.strip(): |
|
raise gr.Error("No query document selected.") |
|
try: |
|
span_id2idx = retriever.get_span_id2idx_from_doc(query_doc_id) |
|
gr.Info(f"Retrieving results for {len(span_id2idx)} ADUs in document {query_doc_id}...") |
|
span_results = { |
|
query_span_id: retrieve_func( |
|
retriever=retriever, |
|
query_span_id=query_span_id, |
|
**kwargs, |
|
) |
|
for query_span_id in span_id2idx.keys() |
|
} |
|
span_results_not_empty = { |
|
query_span_id: df |
|
for query_span_id, df in span_results.items() |
|
if df is not None and not df.empty |
|
} |
|
|
|
|
|
for query_span_id, query_span_result in span_results_not_empty.items(): |
|
query_span_result[query_span_id_column] = query_span_id |
|
|
|
if len(span_results_not_empty) == 0: |
|
gr.Info(f"No results found for any ADU in document {query_doc_id}.") |
|
return None |
|
else: |
|
result = pd.concat(span_results_not_empty.values(), ignore_index=True) |
|
gr.Info(f"Retrieved {len(result)} ADUs for document {query_doc_id}.") |
|
return result |
|
except Exception as e: |
|
raise gr.Error( |
|
f'Failed to retrieve results for all ADUs in document "{query_doc_id}": {e}' |
|
) |
|
|
|
|
|
def retrieve_all_similar_spans( |
|
retriever: DocumentAwareSpanRetriever, |
|
query_doc_id: str, |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
return _retrieve_for_all_spans( |
|
retriever=retriever, |
|
query_doc_id=query_doc_id, |
|
retrieve_func=retrieve_similar_spans, |
|
**kwargs, |
|
) |
|
|
|
|
|
def retrieve_all_relevant_spans( |
|
retriever: DocumentAwareSpanRetriever, |
|
query_doc_id: str, |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
return _retrieve_for_all_spans( |
|
retriever=retriever, |
|
query_doc_id=query_doc_id, |
|
retrieve_func=retrieve_relevant_spans, |
|
**kwargs, |
|
) |
|
|
|
|
|
class RetrieverForAllSpansCallable(Protocol): |
|
def __call__( |
|
self, |
|
retriever: DocumentAwareSpanRetriever, |
|
query_doc_id: str, |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
pass |
|
|
|
|
|
def _retrieve_for_all_documents( |
|
retriever: DocumentAwareSpanRetriever, |
|
retrieve_func: RetrieverForAllSpansCallable, |
|
query_doc_id_column: str = "query_doc_id", |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
try: |
|
all_doc_ids = list(retriever.docstore.yield_keys()) |
|
gr.Info(f"Retrieving results for {len(all_doc_ids)} documents...") |
|
doc_results = { |
|
doc_id: retrieve_func(retriever=retriever, query_doc_id=doc_id, **kwargs) |
|
for doc_id in all_doc_ids |
|
} |
|
doc_results_not_empty = { |
|
doc_id: df for doc_id, df in doc_results.items() if df is not None and not df.empty |
|
} |
|
|
|
for doc_id, doc_result in doc_results_not_empty.items(): |
|
doc_result[query_doc_id_column] = doc_id |
|
|
|
if len(doc_results_not_empty) == 0: |
|
gr.Info("No results found for any document.") |
|
return None |
|
else: |
|
result = pd.concat(doc_results_not_empty, ignore_index=True) |
|
gr.Info(f"Retrieved {len(result)} ADUs for all documents.") |
|
return result |
|
except Exception as e: |
|
raise gr.Error(f"Failed to retrieve results for all documents: {e}") |
|
|
|
|
|
def retrieve_all_similar_spans_for_all_documents( |
|
retriever: DocumentAwareSpanRetriever, |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
return _retrieve_for_all_documents( |
|
retriever=retriever, |
|
retrieve_func=retrieve_all_similar_spans, |
|
**kwargs, |
|
) |
|
|
|
|
|
def retrieve_all_relevant_spans_for_all_documents( |
|
retriever: DocumentAwareSpanRetriever, |
|
**kwargs, |
|
) -> Optional[pd.DataFrame]: |
|
return _retrieve_for_all_documents( |
|
retriever=retriever, |
|
retrieve_func=retrieve_all_relevant_spans, |
|
**kwargs, |
|
) |
|
|