File size: 11,503 Bytes
3133b5e
 
 
 
ced4316
3133b5e
 
 
 
ced4316
3133b5e
 
 
 
 
ced4316
3133b5e
ced4316
3133b5e
 
 
 
 
 
 
 
 
 
 
 
ced4316
 
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced4316
 
 
 
 
3133b5e
 
ced4316
 
 
 
 
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced4316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced4316
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced4316
3133b5e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Iterable, List, Optional, Sequence

import gradio as gr
import pandas as pd
from acl_anthology import Anthology
from pie_datasets import Dataset, IterableDataset, load_dataset
from pytorch_ie import Pipeline
from pytorch_ie.documents import (
    TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
)
from tqdm import tqdm

from src.demo.annotation_utils import annotate_documents, create_documents
from src.demo.data_utils import load_text_from_arxiv
from src.demo.rendering_utils import (
    RENDER_WITH_DISPLACY,
    RENDER_WITH_PRETTY_TABLE,
    render_displacy,
    render_pretty_table,
)
from src.demo.retriever_utils import get_text_spans_and_relations_from_document
from src.langchain_modules import (
    DocumentAwareSpanRetriever,
    DocumentAwareSpanRetrieverWithRelations,
)
from src.utils.pdf_utils.acl_anthology_utils import XML2RawPapers
from src.utils.pdf_utils.process_pdf import FulltextExtractor, PDFDownloader

logger = logging.getLogger(__name__)


def add_annotated_pie_documents(
    retriever: DocumentAwareSpanRetriever,
    pie_documents: Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions],
    use_predicted_annotations: bool,
    verbose: bool = False,
) -> None:
    if verbose:
        gr.Info(f"Create span embeddings for {len(pie_documents)} documents...")
    num_docs_before = len(retriever.docstore)
    retriever.add_pie_documents(pie_documents, use_predicted_annotations=use_predicted_annotations)
    # number of documents that were overwritten
    num_overwritten_docs = num_docs_before + len(pie_documents) - len(retriever.docstore)
    # warn if documents were overwritten
    if num_overwritten_docs > 0:
        gr.Warning(f"{num_overwritten_docs} documents were overwritten.")


def process_texts(
    texts: Iterable[str],
    doc_ids: Iterable[str],
    argumentation_model: Pipeline,
    retriever: DocumentAwareSpanRetriever,
    split_regex_escaped: Optional[str],
    handle_parts_of_same: bool = False,
    verbose: bool = False,
) -> None:
    # check that doc_ids are unique
    if len(set(doc_ids)) != len(list(doc_ids)):
        raise gr.Error("Document IDs must be unique.")
    pie_documents = create_documents(
        texts=texts,
        doc_ids=doc_ids,
        split_regex=split_regex_escaped,
    )
    if verbose:
        gr.Info(f"Annotate {len(pie_documents)} documents...")
    pie_documents = annotate_documents(
        documents=pie_documents,
        argumentation_model=argumentation_model,
        handle_parts_of_same=handle_parts_of_same,
    )
    add_annotated_pie_documents(
        retriever=retriever,
        pie_documents=pie_documents,
        use_predicted_annotations=True,
        verbose=verbose,
    )


def add_annotated_pie_documents_from_dataset(
    retriever: DocumentAwareSpanRetriever, verbose: bool = False, **load_dataset_kwargs
) -> None:
    try:
        gr.Info(
            "Loading PIE dataset with parameters:\n" + json.dumps(load_dataset_kwargs, indent=2)
        )
        dataset = load_dataset(**load_dataset_kwargs)
        if not isinstance(dataset, (Dataset, IterableDataset)):
            raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.")
        dataset_converted = dataset.to_document_type(
            TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
        )
        add_annotated_pie_documents(
            retriever=retriever,
            pie_documents=dataset_converted,
            use_predicted_annotations=False,
            verbose=verbose,
        )
    except Exception as e:
        raise gr.Error(f"Failed to load dataset: {e}")


def wrapped_process_text(
    doc_id: str, text: str, retriever: DocumentAwareSpanRetriever, **kwargs
) -> str:
    try:
        process_texts(doc_ids=[doc_id], texts=[text], retriever=retriever, **kwargs)
    except Exception as e:
        raise gr.Error(f"Failed to process text: {e}")
    # Return as dict and document to avoid serialization issues
    return doc_id


def process_uploaded_files(
    file_names: List[str],
    retriever: DocumentAwareSpanRetriever,
    layer_captions: dict[str, str],
    **kwargs,
) -> pd.DataFrame:
    try:
        doc_ids = []
        texts = []
        for file_name in file_names:
            if file_name.lower().endswith(".txt"):
                # read the file content
                with open(file_name, "r", encoding="utf-8") as f:
                    text = f.read()
                base_file_name = os.path.basename(file_name)
                doc_ids.append(base_file_name)
                texts.append(text)
            else:
                raise gr.Error(f"Unsupported file format: {file_name}")
        process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs)
    except Exception as e:
        raise gr.Error(f"Failed to process uploaded files: {e}")

    return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)


