import json import logging import os import tempfile from pathlib import Path from typing import Iterable, List, Optional, Sequence import gradio as gr import pandas as pd from acl_anthology import Anthology from pie_datasets import Dataset, IterableDataset, load_dataset from pytorch_ie import Pipeline from pytorch_ie.documents import ( TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, ) from tqdm import tqdm from src.demo.annotation_utils import annotate_documents, create_documents from src.demo.data_utils import load_text_from_arxiv from src.demo.rendering_utils import ( RENDER_WITH_DISPLACY, RENDER_WITH_PRETTY_TABLE, render_displacy, render_pretty_table, ) from src.demo.retriever_utils import get_text_spans_and_relations_from_document from src.langchain_modules import ( DocumentAwareSpanRetriever, DocumentAwareSpanRetrieverWithRelations, ) from src.utils.pdf_utils.acl_anthology_utils import XML2RawPapers from src.utils.pdf_utils.process_pdf import FulltextExtractor, PDFDownloader logger = logging.getLogger(__name__) def add_annotated_pie_documents( retriever: DocumentAwareSpanRetriever, pie_documents: Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions], use_predicted_annotations: bool, verbose: bool = False, ) -> None: if verbose: gr.Info(f"Create span embeddings for {len(pie_documents)} documents...") num_docs_before = len(retriever.docstore) retriever.add_pie_documents(pie_documents, use_predicted_annotations=use_predicted_annotations) # number of documents that were overwritten num_overwritten_docs = num_docs_before + len(pie_documents) - len(retriever.docstore) # warn if documents were overwritten if num_overwritten_docs > 0: gr.Warning(f"{num_overwritten_docs} documents were overwritten.") def process_texts( texts: Iterable[str], doc_ids: Iterable[str], argumentation_model: Pipeline, retriever: DocumentAwareSpanRetriever, split_regex_escaped: Optional[str], handle_parts_of_same: bool = False, verbose: bool = False, ) -> None: # check that doc_ids are unique if len(set(doc_ids)) != len(list(doc_ids)): raise gr.Error("Document IDs must be unique.") pie_documents = create_documents( texts=texts, doc_ids=doc_ids, split_regex=split_regex_escaped, ) if verbose: gr.Info(f"Annotate {len(pie_documents)} documents...") pie_documents = annotate_documents( documents=pie_documents, argumentation_model=argumentation_model, handle_parts_of_same=handle_parts_of_same, ) add_annotated_pie_documents( retriever=retriever, pie_documents=pie_documents, use_predicted_annotations=True, verbose=verbose, ) def add_annotated_pie_documents_from_dataset( retriever: DocumentAwareSpanRetriever, verbose: bool = False, **load_dataset_kwargs ) -> None: try: gr.Info( "Loading PIE dataset with parameters:\n" + json.dumps(load_dataset_kwargs, indent=2) ) dataset = load_dataset(**load_dataset_kwargs) if not isinstance(dataset, (Dataset, IterableDataset)): raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.") dataset_converted = dataset.to_document_type( TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions ) add_annotated_pie_documents( retriever=retriever, pie_documents=dataset_converted, use_predicted_annotations=False, verbose=verbose, ) except Exception as e: raise gr.Error(f"Failed to load dataset: {e}") def wrapped_process_text( doc_id: str, text: str, retriever: DocumentAwareSpanRetriever, **kwargs ) -> str: try: process_texts(doc_ids=[doc_id], texts=[text], retriever=retriever, **kwargs) except Exception as e: raise gr.Error(f"Failed to process text: {e}") # Return as dict and document to avoid serialization issues return doc_id def process_uploaded_files( file_names: List[str], retriever: DocumentAwareSpanRetriever, layer_captions: dict[str, str], **kwargs, ) -> pd.DataFrame: try: doc_ids = [] texts = [] for file_name in file_names: if file_name.lower().endswith(".txt"): # read the file content with open(file_name, "r", encoding="utf-8") as f: text = f.read() base_file_name = os.path.basename(file_name) doc_ids.append(base_file_name) texts.append(text) else: raise gr.Error(f"Unsupported file format: {file_name}") process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) except Exception as e: raise gr.Error(f"Failed to process uploaded files: {e}") return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) def process_uploaded_pdf_files( pdf_fulltext_extractor: Optional[FulltextExtractor], file_names: List[str], retriever: DocumentAwareSpanRetriever, layer_captions: dict[str, str], **kwargs, ) -> pd.DataFrame: try: if pdf_fulltext_extractor is None: raise gr.Error("PDF fulltext extractor is not available.") doc_ids = [] texts = [] for file_name in file_names: if file_name.lower().endswith(".pdf"): # extract the fulltext from the pdf text_and_extraction_data = pdf_fulltext_extractor(file_name) if text_and_extraction_data is None: raise gr.Error(f"Failed to extract fulltext from PDF: {file_name}") text, _ = text_and_extraction_data base_file_name = os.path.basename(file_name) doc_ids.append(base_file_name) texts.append(text) else: raise gr.Error(f"Unsupported file format: {file_name}") process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) except Exception as e: raise gr.Error(f"Failed to process uploaded files: {e}") return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) def load_acl_anthology_venues( venues: List[str], pdf_fulltext_extractor: Optional[FulltextExtractor], retriever: DocumentAwareSpanRetriever, layer_captions: dict[str, str], acl_anthology_data_dir: Optional[str], pdf_output_dir: Optional[str], show_progress: bool = True, **kwargs, ) -> pd.DataFrame: try: if pdf_fulltext_extractor is None: raise gr.Error("PDF fulltext extractor is not available.") if acl_anthology_data_dir is None: raise gr.Error("ACL Anthology data directory is not provided.") if pdf_output_dir is None: raise gr.Error("PDF output directory is not provided.") xml2raw_papers = XML2RawPapers( anthology=Anthology(datadir=Path(acl_anthology_data_dir)), venue_id_whitelist=venues, verbose=False, ) pdf_downloader = PDFDownloader() doc_ids = [] texts = [] os.makedirs(pdf_output_dir, exist_ok=True) papers = xml2raw_papers() if show_progress: papers_list = list(papers) papers = tqdm(papers_list, desc="extracting fulltext") gr.Info( f"Downloading and extracting fulltext from {len(papers_list)} papers in venues: {venues}" ) for paper in papers: if paper.url is not None: pdf_save_path = pdf_downloader.download( paper.url, opath=Path(pdf_output_dir) / f"{paper.name}.pdf" ) fulltext_extraction_output = pdf_fulltext_extractor(pdf_save_path) if fulltext_extraction_output: text, _ = fulltext_extraction_output doc_id = f"aclanthology.org/{paper.name}" doc_ids.append(doc_id) texts.append(text) else: gr.Warning(f"Failed to extract fulltext from PDF: {paper.url}") process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs) except Exception as e: raise gr.Error(f"Failed to process uploaded files: {e}") return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) def wrapped_add_annotated_pie_documents_from_dataset( retriever: DocumentAwareSpanRetriever, verbose: bool, layer_captions: dict[str, str], **kwargs ) -> pd.DataFrame: try: add_annotated_pie_documents_from_dataset(retriever=retriever, verbose=verbose, **kwargs) except Exception as e: raise gr.Error(f"Failed to add annotated PIE documents from dataset: {e}") return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) def download_processed_documents( retriever: DocumentAwareSpanRetriever, file_name: str = "retriever_store", ) -> Optional[str]: if len(retriever.docstore) == 0: gr.Warning("No documents to download.") return None # zip the directory file_path = os.path.join(tempfile.gettempdir(), file_name) gr.Info(f"Zipping the retriever store to '{file_name}' ...") result_file_path = retriever.save_to_archive(base_name=file_path, format="zip") return result_file_path def upload_processed_documents( file_name: str, retriever: DocumentAwareSpanRetriever, layer_captions: dict[str, str], ) -> pd.DataFrame: # load the documents from the zip file or directory retriever.load_from_disc(file_name) # return the overview of the document store return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True) def process_text_from_arxiv( arxiv_id: str, retriever: DocumentAwareSpanRetriever, abstract_only: bool = False, **kwargs ) -> str: try: text, doc_id = load_text_from_arxiv(arxiv_id=arxiv_id, abstract_only=abstract_only) except Exception as e: raise gr.Error(f"Failed to load text from arXiv: {e}") return wrapped_process_text(doc_id=doc_id, text=text, retriever=retriever, **kwargs) def render_annotated_document( retriever: DocumentAwareSpanRetrieverWithRelations, document_id: str, render_with: str, render_kwargs_json: str, highlight_span_ids: Optional[List[str]] = None, ) -> str: text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document( retriever=retriever, document_id=document_id ) render_kwargs = json.loads(render_kwargs_json) if render_with == RENDER_WITH_PRETTY_TABLE: html = render_pretty_table( text=text, spans=spans, span_id2idx=span_id2idx, binary_relations=relations, **render_kwargs, ) elif render_with == RENDER_WITH_DISPLACY: html = render_displacy( text=text, spans=spans, span_id2idx=span_id2idx, binary_relations=relations, highlight_span_ids=highlight_span_ids, **render_kwargs, ) else: raise ValueError(f"Unknown render_with value: {render_with}") return html