def process_uploaded_pdf_files(
    pdf_fulltext_extractor: Optional[FulltextExtractor],
    file_names: List[str],
    retriever: DocumentAwareSpanRetriever,
    layer_captions: dict[str, str],
    **kwargs,
) -> pd.DataFrame:
    try:
        if pdf_fulltext_extractor is None:
            raise gr.Error("PDF fulltext extractor is not available.")
        doc_ids = []
        texts = []
        for file_name in file_names:
            if file_name.lower().endswith(".pdf"):
                # extract the fulltext from the pdf
                text_and_extraction_data = pdf_fulltext_extractor(file_name)
                if text_and_extraction_data is None:
                    raise gr.Error(f"Failed to extract fulltext from PDF: {file_name}")
                text, _ = text_and_extraction_data

                base_file_name = os.path.basename(file_name)
                doc_ids.append(base_file_name)
                texts.append(text)

            else:
                raise gr.Error(f"Unsupported file format: {file_name}")
        process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs)
    except Exception as e:
        raise gr.Error(f"Failed to process uploaded files: {e}")

    return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)


def load_acl_anthology_venues(
    venues: List[str],
    pdf_fulltext_extractor: Optional[FulltextExtractor],
    retriever: DocumentAwareSpanRetriever,
    layer_captions: dict[str, str],
    acl_anthology_data_dir: Optional[str],
    pdf_output_dir: Optional[str],
    show_progress: bool = True,
    **kwargs,
) -> pd.DataFrame:
    try:
        if pdf_fulltext_extractor is None:
            raise gr.Error("PDF fulltext extractor is not available.")
        if acl_anthology_data_dir is None:
            raise gr.Error("ACL Anthology data directory is not provided.")
        if pdf_output_dir is None:
            raise gr.Error("PDF output directory is not provided.")
        xml2raw_papers = XML2RawPapers(
            anthology=Anthology(datadir=Path(acl_anthology_data_dir)),
            venue_id_whitelist=venues,
            verbose=False,
        )
        pdf_downloader = PDFDownloader()
        doc_ids = []
        texts = []
        os.makedirs(pdf_output_dir, exist_ok=True)
        papers = xml2raw_papers()
        if show_progress:
            papers_list = list(papers)
            papers = tqdm(papers_list, desc="extracting fulltext")
            gr.Info(
                f"Downloading and extracting fulltext from {len(papers_list)} papers in venues: {venues}"
            )
        for paper in papers:
            if paper.url is not None:
                pdf_save_path = pdf_downloader.download(
                    paper.url, opath=Path(pdf_output_dir) / f"{paper.name}.pdf"
                )
                fulltext_extraction_output = pdf_fulltext_extractor(pdf_save_path)

                if fulltext_extraction_output:
                    text, _ = fulltext_extraction_output
                    doc_id = f"aclanthology.org/{paper.name}"
                    doc_ids.append(doc_id)
                    texts.append(text)
                else:
                    gr.Warning(f"Failed to extract fulltext from PDF: {paper.url}")

        process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs)
    except Exception as e:
        raise gr.Error(f"Failed to process uploaded files: {e}")

    return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)


def wrapped_add_annotated_pie_documents_from_dataset(
    retriever: DocumentAwareSpanRetriever, verbose: bool, layer_captions: dict[str, str], **kwargs
) -> pd.DataFrame:
    try:
        add_annotated_pie_documents_from_dataset(retriever=retriever, verbose=verbose, **kwargs)
    except Exception as e:
        raise gr.Error(f"Failed to add annotated PIE documents from dataset: {e}")
    return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)


def download_processed_documents(
    retriever: DocumentAwareSpanRetriever,
    file_name: str = "retriever_store",
) -> Optional[str]:
    if len(retriever.docstore) == 0:
        gr.Warning("No documents to download.")
        return None

    # zip the directory
    file_path = os.path.join(tempfile.gettempdir(), file_name)

    gr.Info(f"Zipping the retriever store to '{file_name}' ...")
    result_file_path = retriever.save_to_archive(base_name=file_path, format="zip")

    return result_file_path


def upload_processed_documents(
    file_name: str,
    retriever: DocumentAwareSpanRetriever,
    layer_captions: dict[str, str],
) -> pd.DataFrame:
    # load the documents from the zip file or directory
    retriever.load_from_disc(file_name)
    # return the overview of the document store
    return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)


def process_text_from_arxiv(
    arxiv_id: str, retriever: DocumentAwareSpanRetriever, abstract_only: bool = False, **kwargs
) -> str:
    try:
        text, doc_id = load_text_from_arxiv(arxiv_id=arxiv_id, abstract_only=abstract_only)
    except Exception as e:
        raise gr.Error(f"Failed to load text from arXiv: {e}")
    return wrapped_process_text(doc_id=doc_id, text=text, retriever=retriever, **kwargs)


def render_annotated_document(
    retriever: DocumentAwareSpanRetrieverWithRelations,
    document_id: str,
    render_with: str,
    render_kwargs_json: str,
    highlight_span_ids: Optional[List[str]] = None,
) -> str:
    text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document(
        retriever=retriever, document_id=document_id
    )

    render_kwargs = json.loads(render_kwargs_json)
    if render_with == RENDER_WITH_PRETTY_TABLE:
        html = render_pretty_table(
            text=text,
            spans=spans,
            span_id2idx=span_id2idx,
            binary_relations=relations,
            **render_kwargs,
        )
    elif render_with == RENDER_WITH_DISPLACY:
        html = render_displacy(
            text=text,
            spans=spans,
            span_id2idx=span_id2idx,
            binary_relations=relations,
            highlight_span_ids=highlight_span_ids,
            **render_kwargs,
        )
    else:
        raise ValueError(f"Unknown render_with value: {render_with}")

    return html