update from https://github.com/ArneBinder/pie-document-level/pull/397
Browse files- argumentation_model/_joint.yaml +4 -0
- argumentation_model/_pipelined.yaml +17 -0
- argumentation_model/joint.yaml +10 -0
- argumentation_model/joint_hps.yaml +7 -0
- argumentation_model/pipelined.yaml +8 -0
- argumentation_model/pipelined_deprecated.yaml +9 -0
- argumentation_model/pipelined_hps.yaml +8 -0
- argumentation_model/pipelined_new.yaml +14 -0
- demo.yaml +85 -0
- pdf_fulltext_extractor/grobid_local.yaml +18 -0
- pdf_fulltext_extractor/none.yaml +0 -0
- requirements.txt +8 -2
- retriever/related_span_retriever_with_relations_from_other_docs.yaml +49 -0
- src/analysis/__init__.py +0 -0
- src/analysis/combine_job_returns.py +169 -0
- src/analysis/common.py +47 -0
- src/analysis/compare_job_returns.py +407 -0
- src/data/acl_anthology_crawler.py +117 -0
- src/data/calc_iaa_for_brat.py +272 -0
- src/data/construct_sciarg_abstracts_remaining_gold_retrieval.py +238 -0
- src/data/prepare_sciarg_crosssection_annotations.py +398 -0
- src/data/split_sciarg_abstracts.py +132 -0
- src/demo/annotation_utils.py +88 -41
- src/demo/backend_utils.py +106 -13
- src/demo/frontend_utils.py +12 -0
- src/demo/rendering_utils.py +23 -3
- src/demo/rendering_utils_displacy.py +12 -1
- src/demo/retrieve_and_dump_all_relevant.py +61 -2
- src/demo/retriever_utils.py +8 -6
- src/document/processing.py +212 -77
- src/hydra_callbacks/save_job_return_value.py +178 -40
- src/langchain_modules/pie_document_store.py +1 -1
- src/langchain_modules/span_retriever.py +13 -16
- src/pipeline/ner_re_pipeline.py +45 -15
- src/predict.py +6 -2
- src/start_demo.py +161 -36
- src/train.py +10 -0
- src/utils/__init__.py +6 -1
- src/utils/config_utils.py +15 -1
- src/utils/pdf_utils/README.MD +35 -0
- src/utils/pdf_utils/__init__.py +0 -0
- src/utils/pdf_utils/acl_anthology_utils.py +77 -0
- src/utils/pdf_utils/client.py +193 -0
- src/utils/pdf_utils/grobid_client.py +203 -0
- src/utils/pdf_utils/grobid_util.py +413 -0
- src/utils/pdf_utils/process_pdf.py +276 -0
- src/utils/pdf_utils/raw_paper.py +90 -0
- src/utils/pdf_utils/s2orc_paper.py +478 -0
- src/utils/pdf_utils/s2orc_utils.py +61 -0
- src/utils/pdf_utils/utils.py +904 -0
argumentation_model/_joint.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: pytorch_ie.auto.AutoPipeline.from_pretrained
|
2 |
+
pretrained_model_name_or_path: ???
|
3 |
+
# this batch_size that works good (fastest) on a single RTX2080Ti (11GB) (see https://github.com/ArneBinder/pie-document-level/issues/334#issuecomment-2613232344)
|
4 |
+
batch_size: 1
|
argumentation_model/_pipelined.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.pipeline.NerRePipeline
|
2 |
+
ner_model_path: ???
|
3 |
+
re_model_path: ???
|
4 |
+
entity_layer: labeled_spans
|
5 |
+
relation_layer: binary_relations
|
6 |
+
# this works good on a single RTX2080Ti (11GB)
|
7 |
+
ner_pipeline:
|
8 |
+
batch_size: 256
|
9 |
+
re_pipeline:
|
10 |
+
batch_size: 64
|
11 |
+
# convert the RE model to half precision for mixed precision inference (speedup approx. 4x)
|
12 |
+
half_precision_model: true
|
13 |
+
taskmodule_kwargs:
|
14 |
+
# don't show statistics after encoding
|
15 |
+
collect_statistics: false
|
16 |
+
# don't show pipeline steps
|
17 |
+
verbose: false
|
argumentation_model/joint.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _joint
|
3 |
+
|
4 |
+
# best model based on the validation set (see https://github.com/ArneBinder/pie-document-level/issues/334#issuecomment-2613232344 for details)
|
5 |
+
# i.e. models from https://github.com/ArneBinder/pie-document-level/issues/334#issuecomment-2578422544, but with last checkpoint (instead of best validation checkpoint)
|
6 |
+
# model_name_or_path: models/dataset-sciarg/task-ner_re/v0.4/2025-01-09_01-50-53
|
7 |
+
# ckpt_path: logs/training/multiruns/dataset-sciarg/task-ner_re/v0.4/2025-01-09_01-50-52/2/checkpoints/last.ckpt
|
8 |
+
# w&b run (for the loaded checkpoint): [icy-glitter-5](https://wandb.ai/arne/dataset-sciarg-task-ner_re-v0.4-training/runs/it5toj6w)
|
9 |
+
pretrained_model_name_or_path: "ArneBinder/sam-pointer-bart-base-v0.4"
|
10 |
+
revision: "0445c69bafa31f8153aaeafc1767fad84919926a"
|
argumentation_model/joint_hps.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _joint
|
3 |
+
|
4 |
+
# from: hparams_search for all datasets
|
5 |
+
# see https://github.com/ArneBinder/pie-document-level/pull/381#issuecomment-2682711151
|
6 |
+
# THESE ARE LOCAL PATHS, NOT HUGGINGFACE MODELS!
|
7 |
+
pretrained_model_name_or_path: models/dataset-sciarg/task-ner_re/2025-02-23_05-16-45
|
argumentation_model/pipelined.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _pipelined
|
3 |
+
|
4 |
+
# from: train pipeline models with bigger train set,
|
5 |
+
# see https://github.com/ArneBinder/pie-document-level/issues/355#issuecomment-2612958658
|
6 |
+
# THESE ARE LOCAL PATHS, NOT HUGGINGFACE MODELS!
|
7 |
+
ner_model_path: models/dataset-sciarg/task-adus/v0.4/2025-01-20_05-50-00
|
8 |
+
re_model_path: models/dataset-sciarg/task-relations/v0.4/2025-01-22_20-36-23
|
argumentation_model/pipelined_deprecated.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _pipelined
|
3 |
+
|
4 |
+
# from: train pipeline models with bigger train set, but with strange choice of models,
|
5 |
+
# see edit history of https://github.com/ArneBinder/pie-document-level/issues/355#issuecomment-2612958658
|
6 |
+
# NOTE: these were originally in the pipelined.yaml
|
7 |
+
# THESE ARE LOCAL PATHS, NOT HUGGINGFACE MODELS!
|
8 |
+
ner_model_path: models/dataset-sciarg/task-adus/v0.4/2025-01-20_09-09-11
|
9 |
+
re_model_path: models/dataset-sciarg/task-relations/v0.4/2025-01-22_12-44-51
|
argumentation_model/pipelined_hps.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _pipelined
|
3 |
+
|
4 |
+
# from: hparams_search for all datasets,
|
5 |
+
# see https://github.com/ArneBinder/pie-document-level/pull/381#issuecomment-2684865102
|
6 |
+
# THESE ARE LOCAL PATHS, NOT HUGGINGFACE MODELS!
|
7 |
+
ner_model_path: models/dataset-sciarg/task-adur/2025-02-26_07-14-59
|
8 |
+
re_model_path: models/dataset-sciarg/task-are/2025-02-20_18-09-25
|
argumentation_model/pipelined_new.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _pipelined
|
3 |
+
|
4 |
+
# from: Update scientific ARE experiment configs,
|
5 |
+
# see https://github.com/ArneBinder/pie-document-level/pull/379#issuecomment-2651669398
|
6 |
+
# i.e. the models are now on Hugging Face
|
7 |
+
# ner_model_path: models/dataset-sciarg/task-adur/2025-02-09_23-08-37
|
8 |
+
# re_model_path: models/dataset-sciarg/task-are/2025-02-10_19-24-52
|
9 |
+
ner_model_path: ArneBinder/sam-adur-sciarg
|
10 |
+
ner_pipeline:
|
11 |
+
revision: bcbef4e585a5f637009ff702661cf824abede6b0
|
12 |
+
re_model_path: ArneBinder/sam-are-sciarg
|
13 |
+
re_pipeline:
|
14 |
+
revision: 93024388330c58daf20963c2020e08f54553e74c
|
demo.yaml
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _self_
|
3 |
+
# default retriever, see subfolder retriever for more details
|
4 |
+
- retriever: related_span_retriever_with_relations_from_other_docs
|
5 |
+
# default argumentation model, see subfolder argumentation_model for more details
|
6 |
+
- argumentation_model: pipelined_new
|
7 |
+
# since this requires a running GROBID server, we disable it by default
|
8 |
+
- pdf_fulltext_extractor: none
|
9 |
+
|
10 |
+
# Whether to handle segmented entities in the document. If True, labeled_spans are converted
|
11 |
+
# to labeled_multi_spans and binary_relations with label "parts_of_same" are used to merge them.
|
12 |
+
# This requires the networkx package to be installed.
|
13 |
+
handle_parts_of_same: true
|
14 |
+
# Split the document text into sections that are processed separately.
|
15 |
+
default_split_regex: "\n\n\n+"
|
16 |
+
|
17 |
+
# retriever details (query parameters)
|
18 |
+
default_min_similarity: 0.95
|
19 |
+
default_top_k: 10
|
20 |
+
|
21 |
+
# data import details
|
22 |
+
default_arxiv_id: "1706.03762"
|
23 |
+
default_load_pie_dataset_kwargs:
|
24 |
+
path: "pie/sciarg"
|
25 |
+
name: "resolve_parts_of_same"
|
26 |
+
split: "train"
|
27 |
+
|
28 |
+
# set to the data directory of https://github.com/acl-org/acl-anthology
|
29 |
+
# to enable ACL venue PDF import (requires to also have a valid pdf_fulltext_extractor)
|
30 |
+
# acl_anthology_data_dir=../acl-anthology/data
|
31 |
+
# temporary directory to store downloaded PDFs
|
32 |
+
acl_anthology_pdf_dir: "data/acl-anthology/pdf"
|
33 |
+
|
34 |
+
# for better readability in the UI
|
35 |
+
render_mode_captions:
|
36 |
+
displacy: "displaCy + highlighted arguments"
|
37 |
+
pretty_table: "Pretty Table"
|
38 |
+
layer_caption_mapping:
|
39 |
+
labeled_multi_spans: "adus"
|
40 |
+
binary_relations: "relations"
|
41 |
+
labeled_partitions: "partitions"
|
42 |
+
relation_name_mapping:
|
43 |
+
supports_reversed: "supported by"
|
44 |
+
contradicts_reversed: "contradicts"
|
45 |
+
|
46 |
+
default_render_mode: "displacy"
|
47 |
+
default_render_kwargs:
|
48 |
+
entity_options:
|
49 |
+
# we need to have the keys as uppercase because the spacy rendering function converts the labels to uppercase
|
50 |
+
colors:
|
51 |
+
OWN_CLAIM: "#009933"
|
52 |
+
BACKGROUND_CLAIM: "#99ccff"
|
53 |
+
DATA: "#993399"
|
54 |
+
colors_hover:
|
55 |
+
selected: "#ffa"
|
56 |
+
# tail options for relationships
|
57 |
+
tail:
|
58 |
+
# green
|
59 |
+
supports: "#9f9"
|
60 |
+
# red
|
61 |
+
contradicts: "#f99"
|
62 |
+
# do not highlight
|
63 |
+
parts_of_same: null
|
64 |
+
head: null # "#faf"
|
65 |
+
other: null
|
66 |
+
|
67 |
+
example_text: >
|
68 |
+
Scholarly Argumentation Mining (SAM) has recently gained attention due to its
|
69 |
+
potential to help scholars with the rapid growth of published scientific literature.
|
70 |
+
It comprises two subtasks: argumentative discourse unit recognition (ADUR) and
|
71 |
+
argumentative relation extraction (ARE), both of which are challenging since they
|
72 |
+
require e.g. the integration of domain knowledge, the detection of implicit statements,
|
73 |
+
and the disambiguation of argument structure.
|
74 |
+
|
75 |
+
While previous work focused on dataset construction and baseline methods for
|
76 |
+
specific document sections, such as abstract or results, full-text scholarly argumentation
|
77 |
+
mining has seen little progress. In this work, we introduce a sequential pipeline model
|
78 |
+
combining ADUR and ARE for full-text SAM, and provide a first analysis of the
|
79 |
+
performance of pretrained language models (PLMs) on both subtasks.
|
80 |
+
|
81 |
+
We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best
|
82 |
+
reported result by a large margin (+7% F1). We also present the first results for ARE, and
|
83 |
+
thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals
|
84 |
+
that non-contiguous ADUs as well as the interpretation of discourse connectors pose major
|
85 |
+
challenges and that data annotation needs to be more consistent.
|
pdf_fulltext_extractor/grobid_local.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This requires a running GROBID server. To start the server via Docker, run:
|
2 |
+
# docker run --rm --init --ulimit core=0 -p 8070:8070 lfoppiano/grobid:0.8.0
|
3 |
+
|
4 |
+
_target_: src.utils.pdf_utils.process_pdf.GrobidFulltextExtractor
|
5 |
+
section_seperator: "\n\n\n"
|
6 |
+
paragraph_seperator: "\n\n"
|
7 |
+
grobid_config:
|
8 |
+
grobid_server: localhost
|
9 |
+
grobid_port: 8070
|
10 |
+
batch_size: 1000
|
11 |
+
sleep_time: 5
|
12 |
+
generateIDs: false
|
13 |
+
consolidate_header: false
|
14 |
+
consolidate_citations: false
|
15 |
+
include_raw_citations: true
|
16 |
+
include_raw_affiliations: false
|
17 |
+
max_workers: 2
|
18 |
+
verbose: false
|
pdf_fulltext_extractor/none.yaml
ADDED
File without changes
|
requirements.txt
CHANGED
@@ -1,7 +1,11 @@
|
|
|
|
|
|
|
|
|
|
1 |
# --------- pytorch-ie --------- #
|
2 |
-
pytorch-ie>=0.
|
3 |
pie-datasets>=0.10.5,<0.11.0
|
4 |
-
pie-modules>=0.14.
|
5 |
|
6 |
# --------- models -------- #
|
7 |
adapters>=0.1.2,<0.2.0
|
@@ -17,6 +21,8 @@ qdrant-client>=1.12.0,<2.0.0
|
|
17 |
# --------- demo -------- #
|
18 |
gradio~=5.5.0
|
19 |
arxiv~=2.1.3
|
|
|
|
|
20 |
|
21 |
# --------- hydra --------- #
|
22 |
hydra-core>=1.3.0
|
|
|
1 |
+
# -------- dl backend -------- #
|
2 |
+
torch==2.0.0
|
3 |
+
pytorch-lightning==2.1.2
|
4 |
+
|
5 |
# --------- pytorch-ie --------- #
|
6 |
+
pytorch-ie>=0.31.4,<0.32.0
|
7 |
pie-datasets>=0.10.5,<0.11.0
|
8 |
+
pie-modules>=0.14.2,<0.15.0
|
9 |
|
10 |
# --------- models -------- #
|
11 |
adapters>=0.1.2,<0.2.0
|
|
|
21 |
# --------- demo -------- #
|
22 |
gradio~=5.5.0
|
23 |
arxiv~=2.1.3
|
24 |
+
# data preparation
|
25 |
+
acl-anthology-py>=0.4.3
|
26 |
|
27 |
# --------- hydra --------- #
|
28 |
hydra-core>=1.3.0
|
retriever/related_span_retriever_with_relations_from_other_docs.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.langchain_modules.DocumentAwareSpanRetrieverWithRelations
|
2 |
+
symmetric_relations:
|
3 |
+
- contradicts
|
4 |
+
reversed_relations_suffix: _reversed
|
5 |
+
relation_labels:
|
6 |
+
- supports_reversed
|
7 |
+
- contradicts
|
8 |
+
retrieve_from_same_document: false
|
9 |
+
retrieve_from_different_documents: true
|
10 |
+
pie_document_type:
|
11 |
+
_target_: pie_modules.utils.resolve_type
|
12 |
+
type_or_str: pytorch_ie.documents.TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
|
13 |
+
docstore:
|
14 |
+
_target_: src.langchain_modules.DatasetsPieDocumentStore
|
15 |
+
search_kwargs:
|
16 |
+
k: 10
|
17 |
+
search_type: similarity_score_threshold
|
18 |
+
vectorstore:
|
19 |
+
_target_: src.langchain_modules.QdrantSpanVectorStore
|
20 |
+
embedding:
|
21 |
+
_target_: src.langchain_modules.HuggingFaceSpanEmbeddings
|
22 |
+
model:
|
23 |
+
_target_: src.models.utils.load_model_with_adapter
|
24 |
+
model_kwargs:
|
25 |
+
pretrained_model_name_or_path: allenai/specter2_base
|
26 |
+
adapter_kwargs:
|
27 |
+
adapter_name_or_path: allenai/specter2
|
28 |
+
load_as: proximity
|
29 |
+
source: hf
|
30 |
+
pipeline_kwargs:
|
31 |
+
tokenizer: allenai/specter2_base
|
32 |
+
stride: 64
|
33 |
+
batch_size: 32
|
34 |
+
model_max_length: 512
|
35 |
+
client:
|
36 |
+
_target_: qdrant_client.QdrantClient
|
37 |
+
location: ":memory:"
|
38 |
+
collection_name: adus
|
39 |
+
vector_params:
|
40 |
+
distance:
|
41 |
+
_target_: qdrant_client.http.models.Distance
|
42 |
+
value: Cosine
|
43 |
+
label_mapping:
|
44 |
+
background_claim:
|
45 |
+
- background_claim
|
46 |
+
- own_claim
|
47 |
+
own_claim:
|
48 |
+
- background_claim
|
49 |
+
- own_claim
|
src/analysis/__init__.py
ADDED
File without changes
|
src/analysis/combine_job_returns.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrootutils
|
2 |
+
|
3 |
+
root = pyrootutils.setup_root(
|
4 |
+
search_from=__file__,
|
5 |
+
indicator=[".project-root"],
|
6 |
+
pythonpath=True,
|
7 |
+
dotenv=False,
|
8 |
+
)
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import os
|
12 |
+
|
13 |
+
import pandas as pd
|
14 |
+
|
15 |
+
from src.analysis.common import read_nested_jsons
|
16 |
+
|
17 |
+
|
18 |
+
def separate_path_and_id(path_and_maybe_id: str, separator: str = ":") -> tuple[str | None, str]:
|
19 |
+
parts = path_and_maybe_id.split(separator, 1)
|
20 |
+
if len(parts) == 1:
|
21 |
+
return None, parts[0]
|
22 |
+
return parts[0], parts[1]
|
23 |
+
|
24 |
+
|
25 |
+
def get_file_paths(paths_file: str, file_name: str, use_aggregated: bool) -> dict[str, str]:
|
26 |
+
with open(paths_file, "r") as f:
|
27 |
+
paths_maybe_with_ids = f.readlines()
|
28 |
+
ids, paths = zip(*[separate_path_and_id(path.strip()) for path in paths_maybe_with_ids])
|
29 |
+
|
30 |
+
if use_aggregated:
|
31 |
+
file_base_name, ext = os.path.splitext(file_name)
|
32 |
+
file_name = f"{file_base_name}.aggregated{ext}"
|
33 |
+
file_paths = [os.path.join(path, file_name) for path in paths]
|
34 |
+
return {
|
35 |
+
id if id is not None else f"idx={idx}": path
|
36 |
+
for idx, (id, path) in enumerate(zip(ids, file_paths))
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
def main(
|
41 |
+
paths_file: str,
|
42 |
+
file_name: str,
|
43 |
+
use_aggregated: bool,
|
44 |
+
columns: list[str] | None,
|
45 |
+
round_precision: int | None,
|
46 |
+
format: str,
|
47 |
+
transpose: bool = False,
|
48 |
+
unpack_multirun_results: bool = False,
|
49 |
+
in_percent: bool = False,
|
50 |
+
):
|
51 |
+
file_paths = get_file_paths(
|
52 |
+
paths_file=paths_file, file_name=file_name, use_aggregated=use_aggregated
|
53 |
+
)
|
54 |
+
data = read_nested_jsons(json_paths=file_paths)
|
55 |
+
|
56 |
+
if columns is not None:
|
57 |
+
columns_multi_index = [tuple(col.split("/")) for col in columns]
|
58 |
+
try:
|
59 |
+
data_series = [data[col] for col in columns_multi_index]
|
60 |
+
except KeyError as e:
|
61 |
+
print(
|
62 |
+
f"Columns {columns_multi_index} not found in the data. Available columns are {list(data.columns)}."
|
63 |
+
)
|
64 |
+
raise e
|
65 |
+
data = pd.concat(data_series, axis=1)
|
66 |
+
|
67 |
+
# drop rows that are all NaN
|
68 |
+
data = data.dropna(how="all")
|
69 |
+
|
70 |
+
# if more than one data point, drop the index levels that are everywhere the same
|
71 |
+
if len(data) > 1:
|
72 |
+
unique_levels = [
|
73 |
+
idx
|
74 |
+
for idx, level in enumerate(data.index.levels)
|
75 |
+
if len(data.index.get_level_values(idx).unique()) == 1
|
76 |
+
]
|
77 |
+
for level in sorted(unique_levels, reverse=True):
|
78 |
+
data.index = data.index.droplevel(level)
|
79 |
+
|
80 |
+
# if more than one column, drop the columns that are everywhere the same
|
81 |
+
if len(data.columns) > 1:
|
82 |
+
unique_column_levels = [
|
83 |
+
idx
|
84 |
+
for idx, level in enumerate(data.columns.levels)
|
85 |
+
if len(data.columns.get_level_values(idx).unique()) == 1
|
86 |
+
]
|
87 |
+
for level in sorted(unique_column_levels, reverse=True):
|
88 |
+
data.columns = data.columns.droplevel(level)
|
89 |
+
|
90 |
+
if unpack_multirun_results:
|
91 |
+
index_names = list(data.index.names)
|
92 |
+
data_series_lists = data.unstack()
|
93 |
+
data = pd.DataFrame.from_records(
|
94 |
+
data_series_lists.values, index=data_series_lists.index
|
95 |
+
).stack()
|
96 |
+
for _, index_name in enumerate(index_names):
|
97 |
+
data = data.unstack(index_name)
|
98 |
+
data = data.T
|
99 |
+
|
100 |
+
if transpose:
|
101 |
+
data = data.T
|
102 |
+
|
103 |
+
# needs to happen before rounding, otherwise the rounding will be off
|
104 |
+
if in_percent:
|
105 |
+
data = data * 100
|
106 |
+
|
107 |
+
if round_precision is not None:
|
108 |
+
data = data.round(round_precision)
|
109 |
+
|
110 |
+
if format == "markdown":
|
111 |
+
print(data.to_markdown())
|
112 |
+
elif format == "markdown_mean_and_std":
|
113 |
+
if transpose:
|
114 |
+
data = data.T
|
115 |
+
if "mean" not in data.columns or "std" not in data.columns:
|
116 |
+
raise ValueError("Columns 'mean' and 'std' are required for this format.")
|
117 |
+
# create a single column with mean and std in the format: mean ± std
|
118 |
+
data = pd.DataFrame(
|
119 |
+
data["mean"].astype(str) + " ± " + data["std"].astype(str), columns=["mean ± std"]
|
120 |
+
)
|
121 |
+
if transpose:
|
122 |
+
data = data.T
|
123 |
+
print(data.to_markdown())
|
124 |
+
elif format == "json":
|
125 |
+
print(data.to_json())
|
126 |
+
else:
|
127 |
+
raise ValueError(f"Invalid format: {format}. Use 'markdown' or 'json'.")
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
parser = argparse.ArgumentParser(description="Combine job returns and show as Markdown table")
|
132 |
+
parser.add_argument(
|
133 |
+
"--paths-file", type=str, help="Path to the file containing the paths to the job returns"
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--use-aggregated", action="store_true", help="Whether to use the aggregated job returns"
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
"--file-name",
|
140 |
+
type=str,
|
141 |
+
default="job_return_value.json",
|
142 |
+
help="Name of the file to write the aggregated job returns to",
|
143 |
+
)
|
144 |
+
parser.add_argument(
|
145 |
+
"--columns", type=str, nargs="+", help="Columns to select from the combined job returns"
|
146 |
+
)
|
147 |
+
parser.add_argument(
|
148 |
+
"--unpack-multirun-results", action="store_true", help="Unpack multirun results"
|
149 |
+
)
|
150 |
+
parser.add_argument("--transpose", action="store_true", help="Transpose the table")
|
151 |
+
parser.add_argument(
|
152 |
+
"--round-precision",
|
153 |
+
type=int,
|
154 |
+
help="Round the values in the combined job returns to the specified precision",
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--in-percent", action="store_true", help="Show the values in percent (multiply by 100)"
|
158 |
+
)
|
159 |
+
parser.add_argument(
|
160 |
+
"--format",
|
161 |
+
type=str,
|
162 |
+
default="markdown",
|
163 |
+
choices=["markdown", "markdown_mean_and_std", "json"],
|
164 |
+
help="Format to output the combined job returns",
|
165 |
+
)
|
166 |
+
|
167 |
+
args = parser.parse_args()
|
168 |
+
kwargs = vars(args)
|
169 |
+
main(**kwargs)
|
src/analysis/common.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Dict, List, Optional
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
|
7 |
+
def parse_identifier(
|
8 |
+
identifier_str, defaults: Dict[str, str], parts_sep: str = ",", key_val_sep: str = "="
|
9 |
+
) -> Dict[str, str]:
|
10 |
+
parts = [
|
11 |
+
part.split(key_val_sep)
|
12 |
+
for part in identifier_str.strip().split(parts_sep)
|
13 |
+
if key_val_sep in part
|
14 |
+
]
|
15 |
+
parts_dict = dict(parts)
|
16 |
+
return {**defaults, **parts_dict}
|
17 |
+
|
18 |
+
|
19 |
+
def read_nested_json(path: str) -> pd.DataFrame:
|
20 |
+
# Read the nested JSON data into a pandas DataFrame
|
21 |
+
with open(path, "r") as f:
|
22 |
+
data = json.load(f)
|
23 |
+
result = pd.json_normalize(data, sep="/")
|
24 |
+
result.index.name = "entry"
|
25 |
+
return result
|
26 |
+
|
27 |
+
|
28 |
+
def read_nested_jsons(
|
29 |
+
json_paths: Dict[str, str],
|
30 |
+
default_key_values: Optional[Dict[str, str]] = None,
|
31 |
+
column_level_names: Optional[List[str]] = None,
|
32 |
+
) -> pd.DataFrame:
|
33 |
+
identifier_strings = json_paths.keys()
|
34 |
+
dfs = [read_nested_json(json_paths[identifier_str]) for identifier_str in identifier_strings]
|
35 |
+
new_index_levels = pd.MultiIndex.from_frame(
|
36 |
+
pd.DataFrame(
|
37 |
+
[
|
38 |
+
parse_identifier(identifier_str, default_key_values or {})
|
39 |
+
for identifier_str in identifier_strings
|
40 |
+
]
|
41 |
+
)
|
42 |
+
)
|
43 |
+
dfs_concat = pd.concat(dfs, keys=list(new_index_levels), names=new_index_levels.names, axis=0)
|
44 |
+
dfs_concat.columns = pd.MultiIndex.from_tuples(
|
45 |
+
[col.split("/") for col in dfs_concat.columns], names=column_level_names
|
46 |
+
)
|
47 |
+
return dfs_concat
|
src/analysis/compare_job_returns.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrootutils
|
2 |
+
|
3 |
+
root = pyrootutils.setup_root(
|
4 |
+
search_from=__file__,
|
5 |
+
indicator=[".project-root"],
|
6 |
+
pythonpath=True,
|
7 |
+
dotenv=False,
|
8 |
+
)
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import io
|
12 |
+
import re
|
13 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
14 |
+
|
15 |
+
import pandas as pd
|
16 |
+
import plotly.graph_objects as go
|
17 |
+
|
18 |
+
from src.analysis.common import parse_identifier, read_nested_jsons
|
19 |
+
|
20 |
+
|
21 |
+
def read_markdown_table(
|
22 |
+
markdown_data: str,
|
23 |
+
default_key_values: Optional[Dict[str, str]] = None,
|
24 |
+
column_level_names: Optional[List[str]] = None,
|
25 |
+
) -> pd.DataFrame:
|
26 |
+
# Read the markdown data into a pandas DataFrame
|
27 |
+
df = pd.read_csv(io.StringIO(markdown_data), sep="|", engine="python", skiprows=1)
|
28 |
+
|
29 |
+
# Clean up the DataFrame
|
30 |
+
# drop the first and last columns
|
31 |
+
df = df.drop(columns=[df.columns[0], df.columns[-1]])
|
32 |
+
# drop the first row
|
33 |
+
df = df.drop(0)
|
34 |
+
# make the index from the first column: parse the string and extract the values
|
35 |
+
df.index = pd.MultiIndex.from_tuples(
|
36 |
+
[tuple(x.strip()[2:-2].split("', '")) for x in df[df.columns[0]]]
|
37 |
+
)
|
38 |
+
# drop the first column
|
39 |
+
df = df.drop(columns=[df.columns[0]])
|
40 |
+
# parse the column names and create a MultiIndex
|
41 |
+
columns = pd.DataFrame(
|
42 |
+
[parse_identifier(col, defaults=default_key_values or {}) for col in df.columns]
|
43 |
+
)
|
44 |
+
df.columns = pd.MultiIndex.from_frame(columns)
|
45 |
+
|
46 |
+
# Function to parse the values and errors
|
47 |
+
def parse_value_error(value_error_str: str) -> Tuple[float, float]:
|
48 |
+
match = re.match(r"([0-9.]+) \(?± ?([0-9.]+)\)?", value_error_str.strip())
|
49 |
+
if match:
|
50 |
+
return float(match.group(1)), float(match.group(2))
|
51 |
+
raise ValueError(f"Invalid value error string: {value_error_str}")
|
52 |
+
|
53 |
+
df_mean_and_std_cells = df.map(lambda x: parse_value_error(x))
|
54 |
+
# make a new DataFrame with the mean and std values as new rows
|
55 |
+
result = pd.concat(
|
56 |
+
{
|
57 |
+
"mean": df_mean_and_std_cells.map(lambda x: x[0]),
|
58 |
+
"std": df_mean_and_std_cells.map(lambda x: x[1]),
|
59 |
+
},
|
60 |
+
axis=0,
|
61 |
+
)
|
62 |
+
# transpose the DataFrame
|
63 |
+
result = result.T
|
64 |
+
# move new column index level to the most inner level
|
65 |
+
result.columns = pd.MultiIndex.from_tuples(
|
66 |
+
[col[1:] + (col[0],) for col in result.columns], names=column_level_names
|
67 |
+
)
|
68 |
+
return result
|
69 |
+
|
70 |
+
|
71 |
+
def rearrange_for_plotting(
|
72 |
+
data: Union[pd.DataFrame, pd.Series], x_axis: str, x_is_numeric: bool
|
73 |
+
) -> pd.DataFrame:
|
74 |
+
# rearrange the DataFrame for plotting
|
75 |
+
while not isinstance(data, pd.Series):
|
76 |
+
data = data.unstack()
|
77 |
+
result = data.unstack(x_axis)
|
78 |
+
if x_is_numeric:
|
79 |
+
result.columns = result.columns.astype(float)
|
80 |
+
return result
|
81 |
+
|
82 |
+
|
83 |
+
# Function to create plots
|
84 |
+
def create_plot(
|
85 |
+
title,
|
86 |
+
x_axis: str,
|
87 |
+
data: pd.DataFrame,
|
88 |
+
data_err: Optional[pd.DataFrame] = None,
|
89 |
+
x_is_numeric: bool = False,
|
90 |
+
marker_getter: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
91 |
+
line_getter: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
92 |
+
):
|
93 |
+
data = rearrange_for_plotting(data, x_axis, x_is_numeric)
|
94 |
+
# sort the columns by the x_axis values
|
95 |
+
data = data[data.columns.sort_values()]
|
96 |
+
|
97 |
+
if data_err is not None:
|
98 |
+
data_err = rearrange_for_plotting(data_err, x_axis, x_is_numeric)
|
99 |
+
data_err = data_err[data.columns]
|
100 |
+
|
101 |
+
fig = go.Figure()
|
102 |
+
for trace_idx, row_data_mean in data.iterrows():
|
103 |
+
trace_meta = dict(zip(data.index.names, trace_idx))
|
104 |
+
if data_err is not None:
|
105 |
+
error_y = dict(type="data", array=data_err.loc[trace_idx])
|
106 |
+
else:
|
107 |
+
error_y = None
|
108 |
+
fig.add_trace(
|
109 |
+
go.Scatter(
|
110 |
+
x=row_data_mean.index,
|
111 |
+
y=row_data_mean,
|
112 |
+
error_y=error_y,
|
113 |
+
mode="lines+markers",
|
114 |
+
marker=marker_getter(trace_meta) if marker_getter is not None else None,
|
115 |
+
line=line_getter(trace_meta) if line_getter is not None else None,
|
116 |
+
name=", ".join(trace_idx),
|
117 |
+
)
|
118 |
+
)
|
119 |
+
fig.update_layout(title=title, xaxis_title=x_axis, yaxis_title="Values")
|
120 |
+
fig.show()
|
121 |
+
|
122 |
+
|
123 |
+
def prepare_for_markdown(
|
124 |
+
df: pd.DataFrame,
|
125 |
+
aggregation_column_level: Optional[str] = None,
|
126 |
+
round_precision: Optional[int] = None,
|
127 |
+
) -> pd.DataFrame:
|
128 |
+
result = df.copy()
|
129 |
+
# simplify index: create single index from all levels in the format "level1_name=level1_val,level2_name=level2_val,..."
|
130 |
+
if isinstance(result.index, pd.MultiIndex):
|
131 |
+
result.index = [
|
132 |
+
",".join([f"{name}={val}" for name, val in zip(result.index.names, idx)])
|
133 |
+
for idx in result.index
|
134 |
+
]
|
135 |
+
else:
|
136 |
+
result.index = [f"{result.index.name}={idx}" for idx in result.index]
|
137 |
+
result = result.T
|
138 |
+
if round_precision is not None:
|
139 |
+
result = result.round(round_precision)
|
140 |
+
if aggregation_column_level is not None:
|
141 |
+
result_mean = result.xs("mean", level=aggregation_column_level, axis="index")
|
142 |
+
result_std = result.xs("std", level=aggregation_column_level, axis="index")
|
143 |
+
# combine each cell with mean and std into a single string
|
144 |
+
result = pd.DataFrame(
|
145 |
+
{
|
146 |
+
col: [f"{mean} (±{std})" for mean, std in zip(result_mean[col], result_std[col])]
|
147 |
+
for col in result_mean.columns
|
148 |
+
},
|
149 |
+
index=result_mean.index,
|
150 |
+
)
|
151 |
+
|
152 |
+
return result
|
153 |
+
|
154 |
+
|
155 |
+
def combine_job_returns_and_plot(
|
156 |
+
x_axis: str,
|
157 |
+
plot_column_level: str,
|
158 |
+
job_return_paths: Optional[Dict[str, str]] = None,
|
159 |
+
markdown_str: Optional[str] = None,
|
160 |
+
default_key_values: Optional[Dict[str, str]] = None,
|
161 |
+
column_level_names: Optional[List[str]] = None,
|
162 |
+
drop_columns: Optional[Dict[str, str]] = None,
|
163 |
+
aggregation_column_level: Optional[str] = None,
|
164 |
+
title_prefix: Optional[str] = None,
|
165 |
+
x_is_not_numeric: bool = False,
|
166 |
+
show_as: str = "plot",
|
167 |
+
markdown_round_precision: Optional[int] = None,
|
168 |
+
marker_getter: Optional[Callable] = None,
|
169 |
+
line_getter: Optional[Callable] = None,
|
170 |
+
# placeholder to allow description in CONFIGS
|
171 |
+
description: Optional[str] = None,
|
172 |
+
):
|
173 |
+
|
174 |
+
if job_return_paths is not None:
|
175 |
+
df_all = read_nested_jsons(
|
176 |
+
json_paths=job_return_paths,
|
177 |
+
default_key_values=default_key_values,
|
178 |
+
column_level_names=column_level_names,
|
179 |
+
)
|
180 |
+
elif markdown_str is not None:
|
181 |
+
df_all = read_markdown_table(
|
182 |
+
markdown_data=markdown_str,
|
183 |
+
default_key_values=default_key_values,
|
184 |
+
column_level_names=column_level_names,
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
raise ValueError("Either job_return_paths or markdown_str must be provided")
|
188 |
+
|
189 |
+
for metric, value in (drop_columns or {}).items():
|
190 |
+
df_all = df_all.drop(columns=value, level=metric)
|
191 |
+
|
192 |
+
# drop index levels where all values are the same
|
193 |
+
index_levels_to_drop = [
|
194 |
+
i
|
195 |
+
for i in range(df_all.index.nlevels)
|
196 |
+
if len(df_all.index.get_level_values(i).unique()) == 1
|
197 |
+
]
|
198 |
+
dropped_index_levels = {
|
199 |
+
df_all.index.names[i]: df_all.index.get_level_values(i).unique()[0]
|
200 |
+
for i in index_levels_to_drop
|
201 |
+
}
|
202 |
+
if len(index_levels_to_drop) > 0:
|
203 |
+
print(f"Drop index levels: {dropped_index_levels}")
|
204 |
+
df_all = df_all.droplevel(index_levels_to_drop, axis="index")
|
205 |
+
# drop column levels where all values are the same
|
206 |
+
column_levels_to_drop = [
|
207 |
+
i
|
208 |
+
for i in range(df_all.columns.nlevels)
|
209 |
+
if len(df_all.columns.get_level_values(i).unique()) == 1
|
210 |
+
]
|
211 |
+
dropped_column_levels = {
|
212 |
+
df_all.columns.names[i]: df_all.columns.get_level_values(i).unique()[0]
|
213 |
+
for i in column_levels_to_drop
|
214 |
+
}
|
215 |
+
if len(column_levels_to_drop) > 0:
|
216 |
+
print(f"Drop column levels: {dropped_column_levels}")
|
217 |
+
df_all = df_all.droplevel(column_levels_to_drop, axis="columns")
|
218 |
+
|
219 |
+
if show_as == "markdown":
|
220 |
+
print(
|
221 |
+
prepare_for_markdown(
|
222 |
+
df_all,
|
223 |
+
aggregation_column_level=aggregation_column_level,
|
224 |
+
round_precision=markdown_round_precision,
|
225 |
+
).to_markdown()
|
226 |
+
)
|
227 |
+
elif show_as == "plots":
|
228 |
+
# create plots for each "average" value, i.e. MACRO, MICRO, but also label specific values
|
229 |
+
plot_names = df_all.columns.get_level_values(plot_column_level).unique()
|
230 |
+
for plot_name in plot_names:
|
231 |
+
data_plot = df_all.xs(plot_name, level=plot_column_level, axis="columns")
|
232 |
+
data_err = None
|
233 |
+
if aggregation_column_level is not None:
|
234 |
+
data_err = data_plot.xs("std", level=aggregation_column_level, axis="columns")
|
235 |
+
data_plot = data_plot.xs("mean", level=aggregation_column_level, axis="columns")
|
236 |
+
|
237 |
+
# Create plot for MACRO values
|
238 |
+
if title_prefix is not None:
|
239 |
+
plot_name = f"{title_prefix}: {plot_name}"
|
240 |
+
create_plot(
|
241 |
+
title=plot_name,
|
242 |
+
data=data_plot,
|
243 |
+
data_err=data_err,
|
244 |
+
x_axis=x_axis,
|
245 |
+
x_is_numeric=not x_is_not_numeric,
|
246 |
+
marker_getter=marker_getter,
|
247 |
+
line_getter=line_getter,
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
raise ValueError(f"Invalid show_as: {show_as}")
|
251 |
+
|
252 |
+
|
253 |
+
CONFIGS = {
|
254 |
+
"joint model (adus) - last vs best val checkpoint @test": dict(
|
255 |
+
job_return_paths={
|
256 |
+
"epochs=75": "logs/document_evaluation/multiruns/default/2025-01-12_13-28-54/job_return_value.aggregated.json",
|
257 |
+
"epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_18-36-35/job_return_value.aggregated.json",
|
258 |
+
"epochs=150": "logs/document_evaluation/multiruns/default/2025-01-15_16-02-04/job_return_value.aggregated.json",
|
259 |
+
"epochs=150,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_22-07-14/job_return_value.aggregated.json",
|
260 |
+
"epochs=300": "logs/document_evaluation/multiruns/default/2025-01-16_18-50-43/job_return_value.aggregated.json",
|
261 |
+
"epochs=300,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_23-12-52/job_return_value.aggregated.json",
|
262 |
+
},
|
263 |
+
x_axis="epochs",
|
264 |
+
default_key_values={"checkpoint": "best_val"},
|
265 |
+
description="data from https://github.com/ArneBinder/pie-document-level/issues/334",
|
266 |
+
),
|
267 |
+
"joint model (relations) - last vs best val checkpoint @test": dict(
|
268 |
+
job_return_paths={
|
269 |
+
"epochs=75": "logs/document_evaluation/multiruns/default/2025-01-12_13-30-25/job_return_value.aggregated.json",
|
270 |
+
"epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_18-38-55/job_return_value.aggregated.json",
|
271 |
+
"epochs=150": "logs/document_evaluation/multiruns/default/2025-01-15_13-32-33/job_return_value.aggregated.json",
|
272 |
+
"epochs=150,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_22-08-43/job_return_value.aggregated.json",
|
273 |
+
"epochs=300": "logs/document_evaluation/multiruns/default/2025-01-11_16-42-17/job_return_value.aggregated.json",
|
274 |
+
"epochs=300,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_23-14-13/job_return_value.aggregated.json",
|
275 |
+
},
|
276 |
+
x_axis="epochs",
|
277 |
+
default_key_values={"checkpoint": "best_val"},
|
278 |
+
description="data from https://github.com/ArneBinder/pie-document-level/issues/334",
|
279 |
+
),
|
280 |
+
"joint model (adus) - last vs best val checkpoint @val": dict(
|
281 |
+
job_return_paths={
|
282 |
+
"epochs=75": "logs/document_evaluation/multiruns/default/2025-01-17_17-13-46/job_return_value.aggregated.json",
|
283 |
+
"epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-17_22-26-17/job_return_value.aggregated.json",
|
284 |
+
"epochs=150": "logs/document_evaluation/multiruns/default/2025-01-17_19-41-17/job_return_value.aggregated.json",
|
285 |
+
"epochs=150,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-17_22-40-54/job_return_value.aggregated.json",
|
286 |
+
"epochs=300": "logs/document_evaluation/multiruns/default/2025-01-17_20-00-01/job_return_value.aggregated.json",
|
287 |
+
"epochs=300,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-17_22-51-51/job_return_value.aggregated.json",
|
288 |
+
},
|
289 |
+
x_axis="epochs",
|
290 |
+
default_key_values={"checkpoint": "best_val"},
|
291 |
+
description="data from https://github.com/ArneBinder/pie-document-level/issues/334",
|
292 |
+
),
|
293 |
+
"joint model (relations) - last vs best val checkpoint @val": dict(
|
294 |
+
job_return_paths={
|
295 |
+
"epochs=75": "logs/document_evaluation/multiruns/default/2025-01-17_17-16-01/job_return_value.aggregated.json",
|
296 |
+
"epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-17_22-28-24/job_return_value.aggregated.json",
|
297 |
+
"epochs=150": "logs/document_evaluation/multiruns/default/2025-01-17_19-42-59/job_return_value.aggregated.json",
|
298 |
+
"epochs=150,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-17_22-41-19/job_return_value.aggregated.json",
|
299 |
+
"epochs=300": "logs/document_evaluation/multiruns/default/2025-01-17_20-01-16/job_return_value.aggregated.json",
|
300 |
+
"epochs=300,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-17_22-52-16/job_return_value.aggregated.json",
|
301 |
+
},
|
302 |
+
x_axis="epochs",
|
303 |
+
default_key_values={"checkpoint": "best_val"},
|
304 |
+
description="data from https://github.com/ArneBinder/pie-document-level/issues/334",
|
305 |
+
),
|
306 |
+
"joint model (adus) - 27 vs 31 train docs @test": dict(
|
307 |
+
job_return_paths={
|
308 |
+
"num_train_docs=27,epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_18-36-35/job_return_value.aggregated.json",
|
309 |
+
"num_train_docs=31,epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_11-35-12/job_return_value.aggregated.json",
|
310 |
+
},
|
311 |
+
x_axis="num_train_docs",
|
312 |
+
description="data from https://github.com/ArneBinder/pie-document-level/issues/334",
|
313 |
+
),
|
314 |
+
"joint model (relations) - 27 vs 31 train docs @test": dict(
|
315 |
+
job_return_paths={
|
316 |
+
"num_train_docs=27,epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_18-38-55/job_return_value.aggregated.json",
|
317 |
+
"num_train_docs=31,epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_11-36-52/job_return_value.aggregated.json",
|
318 |
+
},
|
319 |
+
x_axis="num_train_docs",
|
320 |
+
description="data from https://github.com/ArneBinder/pie-document-level/issues/334",
|
321 |
+
),
|
322 |
+
}
|
323 |
+
|
324 |
+
DEFAULT_KWARGS = {
|
325 |
+
"column_level_names": ["split", "average", "metric", "aggr"],
|
326 |
+
"plot_column_level": "average",
|
327 |
+
"x_is_not_numeric": False,
|
328 |
+
"aggregation_column_level": "aggr",
|
329 |
+
"drop_columns": {"metric": "s"},
|
330 |
+
"show_as": "plots",
|
331 |
+
"markdown_round_precision": 3,
|
332 |
+
}
|
333 |
+
|
334 |
+
if __name__ == "__main__":
|
335 |
+
|
336 |
+
parser = argparse.ArgumentParser(
|
337 |
+
description="Compare multiple job results for predefined setups (see positional choice argument) "
|
338 |
+
"by creating plots or a markdown table."
|
339 |
+
)
|
340 |
+
parser.add_argument(
|
341 |
+
"config",
|
342 |
+
type=str,
|
343 |
+
help="config name (will also be the title prefix)",
|
344 |
+
choices=CONFIGS.keys(),
|
345 |
+
)
|
346 |
+
parser.add_argument(
|
347 |
+
"--column-level-names",
|
348 |
+
type=lambda x: x.strip().split(","),
|
349 |
+
help="comma separated list of column level names. Note that column levels are "
|
350 |
+
"created for each nesting level in the JSON data",
|
351 |
+
)
|
352 |
+
parser.add_argument("--plot-column-level", type=str, help="column level to create plots for")
|
353 |
+
parser.add_argument("--x-axis", type=str, help="column level to use as x-axis")
|
354 |
+
parser.add_argument("--x-is-not-numeric", help="set if x-axis is not numeric")
|
355 |
+
parser.add_argument(
|
356 |
+
"--aggregation-column-level",
|
357 |
+
type=str,
|
358 |
+
help="column level that contains the aggregation type (e.g. mean, std)",
|
359 |
+
)
|
360 |
+
parser.add_argument(
|
361 |
+
"--drop-columns",
|
362 |
+
type=lambda x: dict(part.split(":") for part in x.strip().split(",")),
|
363 |
+
help="a comma separated list of key-value pairs in the format level_name=level_value to "
|
364 |
+
"drop columns with the specific level values",
|
365 |
+
)
|
366 |
+
parser.add_argument(
|
367 |
+
"--show-as",
|
368 |
+
type=str,
|
369 |
+
help="show the data as 'plots' or 'markdown'",
|
370 |
+
)
|
371 |
+
parser.add_argument(
|
372 |
+
"--markdown-round-precision", type=int, help="round precision for show markdown"
|
373 |
+
)
|
374 |
+
args = parser.parse_args()
|
375 |
+
|
376 |
+
user_kwargs = vars(args)
|
377 |
+
config_name = user_kwargs.pop("config")
|
378 |
+
|
379 |
+
def get_marker(trace_meta):
|
380 |
+
checkpoint2marker_style = {"best_val": "circle-open", "last": "x"}
|
381 |
+
return dict(symbol=checkpoint2marker_style[trace_meta.get("checkpoint", "last")], size=12)
|
382 |
+
|
383 |
+
def get_line(trace_meta):
|
384 |
+
metric2checkpoint2color = {
|
385 |
+
"f1": {"best_val": "lightblue", "last": "blue"},
|
386 |
+
"f": {"best_val": "lightblue", "last": "blue"},
|
387 |
+
"p": {"best_val": "lightgreen", "last": "green"},
|
388 |
+
"r": {"best_val": "lightcoral", "last": "red"},
|
389 |
+
}
|
390 |
+
return dict(
|
391 |
+
color=metric2checkpoint2color[trace_meta["metric"]][
|
392 |
+
trace_meta.get("checkpoint", "last")
|
393 |
+
],
|
394 |
+
width=2,
|
395 |
+
)
|
396 |
+
|
397 |
+
kwargs = {
|
398 |
+
"title_prefix": config_name,
|
399 |
+
"line_getter": get_line,
|
400 |
+
"marker_getter": get_marker,
|
401 |
+
**DEFAULT_KWARGS,
|
402 |
+
**CONFIGS[config_name],
|
403 |
+
}
|
404 |
+
for key, value in user_kwargs.items():
|
405 |
+
if value is not None:
|
406 |
+
kwargs[key] = value
|
407 |
+
combine_job_returns_and_plot(**kwargs)
|
src/data/acl_anthology_crawler.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrootutils
|
2 |
+
|
3 |
+
root = pyrootutils.setup_root(
|
4 |
+
search_from=__file__,
|
5 |
+
indicator=[".project-root"],
|
6 |
+
pythonpath=True,
|
7 |
+
dotenv=True,
|
8 |
+
)
|
9 |
+
|
10 |
+
import os
|
11 |
+
from argparse import ArgumentParser, RawTextHelpFormatter
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
from acl_anthology import Anthology
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
from src.utils.pdf_utils.acl_anthology_utils import XML2RawPapers
|
19 |
+
from src.utils.pdf_utils.process_pdf import (
|
20 |
+
FulltextExtractor,
|
21 |
+
GrobidFulltextExtractor,
|
22 |
+
PDFDownloader,
|
23 |
+
)
|
24 |
+
|
25 |
+
HELP_MSG = """
|
26 |
+
Generate paper json files from an ACL Anthology collection, with fulltext extraction.
|
27 |
+
|
28 |
+
Iterate over entries in the ACL Anthology metadata, and for each entry:
|
29 |
+
1. extract relevant paper info from the xml entry
|
30 |
+
2. download pdf file
|
31 |
+
3. extract fulltext
|
32 |
+
4. format a json file and save
|
33 |
+
|
34 |
+
pre-requisites:
|
35 |
+
- Install the requirements: pip install acl-anthology-py>=0.4.3 bs4 jsonschema
|
36 |
+
- Get the meta data from ACL Anthology: git clone git@github.com:acl-org/acl-anthology.git
|
37 |
+
- Start Grobid Docker container: docker run --rm --init --ulimit core=0 -p 8070:8070 lfoppiano/grobid:0.8.0
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class XML2Jsons:
|
43 |
+
base_output_dir: Path
|
44 |
+
pdf_output_dir: Path
|
45 |
+
|
46 |
+
xml2raw_papers: XML2RawPapers
|
47 |
+
pdf_downloader: PDFDownloader = field(default_factory=PDFDownloader)
|
48 |
+
fulltext_extractor: FulltextExtractor = field(default_factory=GrobidFulltextExtractor)
|
49 |
+
show_progress: bool = True
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def from_cli(cls) -> "XML2Jsons":
|
53 |
+
parser = ArgumentParser(description=HELP_MSG, formatter_class=RawTextHelpFormatter)
|
54 |
+
parser.add_argument(
|
55 |
+
"--base-output-dir", type=str, help="Directory to save all the paper json files"
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"--pdf-output-dir", type=str, help="Directory to save all the downloaded pdf files"
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--anthology-data-dir",
|
62 |
+
type=str,
|
63 |
+
help="Path to ACL Anthology metadata directory, e.g., /path/to/acl-anthology-repo/data. "
|
64 |
+
"You can obtain the data via: git clone git@github.com:acl-org/acl-anthology.git",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--collection-id-filters",
|
68 |
+
nargs="+",
|
69 |
+
type=str,
|
70 |
+
default=None,
|
71 |
+
help="If provided, only papers from the collections whose id (Anthology ID) contains the "
|
72 |
+
"specified strings will be processed.",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--venue-id-whitelist",
|
76 |
+
nargs="+",
|
77 |
+
type=str,
|
78 |
+
default=None,
|
79 |
+
help="If provided, only papers from the specified venues will be processed. See here for "
|
80 |
+
"the list of venues: https://aclanthology.org/venues",
|
81 |
+
)
|
82 |
+
args = parser.parse_args()
|
83 |
+
|
84 |
+
return cls(
|
85 |
+
base_output_dir=Path(args.base_output_dir),
|
86 |
+
pdf_output_dir=Path(args.pdf_output_dir),
|
87 |
+
xml2raw_papers=XML2RawPapers(
|
88 |
+
anthology=Anthology(datadir=args.anthology_data_dir),
|
89 |
+
collection_id_filters=args.collection_id_filters,
|
90 |
+
venue_id_whitelist=args.venue_id_whitelist,
|
91 |
+
),
|
92 |
+
)
|
93 |
+
|
94 |
+
def run(self):
|
95 |
+
os.makedirs(self.pdf_output_dir, exist_ok=True)
|
96 |
+
papers = self.xml2raw_papers()
|
97 |
+
if self.show_progress:
|
98 |
+
papers = tqdm(list(papers), desc="extracting fulltext")
|
99 |
+
for paper in papers:
|
100 |
+
volume_dir = self.base_output_dir / paper.volume_id
|
101 |
+
if paper.url is not None:
|
102 |
+
pdf_save_path = self.pdf_downloader.download(
|
103 |
+
paper.url, opath=self.pdf_output_dir / f"{paper.name}.pdf"
|
104 |
+
)
|
105 |
+
fulltext_extraction_output = self.fulltext_extractor(pdf_save_path)
|
106 |
+
|
107 |
+
if fulltext_extraction_output:
|
108 |
+
plain_text, extraction_data = fulltext_extraction_output
|
109 |
+
paper.fulltext = extraction_data.get("sections")
|
110 |
+
if not paper.abstract:
|
111 |
+
paper.abstract = extraction_data.get("abstract")
|
112 |
+
paper.save(str(volume_dir))
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
xml2jsons = XML2Jsons.from_cli()
|
117 |
+
xml2jsons.run()
|
src/data/calc_iaa_for_brat.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Iterable
|
2 |
+
|
3 |
+
import pyrootutils
|
4 |
+
from pytorch_ie import Document
|
5 |
+
|
6 |
+
root = pyrootutils.setup_root(
|
7 |
+
search_from=__file__,
|
8 |
+
indicator=[".project-root"],
|
9 |
+
pythonpath=True,
|
10 |
+
dotenv=True,
|
11 |
+
)
|
12 |
+
|
13 |
+
import argparse
|
14 |
+
from functools import partial
|
15 |
+
from typing import Callable, List, Optional, Union
|
16 |
+
|
17 |
+
import pandas as pd
|
18 |
+
from pie_datasets import Dataset, load_dataset
|
19 |
+
from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans
|
20 |
+
from pie_modules.document.processing import RelationArgumentSorter, SpansViaRelationMerger
|
21 |
+
from pytorch_ie.metrics import F1Metric
|
22 |
+
|
23 |
+
from src.document.processing import align_predicted_span_annotations
|
24 |
+
|
25 |
+
|
26 |
+
def add_annotations_as_predictions(document: BratDocument, other: BratDocument) -> BratDocument:
|
27 |
+
document = document.copy()
|
28 |
+
other = other.copy()
|
29 |
+
document.spans.predictions.extend(other.spans.clear())
|
30 |
+
gold2gold_span_mapping = {span: span for span in document.spans}
|
31 |
+
predicted2maybe_gold_span = {}
|
32 |
+
for span in document.spans.predictions:
|
33 |
+
predicted2maybe_gold_span[span] = gold2gold_span_mapping.get(span, span)
|
34 |
+
predicted_relations = [
|
35 |
+
rel.copy(
|
36 |
+
head=predicted2maybe_gold_span[rel.head], tail=predicted2maybe_gold_span[rel.tail]
|
37 |
+
)
|
38 |
+
for rel in other.relations.clear()
|
39 |
+
]
|
40 |
+
document.relations.predictions.extend(predicted_relations)
|
41 |
+
return document
|
42 |
+
|
43 |
+
|
44 |
+
def remove_annotations_existing_in_other(
|
45 |
+
document: BratDocumentWithMergedSpans, other: BratDocumentWithMergedSpans
|
46 |
+
) -> BratDocumentWithMergedSpans:
|
47 |
+
result = document.copy(with_annotations=False)
|
48 |
+
document = document.copy()
|
49 |
+
other = other.copy()
|
50 |
+
|
51 |
+
spans = set(document.spans.clear()) - set(other.spans.clear())
|
52 |
+
relations = set(document.relations.clear()) - set(other.relations.clear())
|
53 |
+
result.spans.extend(spans)
|
54 |
+
result.relations.extend(relations)
|
55 |
+
|
56 |
+
return result
|
57 |
+
|
58 |
+
|
59 |
+
def unnest_dict(d):
|
60 |
+
result = {}
|
61 |
+
for key, value in d.items():
|
62 |
+
if isinstance(value, dict):
|
63 |
+
unnested = unnest_dict(value)
|
64 |
+
for k, v in unnested.items():
|
65 |
+
result[(key,) + k] = v
|
66 |
+
else:
|
67 |
+
result[(key,)] = value
|
68 |
+
return result
|
69 |
+
|
70 |
+
|
71 |
+
def calc_brat_iaas(
|
72 |
+
annotator_dirs: List[str],
|
73 |
+
ignore_annotation_dir: Optional[str] = None,
|
74 |
+
combine_fragmented_spans_via_relation: Optional[str] = None,
|
75 |
+
sort_arguments_of_relations: Optional[List[str]] = None,
|
76 |
+
align_spans: bool = False,
|
77 |
+
show_results: bool = False,
|
78 |
+
per_file: bool = False,
|
79 |
+
) -> Union[pd.Series, List[pd.Series]]:
|
80 |
+
if len(annotator_dirs) < 2:
|
81 |
+
raise ValueError("At least two annotation dirs must be provided")
|
82 |
+
|
83 |
+
span_aligner = None
|
84 |
+
if align_spans:
|
85 |
+
span_aligner = partial(align_predicted_span_annotations, span_layer="spans")
|
86 |
+
|
87 |
+
if combine_fragmented_spans_via_relation is not None:
|
88 |
+
print(f"Combine fragmented spans via {combine_fragmented_spans_via_relation} relations")
|
89 |
+
merger = SpansViaRelationMerger(
|
90 |
+
relation_layer="relations",
|
91 |
+
link_relation_label=combine_fragmented_spans_via_relation,
|
92 |
+
create_multi_spans=True,
|
93 |
+
result_document_type=BratDocument,
|
94 |
+
result_field_mapping={"spans": "spans", "relations": "relations"},
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
merger = None
|
98 |
+
|
99 |
+
if sort_arguments_of_relations is not None and len(sort_arguments_of_relations) > 0:
|
100 |
+
print(f"Sort arguments of relations with labels {sort_arguments_of_relations}")
|
101 |
+
relation_argument_sorter = RelationArgumentSorter(
|
102 |
+
relation_layer="relations",
|
103 |
+
label_whitelist=sort_arguments_of_relations, # ["parts_of_same", "semantically_same", "contradicts"],
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
relation_argument_sorter = None
|
107 |
+
|
108 |
+
all_docs = [
|
109 |
+
load_dataset(
|
110 |
+
"pie/brat",
|
111 |
+
name="merge_fragmented_spans",
|
112 |
+
base_dataset_kwargs=dict(data_dir=annotation_dir),
|
113 |
+
split="train",
|
114 |
+
).map(lambda doc: doc.deduplicate_annotations())
|
115 |
+
for annotation_dir in annotator_dirs
|
116 |
+
]
|
117 |
+
|
118 |
+
if ignore_annotation_dir is not None:
|
119 |
+
print(f"Ignoring annotations loaded from {ignore_annotation_dir}")
|
120 |
+
ignore_annotation_docs = load_dataset(
|
121 |
+
"pie/brat",
|
122 |
+
name="merge_fragmented_spans",
|
123 |
+
base_dataset_kwargs=dict(data_dir=ignore_annotation_dir),
|
124 |
+
split="train",
|
125 |
+
)
|
126 |
+
ignore_annotation_docs_dict = {doc.id: doc for doc in ignore_annotation_docs}
|
127 |
+
all_docs = [
|
128 |
+
docs.map(
|
129 |
+
lambda doc: remove_annotations_existing_in_other(
|
130 |
+
doc, other=ignore_annotation_docs_dict[doc.id]
|
131 |
+
)
|
132 |
+
)
|
133 |
+
for docs in all_docs
|
134 |
+
]
|
135 |
+
|
136 |
+
if relation_argument_sorter is not None:
|
137 |
+
all_docs = [docs.map(relation_argument_sorter) for docs in all_docs]
|
138 |
+
|
139 |
+
if per_file:
|
140 |
+
results_per_doc = []
|
141 |
+
for docs_tuple in zip(*all_docs):
|
142 |
+
if show_results:
|
143 |
+
print(f"\ncalculate scores for document id={docs_tuple[0].id} ...")
|
144 |
+
docs = [Dataset.from_documents([doc]) for doc in docs_tuple]
|
145 |
+
result_per_doc = calc_brat_iaas_for_docs(
|
146 |
+
docs, span_aligner=span_aligner, merger=merger, show_results=show_results
|
147 |
+
)
|
148 |
+
results_per_doc.append(result_per_doc)
|
149 |
+
return results_per_doc
|
150 |
+
|
151 |
+
else:
|
152 |
+
return calc_brat_iaas_for_docs(
|
153 |
+
all_docs, span_aligner=span_aligner, merger=merger, show_results=show_results
|
154 |
+
)
|
155 |
+
|
156 |
+
|
157 |
+
def calc_brat_iaas_for_docs(
|
158 |
+
all_docs: List[Dataset],
|
159 |
+
span_aligner: Optional[Callable] = None,
|
160 |
+
merger: Optional[Callable] = None,
|
161 |
+
show_results: bool = False,
|
162 |
+
) -> pd.Series:
|
163 |
+
num_annotators = len(all_docs)
|
164 |
+
all_docs_dict = [{doc.id: doc for doc in docs} for docs in all_docs]
|
165 |
+
gold_predicted = {}
|
166 |
+
for gold_annotator_idx in range(num_annotators):
|
167 |
+
gold = all_docs[gold_annotator_idx]
|
168 |
+
for predicted_annotator_idx in range(num_annotators):
|
169 |
+
if gold_annotator_idx == predicted_annotator_idx:
|
170 |
+
continue
|
171 |
+
predicted_dict = all_docs_dict[predicted_annotator_idx]
|
172 |
+
gold_predicted[(gold_annotator_idx, predicted_annotator_idx)] = gold.map(
|
173 |
+
lambda doc: add_annotations_as_predictions(doc, other=predicted_dict[doc.id])
|
174 |
+
)
|
175 |
+
|
176 |
+
spans_metric = F1Metric(layer="spans", labels="INFERRED", show_as_markdown=True)
|
177 |
+
relations_metric = F1Metric(layer="relations", labels="INFERRED", show_as_markdown=True)
|
178 |
+
|
179 |
+
metric_values = {}
|
180 |
+
for gold_annotator_idx, predicted_annotator_idx in gold_predicted:
|
181 |
+
print(
|
182 |
+
f"calculate scores for annotations {gold_annotator_idx} -> {predicted_annotator_idx}"
|
183 |
+
)
|
184 |
+
for doc in gold_predicted[(gold_annotator_idx, predicted_annotator_idx)]:
|
185 |
+
if span_aligner is not None:
|
186 |
+
doc = span_aligner(doc)
|
187 |
+
if merger is not None:
|
188 |
+
doc = merger(doc)
|
189 |
+
spans_metric(doc)
|
190 |
+
relations_metric(doc)
|
191 |
+
metric_id = f"gold:{gold_annotator_idx},predicted:{predicted_annotator_idx}"
|
192 |
+
metric_values[metric_id] = {
|
193 |
+
"spans": spans_metric.compute(reset=True),
|
194 |
+
"relations": relations_metric.compute(reset=True),
|
195 |
+
}
|
196 |
+
|
197 |
+
result = pd.Series(unnest_dict(metric_values))
|
198 |
+
if show_results:
|
199 |
+
metric_values_series_mean = result.unstack(0).mean(axis=1)
|
200 |
+
metric_values_relations = metric_values_series_mean.xs("relations").unstack()
|
201 |
+
metric_values_spans = metric_values_series_mean.xs("spans").unstack()
|
202 |
+
|
203 |
+
print("\nspans:")
|
204 |
+
print(metric_values_spans.round(decimals=3).to_markdown())
|
205 |
+
|
206 |
+
print("\nrelations:")
|
207 |
+
print(metric_values_relations.round(decimals=3).to_markdown())
|
208 |
+
|
209 |
+
return result
|
210 |
+
|
211 |
+
|
212 |
+
if __name__ == "__main__":
|
213 |
+
|
214 |
+
"""
|
215 |
+
example call:
|
216 |
+
python calc_iaa_for_brat.py \
|
217 |
+
--annotation-dirs annotations/sciarg/v0.9/with_abstracts_rin annotations/sciarg/v0.9/with_abstracts_alisa \
|
218 |
+
--ignore-annotation-dir annotations/sciarg/v0.9/original
|
219 |
+
"""
|
220 |
+
|
221 |
+
parser = argparse.ArgumentParser(
|
222 |
+
description="Calculate inter-annotator agreement for spans and relations in means of F1 "
|
223 |
+
"(exact match, i.e. offsets / arguments and labels must match) for two or more BRAT "
|
224 |
+
"annotation directories."
|
225 |
+
)
|
226 |
+
parser.add_argument(
|
227 |
+
"--annotation-dirs",
|
228 |
+
type=str,
|
229 |
+
required=True,
|
230 |
+
nargs="+",
|
231 |
+
help="List of annotation directories. At least two directories must be provided.",
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--ignore-annotation-dir",
|
235 |
+
type=str,
|
236 |
+
default=None,
|
237 |
+
help="If set, ignore annotations loaded from this directory.",
|
238 |
+
)
|
239 |
+
parser.add_argument(
|
240 |
+
"--combine-fragmented-spans-via-relation",
|
241 |
+
type=str,
|
242 |
+
default=None,
|
243 |
+
help="If set, combine fragmented spans via this relation type.",
|
244 |
+
)
|
245 |
+
parser.add_argument(
|
246 |
+
"--sort-arguments-of-relations",
|
247 |
+
type=str,
|
248 |
+
default=None,
|
249 |
+
nargs="+",
|
250 |
+
help="If set, sort the arguments of the relations with the given labels.",
|
251 |
+
)
|
252 |
+
parser.add_argument(
|
253 |
+
"--align-spans",
|
254 |
+
action="store_true",
|
255 |
+
help="If set, align the spans of the predicted annotations to the gold annotations.",
|
256 |
+
)
|
257 |
+
parser.add_argument(
|
258 |
+
"--per-file",
|
259 |
+
action="store_true",
|
260 |
+
help="If set, calculate IAA per file.",
|
261 |
+
)
|
262 |
+
args = parser.parse_args()
|
263 |
+
|
264 |
+
metric_values_series = calc_brat_iaas(
|
265 |
+
annotator_dirs=args.annotation_dirs,
|
266 |
+
ignore_annotation_dir=args.ignore_annotation_dir,
|
267 |
+
combine_fragmented_spans_via_relation=args.combine_fragmented_spans_via_relation,
|
268 |
+
sort_arguments_of_relations=args.sort_arguments_of_relations,
|
269 |
+
align_spans=args.align_spans,
|
270 |
+
per_file=args.per_file,
|
271 |
+
show_results=True,
|
272 |
+
)
|
src/data/construct_sciarg_abstracts_remaining_gold_retrieval.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrootutils
|
2 |
+
|
3 |
+
root = pyrootutils.setup_root(
|
4 |
+
search_from=__file__,
|
5 |
+
indicator=[".project-root"],
|
6 |
+
pythonpath=True,
|
7 |
+
# dotenv=True,
|
8 |
+
)
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import logging
|
12 |
+
import os
|
13 |
+
from collections import defaultdict
|
14 |
+
from typing import List, Optional, Sequence, Tuple, TypeVar
|
15 |
+
|
16 |
+
import pandas as pd
|
17 |
+
from pie_datasets import load_dataset
|
18 |
+
from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans
|
19 |
+
from pytorch_ie.annotations import LabeledMultiSpan
|
20 |
+
from pytorch_ie.documents import (
|
21 |
+
TextBasedDocument,
|
22 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
23 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
24 |
+
)
|
25 |
+
|
26 |
+
from src.document.processing import replace_substrings_in_text_with_spaces
|
27 |
+
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
def multi_span_is_in_span(multi_span: LabeledMultiSpan, range_span: Tuple[int, int]) -> bool:
|
32 |
+
start, end = range_span
|
33 |
+
starts, ends = zip(*multi_span.slices)
|
34 |
+
return start <= min(starts) and max(ends) <= end
|
35 |
+
|
36 |
+
|
37 |
+
def filter_multi_spans(
|
38 |
+
multi_spans: Sequence[LabeledMultiSpan], filter_span: Tuple[int, int]
|
39 |
+
) -> List[LabeledMultiSpan]:
|
40 |
+
return [
|
41 |
+
span
|
42 |
+
for span in multi_spans
|
43 |
+
if multi_span_is_in_span(multi_span=span, range_span=filter_span)
|
44 |
+
]
|
45 |
+
|
46 |
+
|
47 |
+
def shift_multi_span_slices(
|
48 |
+
slices: Sequence[Tuple[int, int]], shift: int
|
49 |
+
) -> List[Tuple[int, int]]:
|
50 |
+
return [(start + shift, end + shift) for start, end in slices]
|
51 |
+
|
52 |
+
|
53 |
+
def construct_gold_retrievals(
|
54 |
+
doc: TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
55 |
+
symmetric_relations: Optional[List[str]] = None,
|
56 |
+
relation_label_whitelist: Optional[List[str]] = None,
|
57 |
+
) -> Optional[pd.DataFrame]:
|
58 |
+
abstract_annotations = [
|
59 |
+
span for span in doc.labeled_partitions if span.label.lower().strip() == "abstract"
|
60 |
+
]
|
61 |
+
if len(abstract_annotations) != 1:
|
62 |
+
logger.warning(
|
63 |
+
f"Expected exactly one abstract annotation, found {len(abstract_annotations)}"
|
64 |
+
)
|
65 |
+
return None
|
66 |
+
abstract_annotation = abstract_annotations[0]
|
67 |
+
span_abstract = (abstract_annotation.start, abstract_annotation.end)
|
68 |
+
span_remaining = (abstract_annotation.end, len(doc.text))
|
69 |
+
labeled_multi_spans = list(doc.labeled_multi_spans)
|
70 |
+
spans_in_abstract = set(
|
71 |
+
span for span in labeled_multi_spans if multi_span_is_in_span(span, span_abstract)
|
72 |
+
)
|
73 |
+
spans_in_remaining = set(
|
74 |
+
span for span in labeled_multi_spans if multi_span_is_in_span(span, span_remaining)
|
75 |
+
)
|
76 |
+
spans_not_covered = set(labeled_multi_spans) - spans_in_abstract - spans_in_remaining
|
77 |
+
if len(spans_not_covered) > 0:
|
78 |
+
logger.warning(
|
79 |
+
f"Found {len(spans_not_covered)} spans not covered by abstract or remaining text"
|
80 |
+
)
|
81 |
+
|
82 |
+
rel_arg_and_label2other = defaultdict(list)
|
83 |
+
for rel in doc.binary_relations:
|
84 |
+
rel_arg_and_label2other[rel.head].append((rel.tail, rel.label))
|
85 |
+
if symmetric_relations is not None and rel.label in symmetric_relations:
|
86 |
+
label_reversed = rel.label
|
87 |
+
else:
|
88 |
+
label_reversed = f"{rel.label}_reversed"
|
89 |
+
rel_arg_and_label2other[rel.tail].append((rel.head, label_reversed))
|
90 |
+
|
91 |
+
result_rows = []
|
92 |
+
for rel in doc.binary_relations:
|
93 |
+
# we check all semantically_same relations that point from (head) remaining to abstract (tail) ...
|
94 |
+
if rel.label == "semantically_same":
|
95 |
+
if rel.head in spans_in_abstract and rel.tail in spans_in_remaining:
|
96 |
+
# ... and if the head is
|
97 |
+
# candidate_query_span = rel.tail
|
98 |
+
candidate_spans_with_label = rel_arg_and_label2other[rel.tail]
|
99 |
+
for candidate_span, rel_label in candidate_spans_with_label:
|
100 |
+
if (
|
101 |
+
relation_label_whitelist is not None
|
102 |
+
and rel_label not in relation_label_whitelist
|
103 |
+
):
|
104 |
+
continue
|
105 |
+
result_row = {
|
106 |
+
"doc_id": f"{doc.id}.remaining.{span_remaining[0]}.txt",
|
107 |
+
"query_doc_id": f"{doc.id}.abstract.{span_abstract[0]}_{span_abstract[1]}.txt",
|
108 |
+
"span": shift_multi_span_slices(candidate_span.slices, -span_remaining[0]),
|
109 |
+
"query_span": shift_multi_span_slices(rel.head.slices, -span_abstract[0]),
|
110 |
+
"ref_span": shift_multi_span_slices(rel.tail.slices, -span_remaining[0]),
|
111 |
+
"type": rel_label,
|
112 |
+
"label": candidate_span.label,
|
113 |
+
"ref_label": rel.tail.label,
|
114 |
+
}
|
115 |
+
result_rows.append(result_row)
|
116 |
+
|
117 |
+
if len(result_rows) > 0:
|
118 |
+
return pd.DataFrame(result_rows)
|
119 |
+
else:
|
120 |
+
return None
|
121 |
+
|
122 |
+
|
123 |
+
D_text = TypeVar("D_text", bound=TextBasedDocument)
|
124 |
+
|
125 |
+
|
126 |
+
def clean_doc(doc: D_text) -> D_text:
|
127 |
+
# remove xml tags. Note that we also remove the Abstract tag, in contrast to the preprocessing
|
128 |
+
# pipeline (see configs/dataset/sciarg_cleaned.yaml). This is because there, the abstracts are
|
129 |
+
# removed at completely.
|
130 |
+
doc = replace_substrings_in_text_with_spaces(
|
131 |
+
doc,
|
132 |
+
substrings=[
|
133 |
+
"</H2>",
|
134 |
+
"<H3>",
|
135 |
+
"</Document>",
|
136 |
+
"<H1>",
|
137 |
+
"<H2>",
|
138 |
+
"</H3>",
|
139 |
+
"</H1>",
|
140 |
+
"<Abstract>",
|
141 |
+
"</Abstract>",
|
142 |
+
],
|
143 |
+
)
|
144 |
+
return doc
|
145 |
+
|
146 |
+
|
147 |
+
def main(
|
148 |
+
data_dir: str,
|
149 |
+
out_path: str,
|
150 |
+
doc_id_whitelist: Optional[List[str]] = None,
|
151 |
+
symmetric_relations: Optional[List[str]] = None,
|
152 |
+
relation_label_whitelist: Optional[List[str]] = None,
|
153 |
+
) -> None:
|
154 |
+
logger.info(f"Loading dataset from {data_dir}")
|
155 |
+
sciarg_with_abstracts = load_dataset(
|
156 |
+
"pie/sciarg",
|
157 |
+
revision="171478ce3c13cc484be5d7c9bc8f66d7d2f1c210",
|
158 |
+
base_dataset_kwargs={"data_dir": data_dir, "split_paths": None},
|
159 |
+
name="resolve_parts_of_same",
|
160 |
+
split="train",
|
161 |
+
)
|
162 |
+
if issubclass(sciarg_with_abstracts.document_type, BratDocument):
|
163 |
+
ds_converted = sciarg_with_abstracts.to_document_type(
|
164 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
|
165 |
+
)
|
166 |
+
elif issubclass(sciarg_with_abstracts.document_type, BratDocumentWithMergedSpans):
|
167 |
+
ds_converted = sciarg_with_abstracts.to_document_type(
|
168 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
raise ValueError(f"Unsupported document type {sciarg_with_abstracts.document_type}")
|
172 |
+
|
173 |
+
ds_clean = ds_converted.map(clean_doc)
|
174 |
+
if doc_id_whitelist is not None:
|
175 |
+
num_before = len(ds_clean)
|
176 |
+
ds_clean = [doc for doc in ds_clean if doc.id in doc_id_whitelist]
|
177 |
+
logger.info(
|
178 |
+
f"Filtered dataset from {num_before} to {len(ds_clean)} documents based on doc_id_whitelist"
|
179 |
+
)
|
180 |
+
|
181 |
+
results_per_doc = [
|
182 |
+
construct_gold_retrievals(
|
183 |
+
doc,
|
184 |
+
symmetric_relations=symmetric_relations,
|
185 |
+
relation_label_whitelist=relation_label_whitelist,
|
186 |
+
)
|
187 |
+
for doc in ds_clean
|
188 |
+
]
|
189 |
+
results_per_doc_not_empty = [doc for doc in results_per_doc if doc is not None]
|
190 |
+
if len(results_per_doc_not_empty) > 0:
|
191 |
+
results = pd.concat(results_per_doc_not_empty, ignore_index=True)
|
192 |
+
# sort to make the output deterministic
|
193 |
+
results = results.sort_values(
|
194 |
+
by=results.columns.tolist(), ignore_index=True, key=lambda s: s.apply(str)
|
195 |
+
)
|
196 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
197 |
+
logger.info(f"Saving result ({len(results)}) to {out_path}")
|
198 |
+
results.to_json(out_path)
|
199 |
+
else:
|
200 |
+
logger.warning("No results found")
|
201 |
+
|
202 |
+
|
203 |
+
if __name__ == "__main__":
|
204 |
+
parser = argparse.ArgumentParser(
|
205 |
+
description="Create gold retrievals for SciArg-abstracts-remaining in the same format as the retrieval results"
|
206 |
+
)
|
207 |
+
parser.add_argument(
|
208 |
+
"--data_dir",
|
209 |
+
type=str,
|
210 |
+
default="data/annotations/sciarg-with-abstracts-and-cross-section-rels",
|
211 |
+
help="Path to the sciarg data directory",
|
212 |
+
)
|
213 |
+
parser.add_argument(
|
214 |
+
"--out_path",
|
215 |
+
type=str,
|
216 |
+
default="data/retrieval_results/sciarg-with-abstracts-and-cross-section-rels/gold.json",
|
217 |
+
help="Path to save the results",
|
218 |
+
)
|
219 |
+
parser.add_argument(
|
220 |
+
"--symmetric_relations",
|
221 |
+
type=str,
|
222 |
+
nargs="+",
|
223 |
+
default=None,
|
224 |
+
help="Relations that are symmetric, i.e., if A is related to B, then B is related to A",
|
225 |
+
)
|
226 |
+
parser.add_argument(
|
227 |
+
"--relation_label_whitelist",
|
228 |
+
type=str,
|
229 |
+
nargs="+",
|
230 |
+
default=None,
|
231 |
+
help="Only consider relations with these labels",
|
232 |
+
)
|
233 |
+
|
234 |
+
logging.basicConfig(level=logging.INFO)
|
235 |
+
|
236 |
+
kwargs = vars(parser.parse_args())
|
237 |
+
main(**kwargs)
|
238 |
+
logger.info("Done")
|
src/data/prepare_sciarg_crosssection_annotations.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import shutil
|
6 |
+
from collections import defaultdict
|
7 |
+
from typing import Dict, List, Optional, Tuple
|
8 |
+
|
9 |
+
import pandas as pd
|
10 |
+
from pie_datasets import Dataset, IterableDataset, load_dataset
|
11 |
+
from pie_datasets.builders.brat import BratDocumentWithMergedSpans
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
def find_span_idx(raw_text: str, span_string: str) -> Optional[List]:
|
17 |
+
"""
|
18 |
+
Match span string to raw text (document).
|
19 |
+
Return either
|
20 |
+
1) Tuple, 2) List of Tuples (more than one span match), or 3) empty List (no span match).
|
21 |
+
"""
|
22 |
+
# remove possibly accidentally added white spaces
|
23 |
+
span_string.strip()
|
24 |
+
# use raw text input as regex-safe pattern
|
25 |
+
safe = re.escape(span_string)
|
26 |
+
pattern = rf"{safe}"
|
27 |
+
# find match(es)
|
28 |
+
out = [(s.start(), s.end()) for s in re.finditer(pattern, raw_text)]
|
29 |
+
return out
|
30 |
+
|
31 |
+
|
32 |
+
def append_spans_start_and_end(
|
33 |
+
raw_text: str,
|
34 |
+
pd_table: pd.DataFrame,
|
35 |
+
input_cols: List[str],
|
36 |
+
input_idx_cols: List[str],
|
37 |
+
output_cols: List[str],
|
38 |
+
doc_id_col: str = "doc ID",
|
39 |
+
) -> pd.DataFrame:
|
40 |
+
"""
|
41 |
+
Create new column(s) for span indexes (i.e. start and end as Tuple) in pd.DataFrame from span strings.
|
42 |
+
Warn if
|
43 |
+
1) span string does not match anything in document -> None,
|
44 |
+
2) span string is not unique in the document -> List[Tuple].
|
45 |
+
"""
|
46 |
+
pd_table = pd_table.join(pd.DataFrame(columns=output_cols))
|
47 |
+
for idx, pd_row in pd_table.iterrows():
|
48 |
+
for in_col, idx_col, out_col in zip(input_cols, input_idx_cols, output_cols):
|
49 |
+
span_indices = find_span_idx(raw_text, pd_row[in_col])
|
50 |
+
str_idx = pd_row[idx_col]
|
51 |
+
span_idx = None
|
52 |
+
if span_indices is None or len(span_indices) == 0:
|
53 |
+
logger.warning(
|
54 |
+
f'The span "{pd_row[in_col]}" in Column "{in_col}" does not exist in {pd_row[doc_id_col]}.'
|
55 |
+
)
|
56 |
+
elif len(span_indices) == 1:
|
57 |
+
# warn if column is not empty, but span is unique
|
58 |
+
if str_idx == str_idx:
|
59 |
+
logger.warning(f'Column "{idx_col}" is not empty. It has value: {str_idx}.')
|
60 |
+
span_idx = span_indices.pop()
|
61 |
+
else:
|
62 |
+
# warn if span not unique, but column is empty
|
63 |
+
if str_idx != str_idx:
|
64 |
+
logger.warning(
|
65 |
+
f'The span "{pd_row[in_col]}" in Column "{in_col}" is not unique,'
|
66 |
+
f'but, column "{idx_col}" is empty. '
|
67 |
+
f"Need a string index to specify the non-unique span."
|
68 |
+
)
|
69 |
+
else:
|
70 |
+
span_idx = span_indices.pop(int(str_idx))
|
71 |
+
|
72 |
+
if span_idx is not None:
|
73 |
+
pd_table.at[idx, out_col] = span_idx
|
74 |
+
|
75 |
+
# sanity check (NOTE: this should live in a test)
|
76 |
+
search_string = pd_row[in_col]
|
77 |
+
reconstructed_string = raw_text[span_idx[0] : span_idx[1]]
|
78 |
+
if search_string != reconstructed_string:
|
79 |
+
raise ValueError(
|
80 |
+
f"Reconstructed string does not match the original string. "
|
81 |
+
f"Original: {search_string}, Reconstructed: {reconstructed_string}"
|
82 |
+
)
|
83 |
+
return pd_table
|
84 |
+
|
85 |
+
|
86 |
+
def get_texts_from_pie_dataset(
|
87 |
+
doc_ids: List[str], **dataset_kwargs
|
88 |
+
) -> Dict[str, BratDocumentWithMergedSpans]:
|
89 |
+
"""Get texts from a PIE dataset for a list of document IDs.
|
90 |
+
|
91 |
+
:param doc_ids: list of document IDs
|
92 |
+
:param dataset_kwargs: keyword arguments to pass to load_dataset
|
93 |
+
|
94 |
+
:return: a dictionary with document IDs as keys and texts as values
|
95 |
+
"""
|
96 |
+
|
97 |
+
text_based_dataset = load_dataset(**dataset_kwargs)
|
98 |
+
if not isinstance(text_based_dataset, (Dataset, IterableDataset)):
|
99 |
+
raise ValueError(
|
100 |
+
f"Expected a PIE Dataset or PIE IterableDataset, but got a {type(text_based_dataset)} instead."
|
101 |
+
)
|
102 |
+
if not issubclass(text_based_dataset.document_type, BratDocumentWithMergedSpans):
|
103 |
+
raise ValueError(
|
104 |
+
f"Expected a PIE Dataset with BratDocumentWithMergedSpans as document type, "
|
105 |
+
f"but got {text_based_dataset.document_type} instead."
|
106 |
+
)
|
107 |
+
doc_id2text = {doc.id: doc for doc in text_based_dataset}
|
108 |
+
return {doc_id: doc_id2text[doc_id] for doc_id in doc_ids}
|
109 |
+
|
110 |
+
|
111 |
+
def set_span_annotation_ids(
|
112 |
+
table: pd.DataFrame,
|
113 |
+
doc_id2doc: Dict[str, BratDocumentWithMergedSpans],
|
114 |
+
doc_id_col: str,
|
115 |
+
span_idx_cols: List[str],
|
116 |
+
span_id_cols: List[str],
|
117 |
+
) -> pd.DataFrame:
|
118 |
+
"""
|
119 |
+
Create new column(s) for span annotation IDs in pd.DataFrame from span indexes. The annotation IDs are
|
120 |
+
retrieved from the TextBasedDocument object using the span indexes.
|
121 |
+
|
122 |
+
:param table: pd.DataFrame with span indexes, document IDs, and other columns
|
123 |
+
:param doc_id2doc: dictionary with document IDs as keys and BratDocumentWithMergedSpans objects as values
|
124 |
+
:param doc_id_col: column name that contains document IDs
|
125 |
+
:param span_idx_cols: column names that contain span indexes
|
126 |
+
:param span_id_cols: column names for new span ID columns
|
127 |
+
|
128 |
+
:return: pd.DataFrame with new columns for span annotation IDs
|
129 |
+
"""
|
130 |
+
table = table.join(pd.DataFrame(columns=span_id_cols))
|
131 |
+
span2id: Dict[str, Dict[Tuple[int, int], str]] = defaultdict(dict)
|
132 |
+
for doc_id, doc in doc_id2doc.items():
|
133 |
+
for span_id, span in zip(doc.metadata["span_ids"], doc.spans):
|
134 |
+
span2id[doc_id][(span.start, span.end)] = span_id
|
135 |
+
|
136 |
+
for span_idx_col, span_id_col in zip(span_idx_cols, span_id_cols):
|
137 |
+
table[span_id_col] = table.apply(
|
138 |
+
lambda row: span2id[row[doc_id_col]][tuple(row[span_idx_col])], axis=1
|
139 |
+
)
|
140 |
+
|
141 |
+
return table
|
142 |
+
|
143 |
+
|
144 |
+
def set_relation_annotation_ids(
|
145 |
+
table: pd.DataFrame,
|
146 |
+
doc_id2doc: Dict[str, BratDocumentWithMergedSpans],
|
147 |
+
doc_id_col: str,
|
148 |
+
relation_id_col: str,
|
149 |
+
) -> pd.DataFrame:
|
150 |
+
"""create new column for relation annotation IDs in pd.DataFrame. They are simply new ids starting from the last
|
151 |
+
relation annotation id in the document.
|
152 |
+
|
153 |
+
:param table: pd.DataFrame with document IDs and other columns
|
154 |
+
:param doc_id2doc: dictionary with document IDs as keys and BratDocumentWithMergedSpans objects as values
|
155 |
+
:param doc_id_col: column name that contains document IDs
|
156 |
+
:param relation_id_col: column name for new relation ID column
|
157 |
+
|
158 |
+
:return: pd.DataFrame with new column for relation annotation IDs
|
159 |
+
"""
|
160 |
+
|
161 |
+
table = table.join(pd.DataFrame(columns=[relation_id_col]))
|
162 |
+
doc_id2highest_relation_id = defaultdict(int)
|
163 |
+
|
164 |
+
for doc_id, doc in doc_id2doc.items():
|
165 |
+
# relation ids are prefixed with "R" in the dataset
|
166 |
+
doc_id2highest_relation_id[doc_id] = max(
|
167 |
+
[int(relation_id[1:]) for relation_id in doc.metadata["relation_ids"]]
|
168 |
+
)
|
169 |
+
|
170 |
+
for idx, row in table.iterrows():
|
171 |
+
doc_id = row[doc_id_col]
|
172 |
+
doc_id2highest_relation_id[doc_id] += 1
|
173 |
+
table.at[idx, relation_id_col] = f"R{doc_id2highest_relation_id[doc_id]}"
|
174 |
+
|
175 |
+
return table
|
176 |
+
|
177 |
+
|
178 |
+
def main(
|
179 |
+
input_path: str,
|
180 |
+
output_path: str,
|
181 |
+
brat_data_dir: str,
|
182 |
+
input_encoding: str,
|
183 |
+
include_unsure: bool = False,
|
184 |
+
doc_id_col: str = "doc ID",
|
185 |
+
unsure_col: str = "unsure",
|
186 |
+
span_str_cols: List[str] = ["head argument string", "tail argument string"],
|
187 |
+
str_idx_cols: List[str] = ["head string index", "tail string index"],
|
188 |
+
span_idx_cols: List[str] = ["head_span_idx", "tail_span_idx"],
|
189 |
+
span_id_cols: List[str] = ["head_span_id", "tail_span_id"],
|
190 |
+
relation_id_col: str = "relation_id",
|
191 |
+
set_annotation_ids: bool = False,
|
192 |
+
relation_type: str = "relation",
|
193 |
+
) -> None:
|
194 |
+
"""
|
195 |
+
Convert long dependency annotations from a CSV file to a JSON format. The input table should have
|
196 |
+
columns for document IDs, argument span strings, and string indexes (required in the case that the
|
197 |
+
span string occurs multiple times in the base text). The argument span strings are matched to the
|
198 |
+
base text to get the start and end indexes of the span. The output JSON file will have the same
|
199 |
+
columns as the input file, plus two additional columns for the start and end indexes of the spans.
|
200 |
+
|
201 |
+
:param input_path: path to a CSV/Excel file that contains annotations
|
202 |
+
:param output_path: path to save JSON output
|
203 |
+
:param brat_data_dir: directory where the BRAT data (base texts and annotations) is located
|
204 |
+
:param input_encoding: encoding of the input file. Only used for CSV files. Default: "cp1252"
|
205 |
+
:param include_unsure: include annotations marked as unsure. Default: False
|
206 |
+
:param doc_id_col: column name that contains document IDs. Default: "doc ID"
|
207 |
+
:param unsure_col: column name that contains unsure annotations. Default: "unsure"
|
208 |
+
:param span_str_cols: column names that contain span strings. Default: ["head argument string", "tail argument string"]
|
209 |
+
:param str_idx_cols: column names that contain string indexes. Default: ["head string index", "tail string index"]
|
210 |
+
:param span_idx_cols: column names for new span-index columns. Default: ["head_span_idx", "tail_span_idx"]
|
211 |
+
:param span_id_cols: column names for new span-ID columns. Default: ["head_span_id", "tail_span_id"]
|
212 |
+
:param relation_id_col: column name for new relation-ID column. Default: "relation_id"
|
213 |
+
:param set_annotation_ids: set annotation IDs for the spans and relations. Default: False
|
214 |
+
:param relation_type: specify the relation type for the BRAT output. Default: "relation"
|
215 |
+
|
216 |
+
:return: None
|
217 |
+
"""
|
218 |
+
# get annotations from a csv file
|
219 |
+
if input_path.lower().endswith(".csv"):
|
220 |
+
input_df = pd.read_csv(input_path, encoding=input_encoding)
|
221 |
+
elif input_path.lower().endswith(".xlsx"):
|
222 |
+
logger.warning(
|
223 |
+
f"encoding parameter (--input-encoding={input_encoding}) is ignored for Excel files."
|
224 |
+
)
|
225 |
+
input_df = pd.read_excel(input_path)
|
226 |
+
else:
|
227 |
+
raise ValueError("Input file has unexpected format. Please provide a CSV or Excel file.")
|
228 |
+
|
229 |
+
# remove unsure
|
230 |
+
if not include_unsure:
|
231 |
+
input_df = input_df[input_df[unsure_col].isna()]
|
232 |
+
# remove all empty columns
|
233 |
+
input_df = input_df.dropna(axis=1, how="all")
|
234 |
+
|
235 |
+
# define output DataFrame
|
236 |
+
result_df = pd.DataFrame(columns=[*input_df.columns, *span_idx_cols])
|
237 |
+
|
238 |
+
# get unique document IDs
|
239 |
+
doc_ids = list(input_df[doc_id_col].unique())
|
240 |
+
|
241 |
+
# get base texts from a PIE SciArg dataset
|
242 |
+
doc_id2doc = get_texts_from_pie_dataset(
|
243 |
+
doc_ids=doc_ids,
|
244 |
+
path="pie/brat",
|
245 |
+
name="merge_fragmented_spans",
|
246 |
+
split="train",
|
247 |
+
revision="769a15e44e7d691148dd05e54ae2b058ceaed1f0",
|
248 |
+
base_dataset_kwargs=dict(data_dir=brat_data_dir),
|
249 |
+
)
|
250 |
+
|
251 |
+
for doc_id in doc_ids:
|
252 |
+
|
253 |
+
# iterate over each sub-df that contains annotations for a single document
|
254 |
+
doc_df = input_df[input_df[doc_id_col] == doc_id]
|
255 |
+
input_df = input_df.drop(doc_df.index)
|
256 |
+
# get spans' start and end indexes as new columns
|
257 |
+
doc_with_span_indices_df = append_spans_start_and_end(
|
258 |
+
raw_text=doc_id2doc[doc_id].text,
|
259 |
+
pd_table=doc_df,
|
260 |
+
input_cols=span_str_cols,
|
261 |
+
input_idx_cols=str_idx_cols,
|
262 |
+
output_cols=span_idx_cols,
|
263 |
+
)
|
264 |
+
# append this sub-df (with spans' indexes columns) to result_df
|
265 |
+
result_df = pd.concat(
|
266 |
+
[result_df if not result_df.empty else None, doc_with_span_indices_df]
|
267 |
+
)
|
268 |
+
|
269 |
+
out_ext = os.path.splitext(output_path)[1]
|
270 |
+
save_as_brat = out_ext == ""
|
271 |
+
|
272 |
+
if set_annotation_ids or save_as_brat:
|
273 |
+
result_df = set_span_annotation_ids(
|
274 |
+
table=result_df,
|
275 |
+
doc_id2doc=doc_id2doc,
|
276 |
+
doc_id_col=doc_id_col,
|
277 |
+
span_idx_cols=span_idx_cols,
|
278 |
+
span_id_cols=span_id_cols,
|
279 |
+
)
|
280 |
+
result_df = set_relation_annotation_ids(
|
281 |
+
table=result_df,
|
282 |
+
doc_id2doc=doc_id2doc,
|
283 |
+
doc_id_col=doc_id_col,
|
284 |
+
relation_id_col=relation_id_col,
|
285 |
+
)
|
286 |
+
|
287 |
+
base_dir = os.path.dirname(output_path)
|
288 |
+
os.makedirs(base_dir, exist_ok=True)
|
289 |
+
|
290 |
+
if out_ext.lower() == ".json":
|
291 |
+
logger.info(f"Saving output in JSON format to {output_path} ...")
|
292 |
+
result_df.to_json(
|
293 |
+
path_or_buf=output_path,
|
294 |
+
orient="records",
|
295 |
+
lines=True,
|
296 |
+
) # possible orient values: 'split','index', 'table','records', 'columns', 'values'
|
297 |
+
elif save_as_brat:
|
298 |
+
logger.info(f"Saving output in BRAT format to {output_path} ...")
|
299 |
+
os.makedirs(output_path, exist_ok=True)
|
300 |
+
for doc_id in doc_ids:
|
301 |
+
# handle the base text file (just copy from the BRAT data directory)
|
302 |
+
shutil.copy(
|
303 |
+
src=os.path.join(brat_data_dir, f"{doc_id}.txt"),
|
304 |
+
dst=os.path.join(output_path, f"{doc_id}.txt"),
|
305 |
+
)
|
306 |
+
|
307 |
+
# handle the annotation file
|
308 |
+
# first, read the original annotation file
|
309 |
+
input_ann_path = os.path.join(brat_data_dir, f"{doc_id}.ann")
|
310 |
+
with open(input_ann_path, "r") as f:
|
311 |
+
ann_lines = f.readlines()
|
312 |
+
# then, append new relation annotations
|
313 |
+
# The format for each line is (see https://brat.nlplab.org/standoff.html):
|
314 |
+
# R{relation_id}\t{relation_type} Arg1:{span_id1} Arg2:{span_id2}
|
315 |
+
doc_df = result_df[result_df[doc_id_col] == doc_id]
|
316 |
+
logger.info(f"Adding {len(doc_df)} relation annotations to {doc_id}.ann ...")
|
317 |
+
for idx, row in doc_df.iterrows():
|
318 |
+
head_span_id = row[span_id_cols[0]]
|
319 |
+
tail_span_id = row[span_id_cols[1]]
|
320 |
+
relation_id = row[relation_id_col]
|
321 |
+
ann_line = (
|
322 |
+
f"{relation_id}\t{relation_type} Arg1:{head_span_id} Arg2:{tail_span_id}\n"
|
323 |
+
)
|
324 |
+
ann_lines.append(ann_line)
|
325 |
+
# finally, write the new annotation file
|
326 |
+
output_ann_path = os.path.join(output_path, f"{doc_id}.ann")
|
327 |
+
with open(output_ann_path, "w") as f:
|
328 |
+
f.writelines(ann_lines)
|
329 |
+
else:
|
330 |
+
raise ValueError(
|
331 |
+
"Output file has unexpected format. Please provide a JSON file or a directory."
|
332 |
+
)
|
333 |
+
|
334 |
+
logger.info("Done!")
|
335 |
+
|
336 |
+
|
337 |
+
if __name__ == "__main__":
|
338 |
+
|
339 |
+
"""
|
340 |
+
example call:
|
341 |
+
python src/data/prepare_sciarg_crosssection_annotations.py
|
342 |
+
// or //
|
343 |
+
python src/data/prepare_sciarg_crosssection_annotations.py \
|
344 |
+
--input-path data/annotations/sciarg-cross-section/aligned_input.csv \
|
345 |
+
--output-path data/annotations/sciarg-with-abstracts-and-cross-section-rels \
|
346 |
+
--brat-data-dir data/annotations/sciarg-abstracts/v0.9.3/alisa
|
347 |
+
"""
|
348 |
+
|
349 |
+
logging.basicConfig(level=logging.INFO)
|
350 |
+
|
351 |
+
parser = argparse.ArgumentParser(
|
352 |
+
description="Read text files in a directory and a CSV file that contains cross-section annotations. "
|
353 |
+
"Transform the CSV file to a JSON format and save at a specified output directory."
|
354 |
+
)
|
355 |
+
parser.add_argument(
|
356 |
+
"--input-path",
|
357 |
+
type=str,
|
358 |
+
default="data/annotations/sciarg-cross-section/aligned_input.csv",
|
359 |
+
help="Locate a CSV/Excel file.",
|
360 |
+
)
|
361 |
+
parser.add_argument(
|
362 |
+
"--output-path",
|
363 |
+
type=str,
|
364 |
+
default="data/annotations/sciarg-with-abstracts-and-cross-section-rels",
|
365 |
+
help="Specify a path where output will be saved. Should be a JSON file or a directory for BRAT output.",
|
366 |
+
)
|
367 |
+
parser.add_argument(
|
368 |
+
"--brat-data-dir",
|
369 |
+
type=str,
|
370 |
+
default="data/annotations/sciarg-abstracts/v0.9.3/alisa",
|
371 |
+
help="Specify the directory where the BRAT data (base texts and annotations) is located.",
|
372 |
+
)
|
373 |
+
parser.add_argument(
|
374 |
+
"--relation-type",
|
375 |
+
type=str,
|
376 |
+
default="semantically_same",
|
377 |
+
help="Specify the relation type for the BRAT output.",
|
378 |
+
)
|
379 |
+
parser.add_argument(
|
380 |
+
"--input-encoding",
|
381 |
+
type=str,
|
382 |
+
default="cp1252",
|
383 |
+
help="Specify encoding for reading an input file.",
|
384 |
+
)
|
385 |
+
parser.add_argument(
|
386 |
+
"--include-unsure",
|
387 |
+
action="store_true",
|
388 |
+
help="Include annotations marked as unsure.",
|
389 |
+
)
|
390 |
+
parser.add_argument(
|
391 |
+
"--set-annotation-ids",
|
392 |
+
action="store_true",
|
393 |
+
help="Set BRAT annotation IDs for the spans and relations.",
|
394 |
+
)
|
395 |
+
args = parser.parse_args()
|
396 |
+
kwargs = vars(args)
|
397 |
+
|
398 |
+
main(**kwargs)
|
src/data/split_sciarg_abstracts.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrootutils
|
2 |
+
|
3 |
+
root = pyrootutils.setup_root(
|
4 |
+
search_from=__file__,
|
5 |
+
indicator=[".project-root"],
|
6 |
+
pythonpath=True,
|
7 |
+
dotenv=True,
|
8 |
+
)
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import logging
|
12 |
+
import os
|
13 |
+
from typing import List, Optional, TypeVar
|
14 |
+
|
15 |
+
from pie_datasets import load_dataset
|
16 |
+
from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans
|
17 |
+
from pytorch_ie.documents import (
|
18 |
+
TextBasedDocument,
|
19 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
20 |
+
TextDocumentWithLabeledPartitions,
|
21 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
22 |
+
)
|
23 |
+
|
24 |
+
from src.document.processing import replace_substrings_in_text_with_spaces
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
def save_abstract_and_remaining_text(
|
30 |
+
doc: TextDocumentWithLabeledPartitions, base_path: str
|
31 |
+
) -> None:
|
32 |
+
abstract_annotations = [
|
33 |
+
span for span in doc.labeled_partitions if span.label.lower().strip() == "abstract"
|
34 |
+
]
|
35 |
+
if len(abstract_annotations) != 1:
|
36 |
+
logger.warning(
|
37 |
+
f"Expected exactly one abstract annotation, found {len(abstract_annotations)}"
|
38 |
+
)
|
39 |
+
return
|
40 |
+
abstract_annotation = abstract_annotations[0]
|
41 |
+
text_abstract = doc.text[abstract_annotation.start : abstract_annotation.end]
|
42 |
+
text_remaining = doc.text[abstract_annotation.end :]
|
43 |
+
with open(
|
44 |
+
f"{base_path}.abstract.{abstract_annotation.start}_{abstract_annotation.end}.txt", "w"
|
45 |
+
) as f:
|
46 |
+
f.write(text_abstract)
|
47 |
+
with open(f"{base_path}.remaining.{abstract_annotation.end}.txt", "w") as f:
|
48 |
+
f.write(text_remaining)
|
49 |
+
|
50 |
+
|
51 |
+
D_text = TypeVar("D_text", bound=TextBasedDocument)
|
52 |
+
|
53 |
+
|
54 |
+
def clean_doc(doc: D_text) -> D_text:
|
55 |
+
# remove xml tags. Note that we also remove the Abstract tag, in contrast to the preprocessing
|
56 |
+
# pipeline (see configs/dataset/sciarg_cleaned.yaml). This is because there, the abstracts are
|
57 |
+
# removed at completely.
|
58 |
+
doc = replace_substrings_in_text_with_spaces(
|
59 |
+
doc,
|
60 |
+
substrings=[
|
61 |
+
"</H2>",
|
62 |
+
"<H3>",
|
63 |
+
"</Document>",
|
64 |
+
"<H1>",
|
65 |
+
"<H2>",
|
66 |
+
"</H3>",
|
67 |
+
"</H1>",
|
68 |
+
"<Abstract>",
|
69 |
+
"</Abstract>",
|
70 |
+
],
|
71 |
+
)
|
72 |
+
return doc
|
73 |
+
|
74 |
+
|
75 |
+
def main(out_dir: str, doc_id_whitelist: Optional[List[str]] = None) -> None:
|
76 |
+
logger.info("Loading dataset from pie/sciarg")
|
77 |
+
sciarg_with_abstracts = load_dataset(
|
78 |
+
"pie/sciarg",
|
79 |
+
revision="171478ce3c13cc484be5d7c9bc8f66d7d2f1c210",
|
80 |
+
split="train",
|
81 |
+
)
|
82 |
+
if issubclass(sciarg_with_abstracts.document_type, BratDocument):
|
83 |
+
ds_converted = sciarg_with_abstracts.to_document_type(
|
84 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
|
85 |
+
)
|
86 |
+
elif issubclass(sciarg_with_abstracts.document_type, BratDocumentWithMergedSpans):
|
87 |
+
ds_converted = sciarg_with_abstracts.to_document_type(
|
88 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
raise ValueError(f"Unsupported document type {sciarg_with_abstracts.document_type}")
|
92 |
+
|
93 |
+
ds_clean = ds_converted.map(clean_doc)
|
94 |
+
if doc_id_whitelist is not None:
|
95 |
+
num_before = len(ds_clean)
|
96 |
+
ds_clean = [doc for doc in ds_clean if doc.id in doc_id_whitelist]
|
97 |
+
logger.info(
|
98 |
+
f"Filtered dataset from {num_before} to {len(ds_clean)} documents based on doc_id_whitelist"
|
99 |
+
)
|
100 |
+
|
101 |
+
os.makedirs(out_dir, exist_ok=True)
|
102 |
+
logger.info(f"Saving dataset to {out_dir}")
|
103 |
+
for doc in ds_clean:
|
104 |
+
save_abstract_and_remaining_text(doc, os.path.join(out_dir, doc.id))
|
105 |
+
|
106 |
+
|
107 |
+
if __name__ == "__main__":
|
108 |
+
parser = argparse.ArgumentParser(
|
109 |
+
description="Split SciArg dataset into abstract and remaining text"
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"--out_dir",
|
113 |
+
type=str,
|
114 |
+
default="data/datasets/sciarg/abstracts_and_remaining_text",
|
115 |
+
help="Path to save the split data",
|
116 |
+
)
|
117 |
+
parser.add_argument(
|
118 |
+
"--doc_id_whitelist",
|
119 |
+
type=str,
|
120 |
+
nargs="+",
|
121 |
+
default=["A32", "A33", "A34", "A35", "A36", "A37", "A38", "A39", "A40"],
|
122 |
+
help="List of document ids to include in the split",
|
123 |
+
)
|
124 |
+
|
125 |
+
logging.basicConfig(level=logging.INFO)
|
126 |
+
|
127 |
+
kwargs = vars(parser.parse_args())
|
128 |
+
# allow for "all" to include all documents
|
129 |
+
if len(kwargs["doc_id_whitelist"]) == 1 and kwargs["doc_id_whitelist"][0].lower() == "all":
|
130 |
+
kwargs["doc_id_whitelist"] = None
|
131 |
+
main(**kwargs)
|
132 |
+
logger.info("Done")
|
src/demo/annotation_utils.py
CHANGED
@@ -1,7 +1,9 @@
|
|
|
|
1 |
import logging
|
2 |
-
from typing import Optional, Sequence, Union
|
3 |
|
4 |
import gradio as gr
|
|
|
5 |
from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
|
6 |
|
7 |
# this is required to dynamically load the PIE models
|
@@ -10,7 +12,6 @@ from pie_modules.taskmodules import * # noqa: F403
|
|
10 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
11 |
from pytorch_ie import Pipeline
|
12 |
from pytorch_ie.annotations import LabeledSpan
|
13 |
-
from pytorch_ie.auto import AutoPipeline
|
14 |
from pytorch_ie.documents import (
|
15 |
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
16 |
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
@@ -20,9 +21,25 @@ from pytorch_ie.documents import (
|
|
20 |
from pytorch_ie.models import * # noqa: F403
|
21 |
from pytorch_ie.taskmodules import * # noqa: F403
|
22 |
|
|
|
|
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def annotate_document(
|
27 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
28 |
argumentation_model: Pipeline,
|
@@ -40,23 +57,40 @@ def annotate_document(
|
|
40 |
"""
|
41 |
|
42 |
# execute prediction pipeline
|
43 |
-
argumentation_model(
|
|
|
|
|
44 |
|
45 |
if handle_parts_of_same:
|
46 |
-
merger =
|
47 |
-
|
48 |
-
link_relation_label="parts_of_same",
|
49 |
-
create_multi_spans=True,
|
50 |
-
result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
51 |
-
result_field_mapping={
|
52 |
-
"labeled_spans": "labeled_multi_spans",
|
53 |
-
"binary_relations": "binary_relations",
|
54 |
-
"labeled_partitions": "labeled_partitions",
|
55 |
-
},
|
56 |
-
)
|
57 |
-
document = merger(document)
|
58 |
|
59 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
|
62 |
def create_document(
|
@@ -88,32 +122,45 @@ def create_document(
|
|
88 |
return document
|
89 |
|
90 |
|
91 |
-
def
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
try:
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
)
|
114 |
-
gr.Info(
|
115 |
-
f"Loaded argumentation model: model_name={model_name}, revision={revision}, device={device}"
|
116 |
-
)
|
117 |
except Exception as e:
|
118 |
raise gr.Error(f"Failed to load argumentation model: {e}")
|
119 |
|
|
|
1 |
+
import json
|
2 |
import logging
|
3 |
+
from typing import Iterable, Optional, Sequence, Union
|
4 |
|
5 |
import gradio as gr
|
6 |
+
from hydra.utils import instantiate
|
7 |
from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
|
8 |
|
9 |
# this is required to dynamically load the PIE models
|
|
|
12 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
13 |
from pytorch_ie import Pipeline
|
14 |
from pytorch_ie.annotations import LabeledSpan
|
|
|
15 |
from pytorch_ie.documents import (
|
16 |
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
17 |
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
|
|
21 |
from pytorch_ie.models import * # noqa: F403
|
22 |
from pytorch_ie.taskmodules import * # noqa: F403
|
23 |
|
24 |
+
from src.utils import parse_config
|
25 |
+
|
26 |
logger = logging.getLogger(__name__)
|
27 |
|
28 |
|
29 |
+
def get_merger() -> SpansViaRelationMerger:
|
30 |
+
return SpansViaRelationMerger(
|
31 |
+
relation_layer="binary_relations",
|
32 |
+
link_relation_label="parts_of_same",
|
33 |
+
create_multi_spans=True,
|
34 |
+
result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
35 |
+
result_field_mapping={
|
36 |
+
"labeled_spans": "labeled_multi_spans",
|
37 |
+
"binary_relations": "binary_relations",
|
38 |
+
"labeled_partitions": "labeled_partitions",
|
39 |
+
},
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
def annotate_document(
|
44 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
45 |
argumentation_model: Pipeline,
|
|
|
57 |
"""
|
58 |
|
59 |
# execute prediction pipeline
|
60 |
+
result: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions = argumentation_model(
|
61 |
+
document, inplace=True
|
62 |
+
)
|
63 |
|
64 |
if handle_parts_of_same:
|
65 |
+
merger = get_merger()
|
66 |
+
result = merger(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
return result
|
69 |
+
|
70 |
+
|
71 |
+
def annotate_documents(
|
72 |
+
documents: Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions],
|
73 |
+
argumentation_model: Pipeline,
|
74 |
+
handle_parts_of_same: bool = False,
|
75 |
+
) -> Union[
|
76 |
+
Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions],
|
77 |
+
Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions],
|
78 |
+
]:
|
79 |
+
"""Annotate a sequence of documents with the provided pipeline.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
documents: The documents to annotate.
|
83 |
+
argumentation_model: The pipeline to use for annotation.
|
84 |
+
handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
|
85 |
+
"""
|
86 |
+
# execute prediction pipeline
|
87 |
+
result = argumentation_model(documents, inplace=True)
|
88 |
+
|
89 |
+
if handle_parts_of_same:
|
90 |
+
merger = get_merger()
|
91 |
+
result = [merger(document) for document in result]
|
92 |
+
|
93 |
+
return result
|
94 |
|
95 |
|
96 |
def create_document(
|
|
|
122 |
return document
|
123 |
|
124 |
|
125 |
+
def create_documents(
|
126 |
+
texts: Iterable[str], doc_ids: Iterable[str], split_regex: Optional[str] = None
|
127 |
+
) -> Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
|
128 |
+
"""Create a sequence of TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
|
129 |
+
texts.
|
130 |
+
|
131 |
+
Parameters:
|
132 |
+
texts: The texts to process.
|
133 |
+
doc_ids: The IDs of the documents.
|
134 |
+
split_regex: A regular expression pattern to use for splitting the text into partitions.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
The processed documents.
|
138 |
+
"""
|
139 |
+
return [
|
140 |
+
create_document(text=text, doc_id=doc_id, split_regex=split_regex)
|
141 |
+
for text, doc_id in zip(texts, doc_ids)
|
142 |
+
]
|
143 |
+
|
144 |
+
|
145 |
+
def load_argumentation_model(config_str: str, **kwargs) -> Pipeline:
|
146 |
try:
|
147 |
+
config = parse_config(config_str, format="yaml")
|
148 |
+
|
149 |
+
# for PIE AutoPipeline, we need to handle the revision separately for
|
150 |
+
# the taskmodule and the model
|
151 |
+
if (
|
152 |
+
config.get("_target_") == "pytorch_ie.auto.AutoPipeline.from_pretrained"
|
153 |
+
and "revision" in config
|
154 |
+
):
|
155 |
+
revision = config.pop("revision")
|
156 |
+
if "taskmodule_kwargs" not in config:
|
157 |
+
config["taskmodule_kwargs"] = {}
|
158 |
+
config["taskmodule_kwargs"]["revision"] = revision
|
159 |
+
if "model_kwargs" not in config:
|
160 |
+
config["model_kwargs"] = {}
|
161 |
+
config["model_kwargs"]["revision"] = revision
|
162 |
+
model = instantiate(config, **kwargs)
|
163 |
+
gr.Info(f"Loaded argumentation model: {json.dumps({**config, **kwargs})}")
|
|
|
|
|
|
|
164 |
except Exception as e:
|
165 |
raise gr.Error(f"Failed to load argumentation model: {e}")
|
166 |
|
src/demo/backend_utils.py
CHANGED
@@ -2,17 +2,20 @@ import json
|
|
2 |
import logging
|
3 |
import os
|
4 |
import tempfile
|
|
|
5 |
from typing import Iterable, List, Optional, Sequence
|
6 |
|
7 |
import gradio as gr
|
8 |
import pandas as pd
|
|
|
9 |
from pie_datasets import Dataset, IterableDataset, load_dataset
|
10 |
from pytorch_ie import Pipeline
|
11 |
from pytorch_ie.documents import (
|
12 |
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
13 |
)
|
|
|
14 |
|
15 |
-
from src.demo.annotation_utils import
|
16 |
from src.demo.data_utils import load_text_from_arxiv
|
17 |
from src.demo.rendering_utils import (
|
18 |
RENDER_WITH_DISPLACY,
|
@@ -25,6 +28,8 @@ from src.langchain_modules import (
|
|
25 |
DocumentAwareSpanRetriever,
|
26 |
DocumentAwareSpanRetrieverWithRelations,
|
27 |
)
|
|
|
|
|
28 |
|
29 |
logger = logging.getLogger(__name__)
|
30 |
|
@@ -58,20 +63,18 @@ def process_texts(
|
|
58 |
# check that doc_ids are unique
|
59 |
if len(set(doc_ids)) != len(list(doc_ids)):
|
60 |
raise gr.Error("Document IDs must be unique.")
|
61 |
-
pie_documents =
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
65 |
if verbose:
|
66 |
gr.Info(f"Annotate {len(pie_documents)} documents...")
|
67 |
-
pie_documents =
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
)
|
73 |
-
for pie_document in pie_documents
|
74 |
-
]
|
75 |
add_annotated_pie_documents(
|
76 |
retriever=retriever,
|
77 |
pie_documents=pie_documents,
|
@@ -140,6 +143,94 @@ def process_uploaded_files(
|
|
140 |
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
|
141 |
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
def wrapped_add_annotated_pie_documents_from_dataset(
|
144 |
retriever: DocumentAwareSpanRetriever, verbose: bool, layer_captions: dict[str, str], **kwargs
|
145 |
) -> pd.DataFrame:
|
@@ -193,6 +284,7 @@ def render_annotated_document(
|
|
193 |
document_id: str,
|
194 |
render_with: str,
|
195 |
render_kwargs_json: str,
|
|
|
196 |
) -> str:
|
197 |
text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document(
|
198 |
retriever=retriever, document_id=document_id
|
@@ -213,6 +305,7 @@ def render_annotated_document(
|
|
213 |
spans=spans,
|
214 |
span_id2idx=span_id2idx,
|
215 |
binary_relations=relations,
|
|
|
216 |
**render_kwargs,
|
217 |
)
|
218 |
else:
|
|
|
2 |
import logging
|
3 |
import os
|
4 |
import tempfile
|
5 |
+
from pathlib import Path
|
6 |
from typing import Iterable, List, Optional, Sequence
|
7 |
|
8 |
import gradio as gr
|
9 |
import pandas as pd
|
10 |
+
from acl_anthology import Anthology
|
11 |
from pie_datasets import Dataset, IterableDataset, load_dataset
|
12 |
from pytorch_ie import Pipeline
|
13 |
from pytorch_ie.documents import (
|
14 |
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
15 |
)
|
16 |
+
from tqdm import tqdm
|
17 |
|
18 |
+
from src.demo.annotation_utils import annotate_documents, create_documents
|
19 |
from src.demo.data_utils import load_text_from_arxiv
|
20 |
from src.demo.rendering_utils import (
|
21 |
RENDER_WITH_DISPLACY,
|
|
|
28 |
DocumentAwareSpanRetriever,
|
29 |
DocumentAwareSpanRetrieverWithRelations,
|
30 |
)
|
31 |
+
from src.utils.pdf_utils.acl_anthology_utils import XML2RawPapers
|
32 |
+
from src.utils.pdf_utils.process_pdf import FulltextExtractor, PDFDownloader
|
33 |
|
34 |
logger = logging.getLogger(__name__)
|
35 |
|
|
|
63 |
# check that doc_ids are unique
|
64 |
if len(set(doc_ids)) != len(list(doc_ids)):
|
65 |
raise gr.Error("Document IDs must be unique.")
|
66 |
+
pie_documents = create_documents(
|
67 |
+
texts=texts,
|
68 |
+
doc_ids=doc_ids,
|
69 |
+
split_regex=split_regex_escaped,
|
70 |
+
)
|
71 |
if verbose:
|
72 |
gr.Info(f"Annotate {len(pie_documents)} documents...")
|
73 |
+
pie_documents = annotate_documents(
|
74 |
+
documents=pie_documents,
|
75 |
+
argumentation_model=argumentation_model,
|
76 |
+
handle_parts_of_same=handle_parts_of_same,
|
77 |
+
)
|
|
|
|
|
|
|
78 |
add_annotated_pie_documents(
|
79 |
retriever=retriever,
|
80 |
pie_documents=pie_documents,
|
|
|
143 |
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
|
144 |
|
145 |
|
146 |
+
def process_uploaded_pdf_files(
|
147 |
+
pdf_fulltext_extractor: Optional[FulltextExtractor],
|
148 |
+
file_names: List[str],
|
149 |
+
retriever: DocumentAwareSpanRetriever,
|
150 |
+
layer_captions: dict[str, str],
|
151 |
+
**kwargs,
|
152 |
+
) -> pd.DataFrame:
|
153 |
+
try:
|
154 |
+
if pdf_fulltext_extractor is None:
|
155 |
+
raise gr.Error("PDF fulltext extractor is not available.")
|
156 |
+
doc_ids = []
|
157 |
+
texts = []
|
158 |
+
for file_name in file_names:
|
159 |
+
if file_name.lower().endswith(".pdf"):
|
160 |
+
# extract the fulltext from the pdf
|
161 |
+
text_and_extraction_data = pdf_fulltext_extractor(file_name)
|
162 |
+
if text_and_extraction_data is None:
|
163 |
+
raise gr.Error(f"Failed to extract fulltext from PDF: {file_name}")
|
164 |
+
text, _ = text_and_extraction_data
|
165 |
+
|
166 |
+
base_file_name = os.path.basename(file_name)
|
167 |
+
doc_ids.append(base_file_name)
|
168 |
+
texts.append(text)
|
169 |
+
|
170 |
+
else:
|
171 |
+
raise gr.Error(f"Unsupported file format: {file_name}")
|
172 |
+
process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs)
|
173 |
+
except Exception as e:
|
174 |
+
raise gr.Error(f"Failed to process uploaded files: {e}")
|
175 |
+
|
176 |
+
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
|
177 |
+
|
178 |
+
|
179 |
+
def load_acl_anthology_venues(
|
180 |
+
venues: List[str],
|
181 |
+
pdf_fulltext_extractor: Optional[FulltextExtractor],
|
182 |
+
retriever: DocumentAwareSpanRetriever,
|
183 |
+
layer_captions: dict[str, str],
|
184 |
+
acl_anthology_data_dir: Optional[str],
|
185 |
+
pdf_output_dir: Optional[str],
|
186 |
+
show_progress: bool = True,
|
187 |
+
**kwargs,
|
188 |
+
) -> pd.DataFrame:
|
189 |
+
try:
|
190 |
+
if pdf_fulltext_extractor is None:
|
191 |
+
raise gr.Error("PDF fulltext extractor is not available.")
|
192 |
+
if acl_anthology_data_dir is None:
|
193 |
+
raise gr.Error("ACL Anthology data directory is not provided.")
|
194 |
+
if pdf_output_dir is None:
|
195 |
+
raise gr.Error("PDF output directory is not provided.")
|
196 |
+
xml2raw_papers = XML2RawPapers(
|
197 |
+
anthology=Anthology(datadir=Path(acl_anthology_data_dir)),
|
198 |
+
venue_id_whitelist=venues,
|
199 |
+
verbose=False,
|
200 |
+
)
|
201 |
+
pdf_downloader = PDFDownloader()
|
202 |
+
doc_ids = []
|
203 |
+
texts = []
|
204 |
+
os.makedirs(pdf_output_dir, exist_ok=True)
|
205 |
+
papers = xml2raw_papers()
|
206 |
+
if show_progress:
|
207 |
+
papers_list = list(papers)
|
208 |
+
papers = tqdm(papers_list, desc="extracting fulltext")
|
209 |
+
gr.Info(
|
210 |
+
f"Downloading and extracting fulltext from {len(papers_list)} papers in venues: {venues}"
|
211 |
+
)
|
212 |
+
for paper in papers:
|
213 |
+
if paper.url is not None:
|
214 |
+
pdf_save_path = pdf_downloader.download(
|
215 |
+
paper.url, opath=Path(pdf_output_dir) / f"{paper.name}.pdf"
|
216 |
+
)
|
217 |
+
fulltext_extraction_output = pdf_fulltext_extractor(pdf_save_path)
|
218 |
+
|
219 |
+
if fulltext_extraction_output:
|
220 |
+
text, _ = fulltext_extraction_output
|
221 |
+
doc_id = f"aclanthology.org/{paper.name}"
|
222 |
+
doc_ids.append(doc_id)
|
223 |
+
texts.append(text)
|
224 |
+
else:
|
225 |
+
gr.Warning(f"Failed to extract fulltext from PDF: {paper.url}")
|
226 |
+
|
227 |
+
process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs)
|
228 |
+
except Exception as e:
|
229 |
+
raise gr.Error(f"Failed to process uploaded files: {e}")
|
230 |
+
|
231 |
+
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
|
232 |
+
|
233 |
+
|
234 |
def wrapped_add_annotated_pie_documents_from_dataset(
|
235 |
retriever: DocumentAwareSpanRetriever, verbose: bool, layer_captions: dict[str, str], **kwargs
|
236 |
) -> pd.DataFrame:
|
|
|
284 |
document_id: str,
|
285 |
render_with: str,
|
286 |
render_kwargs_json: str,
|
287 |
+
highlight_span_ids: Optional[List[str]] = None,
|
288 |
) -> str:
|
289 |
text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document(
|
290 |
retriever=retriever, document_id=document_id
|
|
|
305 |
spans=spans,
|
306 |
span_id2idx=span_id2idx,
|
307 |
binary_relations=relations,
|
308 |
+
highlight_span_ids=highlight_span_ids,
|
309 |
**render_kwargs,
|
310 |
)
|
311 |
else:
|
src/demo/frontend_utils.py
CHANGED
@@ -24,6 +24,18 @@ def close_accordion():
|
|
24 |
return gr.Accordion(open=False)
|
25 |
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def change_tab(id: Union[int, str]):
|
28 |
return gr.Tabs(selected=id)
|
29 |
|
|
|
24 |
return gr.Accordion(open=False)
|
25 |
|
26 |
|
27 |
+
def open_accordion_with_stats(
|
28 |
+
overview: pd.DataFrame, base_label: str, caption2column: dict[str, str], total_column: str
|
29 |
+
):
|
30 |
+
caption2value = {
|
31 |
+
caption: len(overview) if column == total_column else overview[column].sum()
|
32 |
+
for caption, column in caption2column.items()
|
33 |
+
}
|
34 |
+
stats_str = ", ".join([f"{value} {caption}" for caption, value in caption2value.items()])
|
35 |
+
label = f"{base_label} ({stats_str})"
|
36 |
+
return gr.Accordion(open=True, label=label)
|
37 |
+
|
38 |
+
|
39 |
def change_tab(id: Union[int, str]):
|
40 |
return gr.Tabs(selected=id)
|
41 |
|
src/demo/rendering_utils.py
CHANGED
@@ -15,7 +15,7 @@ AVAILABLE_RENDER_MODES = [RENDER_WITH_DISPLACY, RENDER_WITH_PRETTY_TABLE]
|
|
15 |
|
16 |
# adjusted from rendering_utils_displacy.TPL_ENT
|
17 |
TPL_ENT_WITH_ID = """
|
18 |
-
<mark class="entity" data-entity-id="{entity_id}" data-slice-idx="{slice_idx}" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
|
19 |
{text}
|
20 |
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
|
21 |
</mark>
|
@@ -31,8 +31,12 @@ HIGHLIGHT_SPANS_JS = """
|
|
31 |
color = colors[colorDictKey];
|
32 |
} catch (e) {}
|
33 |
if (color) {
|
|
|
|
|
34 |
entity.style.backgroundColor = color;
|
35 |
entity.style.color = '#000';
|
|
|
|
|
36 |
}
|
37 |
}
|
38 |
|
@@ -42,6 +46,8 @@ HIGHLIGHT_SPANS_JS = """
|
|
42 |
entities.forEach(entity => {
|
43 |
const color = entity.getAttribute('data-color-original');
|
44 |
entity.style.backgroundColor = color;
|
|
|
|
|
45 |
entity.style.color = '';
|
46 |
});
|
47 |
|
@@ -171,6 +177,7 @@ def render_displacy(
|
|
171 |
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
|
172 |
span_id2idx: Dict[str, int],
|
173 |
binary_relations: Sequence[BinaryRelation],
|
|
|
174 |
inject_relations=True,
|
175 |
colors_hover=None,
|
176 |
entity_options={},
|
@@ -180,6 +187,9 @@ def render_displacy(
|
|
180 |
ents: List[Dict[str, Any]] = []
|
181 |
for entity_id, idx in span_id2idx.items():
|
182 |
labeled_span = spans[idx]
|
|
|
|
|
|
|
183 |
# pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
|
184 |
# on hover and to inject the relation data.
|
185 |
if isinstance(labeled_span, LabeledSpan):
|
@@ -188,7 +198,11 @@ def render_displacy(
|
|
188 |
"start": labeled_span.start,
|
189 |
"end": labeled_span.end,
|
190 |
"label": labeled_span.label,
|
191 |
-
"params": {
|
|
|
|
|
|
|
|
|
192 |
}
|
193 |
)
|
194 |
elif isinstance(labeled_span, LabeledMultiSpan):
|
@@ -198,7 +212,11 @@ def render_displacy(
|
|
198 |
"start": start,
|
199 |
"end": end,
|
200 |
"label": labeled_span.label,
|
201 |
-
"params": {
|
|
|
|
|
|
|
|
|
202 |
}
|
203 |
)
|
204 |
else:
|
@@ -254,7 +272,9 @@ def inject_relation_data(
|
|
254 |
entities = soup.find_all(class_="entity")
|
255 |
for entity in entities:
|
256 |
original_color = entity["style"].split("background:")[1].split(";")[0].strip()
|
|
|
257 |
entity["data-color-original"] = original_color
|
|
|
258 |
if additional_colors is not None:
|
259 |
for key, color in additional_colors.items():
|
260 |
entity[f"data-color-{key}"] = (
|
|
|
15 |
|
16 |
# adjusted from rendering_utils_displacy.TPL_ENT
|
17 |
TPL_ENT_WITH_ID = """
|
18 |
+
<mark class="entity" data-entity-id="{entity_id}" data-slice-idx="{slice_idx}" data-highlight-mode="{highlight_mode}" style="background: {bg}; border-width: {border_width}; border-color: {border_color}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
|
19 |
{text}
|
20 |
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
|
21 |
</mark>
|
|
|
31 |
color = colors[colorDictKey];
|
32 |
} catch (e) {}
|
33 |
if (color) {
|
34 |
+
//const highlightMode = entity.getAttribute('data-highlight-mode');
|
35 |
+
//if (highlightMode === 'fill') {
|
36 |
entity.style.backgroundColor = color;
|
37 |
entity.style.color = '#000';
|
38 |
+
//}
|
39 |
+
entity.style.borderColor = color;
|
40 |
}
|
41 |
}
|
42 |
|
|
|
46 |
entities.forEach(entity => {
|
47 |
const color = entity.getAttribute('data-color-original');
|
48 |
entity.style.backgroundColor = color;
|
49 |
+
const borderColor = entity.getAttribute('data-border-color-original');
|
50 |
+
entity.style.borderColor = borderColor;
|
51 |
entity.style.color = '';
|
52 |
});
|
53 |
|
|
|
177 |
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
|
178 |
span_id2idx: Dict[str, int],
|
179 |
binary_relations: Sequence[BinaryRelation],
|
180 |
+
highlight_span_ids: Optional[List[str]] = None,
|
181 |
inject_relations=True,
|
182 |
colors_hover=None,
|
183 |
entity_options={},
|
|
|
187 |
ents: List[Dict[str, Any]] = []
|
188 |
for entity_id, idx in span_id2idx.items():
|
189 |
labeled_span = spans[idx]
|
190 |
+
highlight_mode = (
|
191 |
+
"fill" if highlight_span_ids is None or entity_id in highlight_span_ids else "border"
|
192 |
+
)
|
193 |
# pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
|
194 |
# on hover and to inject the relation data.
|
195 |
if isinstance(labeled_span, LabeledSpan):
|
|
|
198 |
"start": labeled_span.start,
|
199 |
"end": labeled_span.end,
|
200 |
"label": labeled_span.label,
|
201 |
+
"params": {
|
202 |
+
"entity_id": entity_id,
|
203 |
+
"slice_idx": 0,
|
204 |
+
"highlight_mode": highlight_mode,
|
205 |
+
},
|
206 |
}
|
207 |
)
|
208 |
elif isinstance(labeled_span, LabeledMultiSpan):
|
|
|
212 |
"start": start,
|
213 |
"end": end,
|
214 |
"label": labeled_span.label,
|
215 |
+
"params": {
|
216 |
+
"entity_id": entity_id,
|
217 |
+
"slice_idx": i,
|
218 |
+
"highlight_mode": highlight_mode,
|
219 |
+
},
|
220 |
}
|
221 |
)
|
222 |
else:
|
|
|
272 |
entities = soup.find_all(class_="entity")
|
273 |
for entity in entities:
|
274 |
original_color = entity["style"].split("background:")[1].split(";")[0].strip()
|
275 |
+
original_border_color = entity["style"].split("border-color:")[1].split(";")[0].strip()
|
276 |
entity["data-color-original"] = original_color
|
277 |
+
entity["data-border-color-original"] = original_border_color
|
278 |
if additional_colors is not None:
|
279 |
for key, color in additional_colors.items():
|
280 |
entity[f"data-color-{key}"] = (
|
src/demo/rendering_utils_displacy.py
CHANGED
@@ -200,7 +200,18 @@ class EntityRenderer(object):
|
|
200 |
markup += "<br/>"
|
201 |
if self.ents is None or label.upper() in self.ents:
|
202 |
color = self.colors.get(label.upper(), self.default_color)
|
203 |
-
ent_settings = {"label": label, "text": entity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
ent_settings.update(additional_params)
|
205 |
markup += self.ent_template.format(**ent_settings)
|
206 |
else:
|
|
|
200 |
markup += "<br/>"
|
201 |
if self.ents is None or label.upper() in self.ents:
|
202 |
color = self.colors.get(label.upper(), self.default_color)
|
203 |
+
ent_settings = {"label": label, "text": entity}
|
204 |
+
highlight_mode = additional_params.get("highlight_mode", "fill")
|
205 |
+
if highlight_mode == "fill":
|
206 |
+
ent_settings["bg"] = color
|
207 |
+
ent_settings["border_width"] = "0px"
|
208 |
+
ent_settings["border_color"] = color
|
209 |
+
elif highlight_mode == "border":
|
210 |
+
ent_settings["bg"] = "inherit"
|
211 |
+
ent_settings["border_width"] = "2px"
|
212 |
+
ent_settings["border_color"] = color
|
213 |
+
else:
|
214 |
+
raise ValueError(f"Invalid highlight_mode: {highlight_mode}")
|
215 |
ent_settings.update(additional_params)
|
216 |
markup += self.ent_template.format(**ent_settings)
|
217 |
else:
|
src/demo/retrieve_and_dump_all_relevant.py
CHANGED
@@ -9,6 +9,9 @@ root = pyrootutils.setup_root(
|
|
9 |
|
10 |
import argparse
|
11 |
import logging
|
|
|
|
|
|
|
12 |
|
13 |
from src.demo.retriever_utils import (
|
14 |
retrieve_all_relevant_spans,
|
@@ -55,6 +58,29 @@ if __name__ == "__main__":
|
|
55 |
default=None,
|
56 |
help="If provided, retrieve all spans for only this query span.",
|
57 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
args = parser.parse_args()
|
59 |
|
60 |
logging.basicConfig(
|
@@ -74,9 +100,41 @@ if __name__ == "__main__":
|
|
74 |
retriever.load_from_disc(args.data_path)
|
75 |
|
76 |
search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
|
|
|
|
|
|
|
|
|
77 |
logger.info(f"use search_kwargs: {search_kwargs}")
|
78 |
|
79 |
-
if args.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
logger.warning(f"retrieving results for single span: {args.query_span_id}")
|
81 |
all_spans_for_all_documents = retrieve_relevant_spans(
|
82 |
retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
|
@@ -95,7 +153,8 @@ if __name__ == "__main__":
|
|
95 |
logger.warning("no relevant spans found in any document")
|
96 |
exit(0)
|
97 |
|
98 |
-
logger.info(f"dumping results to {args.output_path}...")
|
|
|
99 |
all_spans_for_all_documents.to_json(args.output_path)
|
100 |
|
101 |
logger.info("done")
|
|
|
9 |
|
10 |
import argparse
|
11 |
import logging
|
12 |
+
import os
|
13 |
+
|
14 |
+
import pandas as pd
|
15 |
|
16 |
from src.demo.retriever_utils import (
|
17 |
retrieve_all_relevant_spans,
|
|
|
58 |
default=None,
|
59 |
help="If provided, retrieve all spans for only this query span.",
|
60 |
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--doc_id_whitelist",
|
63 |
+
type=str,
|
64 |
+
nargs="+",
|
65 |
+
default=None,
|
66 |
+
help="If provided, only consider documents with these IDs.",
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"--doc_id_blacklist",
|
70 |
+
type=str,
|
71 |
+
nargs="+",
|
72 |
+
default=None,
|
73 |
+
help="If provided, ignore documents with these IDs.",
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--query_target_doc_id_pairs",
|
77 |
+
type=str,
|
78 |
+
nargs="+",
|
79 |
+
default=None,
|
80 |
+
help="One or more pairs of query and target document IDs "
|
81 |
+
'(each separated by ":") to retrieve spans for. If provided, '
|
82 |
+
"--query_doc_id and --query_span_id are ignored.",
|
83 |
+
)
|
84 |
args = parser.parse_args()
|
85 |
|
86 |
logging.basicConfig(
|
|
|
100 |
retriever.load_from_disc(args.data_path)
|
101 |
|
102 |
search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
|
103 |
+
if args.doc_id_whitelist is not None:
|
104 |
+
search_kwargs["doc_id_whitelist"] = args.doc_id_whitelist
|
105 |
+
if args.doc_id_blacklist is not None:
|
106 |
+
search_kwargs["doc_id_blacklist"] = args.doc_id_blacklist
|
107 |
logger.info(f"use search_kwargs: {search_kwargs}")
|
108 |
|
109 |
+
if args.query_target_doc_id_pairs is not None:
|
110 |
+
all_spans_for_all_documents = None
|
111 |
+
for doc_id_pair in args.query_target_doc_id_pairs:
|
112 |
+
query_doc_id, target_doc_id = doc_id_pair.split(":")
|
113 |
+
current_result = retrieve_all_relevant_spans(
|
114 |
+
retriever=retriever,
|
115 |
+
query_doc_id=query_doc_id,
|
116 |
+
doc_id_whitelist=[target_doc_id],
|
117 |
+
**search_kwargs,
|
118 |
+
)
|
119 |
+
if current_result is None:
|
120 |
+
logger.warning(
|
121 |
+
f"no relevant spans found for query_doc_id={query_doc_id} and "
|
122 |
+
f"target_doc_id={target_doc_id}"
|
123 |
+
)
|
124 |
+
continue
|
125 |
+
logger.info(
|
126 |
+
f"retrieved {len(current_result)} spans for query_doc_id={query_doc_id} "
|
127 |
+
f"and target_doc_id={target_doc_id}"
|
128 |
+
)
|
129 |
+
current_result["query_doc_id"] = query_doc_id
|
130 |
+
if all_spans_for_all_documents is None:
|
131 |
+
all_spans_for_all_documents = current_result
|
132 |
+
else:
|
133 |
+
all_spans_for_all_documents = pd.concat(
|
134 |
+
[all_spans_for_all_documents, current_result], ignore_index=True
|
135 |
+
)
|
136 |
+
|
137 |
+
elif args.query_span_id is not None:
|
138 |
logger.warning(f"retrieving results for single span: {args.query_span_id}")
|
139 |
all_spans_for_all_documents = retrieve_relevant_spans(
|
140 |
retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
|
|
|
153 |
logger.warning("no relevant spans found in any document")
|
154 |
exit(0)
|
155 |
|
156 |
+
logger.info(f"dumping results ({len(all_spans_for_all_documents)}) to {args.output_path}...")
|
157 |
+
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
|
158 |
all_spans_for_all_documents.to_json(args.output_path)
|
159 |
|
160 |
logger.info("done")
|
src/demo/retriever_utils.py
CHANGED
@@ -8,10 +8,8 @@ from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
|
|
8 |
from typing_extensions import Protocol
|
9 |
|
10 |
from src.langchain_modules import DocumentAwareSpanRetriever
|
11 |
-
from src.langchain_modules.span_retriever import
|
12 |
-
|
13 |
-
_parse_config,
|
14 |
-
)
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
@@ -22,13 +20,13 @@ def get_document_as_dict(retriever: DocumentAwareSpanRetriever, doc_id: str) ->
|
|
22 |
|
23 |
|
24 |
def load_retriever(
|
25 |
-
|
26 |
config_format: str,
|
27 |
device: str = "cpu",
|
28 |
previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None,
|
29 |
) -> DocumentAwareSpanRetrieverWithRelations:
|
30 |
try:
|
31 |
-
retriever_config =
|
32 |
# set device for the embeddings pipeline
|
33 |
retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device
|
34 |
result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config)
|
@@ -153,6 +151,7 @@ def _retrieve_for_all_spans(
|
|
153 |
query_doc_id: str,
|
154 |
retrieve_func: RetrieverCallable,
|
155 |
query_span_id_column: str = "query_span_id",
|
|
|
156 |
**kwargs,
|
157 |
) -> Optional[pd.DataFrame]:
|
158 |
if not query_doc_id.strip():
|
@@ -177,6 +176,9 @@ def _retrieve_for_all_spans(
|
|
177 |
# add column with query_span_id
|
178 |
for query_span_id, query_span_result in span_results_not_empty.items():
|
179 |
query_span_result[query_span_id_column] = query_span_id
|
|
|
|
|
|
|
180 |
|
181 |
if len(span_results_not_empty) == 0:
|
182 |
gr.Info(f"No results found for any ADU in document {query_doc_id}.")
|
|
|
8 |
from typing_extensions import Protocol
|
9 |
|
10 |
from src.langchain_modules import DocumentAwareSpanRetriever
|
11 |
+
from src.langchain_modules.span_retriever import DocumentAwareSpanRetrieverWithRelations
|
12 |
+
from src.utils import parse_config
|
|
|
|
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
|
|
20 |
|
21 |
|
22 |
def load_retriever(
|
23 |
+
config_str: str,
|
24 |
config_format: str,
|
25 |
device: str = "cpu",
|
26 |
previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None,
|
27 |
) -> DocumentAwareSpanRetrieverWithRelations:
|
28 |
try:
|
29 |
+
retriever_config = parse_config(config_str, format=config_format)
|
30 |
# set device for the embeddings pipeline
|
31 |
retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device
|
32 |
result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config)
|
|
|
151 |
query_doc_id: str,
|
152 |
retrieve_func: RetrieverCallable,
|
153 |
query_span_id_column: str = "query_span_id",
|
154 |
+
query_span_text_column: Optional[str] = None,
|
155 |
**kwargs,
|
156 |
) -> Optional[pd.DataFrame]:
|
157 |
if not query_doc_id.strip():
|
|
|
176 |
# add column with query_span_id
|
177 |
for query_span_id, query_span_result in span_results_not_empty.items():
|
178 |
query_span_result[query_span_id_column] = query_span_id
|
179 |
+
if query_span_text_column is not None:
|
180 |
+
query_span = retriever.get_span_by_id(span_id=query_span_id)
|
181 |
+
query_span_result[query_span_text_column] = str(query_span)
|
182 |
|
183 |
if len(span_results_not_empty) == 0:
|
184 |
gr.Info(f"No results found for any ADU in document {query_doc_id}.")
|
src/document/processing.py
CHANGED
@@ -1,13 +1,17 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import logging
|
4 |
-
from typing import Any, Dict, Iterable, List,
|
5 |
|
6 |
-
|
7 |
-
from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
|
8 |
from pie_modules.documents import TextDocumentWithLabeledMultiSpansAndBinaryRelations
|
|
|
9 |
from pytorch_ie import AnnotationLayer
|
10 |
from pytorch_ie.core import Document
|
|
|
|
|
|
|
|
|
11 |
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
@@ -64,76 +68,7 @@ def remove_overlapping_entities(
|
|
64 |
return new_doc
|
65 |
|
66 |
|
67 |
-
|
68 |
-
spans: Sequence[LabeledSpan],
|
69 |
-
relations: Sequence[BinaryRelation],
|
70 |
-
link_relation_label: str,
|
71 |
-
create_multi_spans: bool = True,
|
72 |
-
) -> Tuple[Union[Set[LabeledSpan], Set[LabeledMultiSpan]], Set[BinaryRelation]]:
|
73 |
-
# convert list of relations to a graph to easily calculate connected components to merge
|
74 |
-
g = nx.Graph()
|
75 |
-
link_relations = []
|
76 |
-
other_relations = []
|
77 |
-
for rel in relations:
|
78 |
-
if rel.label == link_relation_label:
|
79 |
-
link_relations.append(rel)
|
80 |
-
# never merge spans that have not the same label
|
81 |
-
if (
|
82 |
-
not (isinstance(rel.head, LabeledSpan) or isinstance(rel.tail, LabeledSpan))
|
83 |
-
or rel.head.label == rel.tail.label
|
84 |
-
):
|
85 |
-
g.add_edge(rel.head, rel.tail)
|
86 |
-
else:
|
87 |
-
logger.debug(
|
88 |
-
f"spans to merge do not have the same label, do not merge them: {rel.head}, {rel.tail}"
|
89 |
-
)
|
90 |
-
else:
|
91 |
-
other_relations.append(rel)
|
92 |
-
|
93 |
-
span_mapping = {}
|
94 |
-
connected_components: Set[LabeledSpan]
|
95 |
-
for connected_components in nx.connected_components(g):
|
96 |
-
# all spans in a connected component have the same label
|
97 |
-
label = list(span.label for span in connected_components)[0]
|
98 |
-
connected_components_sorted = sorted(connected_components, key=lambda span: span.start)
|
99 |
-
if create_multi_spans:
|
100 |
-
new_span = LabeledMultiSpan(
|
101 |
-
slices=tuple((span.start, span.end) for span in connected_components_sorted),
|
102 |
-
label=label,
|
103 |
-
)
|
104 |
-
else:
|
105 |
-
new_span = LabeledSpan(
|
106 |
-
start=min(span.start for span in connected_components_sorted),
|
107 |
-
end=max(span.end for span in connected_components_sorted),
|
108 |
-
label=label,
|
109 |
-
)
|
110 |
-
for span in connected_components_sorted:
|
111 |
-
span_mapping[span] = new_span
|
112 |
-
for span in spans:
|
113 |
-
if span not in span_mapping:
|
114 |
-
if create_multi_spans:
|
115 |
-
span_mapping[span] = LabeledMultiSpan(
|
116 |
-
slices=((span.start, span.end),), label=span.label, score=span.score
|
117 |
-
)
|
118 |
-
else:
|
119 |
-
span_mapping[span] = LabeledSpan(
|
120 |
-
start=span.start, end=span.end, label=span.label, score=span.score
|
121 |
-
)
|
122 |
-
|
123 |
-
new_spans = set(span_mapping.values())
|
124 |
-
new_relations = set(
|
125 |
-
BinaryRelation(
|
126 |
-
head=span_mapping[rel.head],
|
127 |
-
tail=span_mapping[rel.tail],
|
128 |
-
label=rel.label,
|
129 |
-
score=rel.score,
|
130 |
-
)
|
131 |
-
for rel in other_relations
|
132 |
-
)
|
133 |
-
|
134 |
-
return new_spans, new_relations
|
135 |
-
|
136 |
-
|
137 |
def merge_spans_via_relation(
|
138 |
document: D,
|
139 |
relation_layer: str,
|
@@ -186,15 +121,50 @@ def merge_spans_via_relation(
|
|
186 |
|
187 |
|
188 |
def remove_partitions_by_labels(
|
189 |
-
document: D, partition_layer: str, label_blacklist: List[str]
|
190 |
) -> D:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
document = document.copy()
|
192 |
-
|
193 |
new_partitions = []
|
194 |
-
for partition in
|
195 |
if partition.label not in label_blacklist:
|
196 |
new_partitions.append(partition)
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
return document
|
199 |
|
200 |
|
@@ -221,3 +191,168 @@ def replace_substrings_in_text(
|
|
221 |
def replace_substrings_in_text_with_spaces(document: D_text, substrings: Iterable[str]) -> D_text:
|
222 |
replacements = {substring: " " * len(substring) for substring in substrings}
|
223 |
return replace_substrings_in_text(document, replacements=replacements)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import logging
|
4 |
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
|
5 |
|
6 |
+
from pie_modules.document.processing.merge_spans_via_relation import _merge_spans_via_relation
|
|
|
7 |
from pie_modules.documents import TextDocumentWithLabeledMultiSpansAndBinaryRelations
|
8 |
+
from pie_modules.utils.span import have_overlap
|
9 |
from pytorch_ie import AnnotationLayer
|
10 |
from pytorch_ie.core import Document
|
11 |
+
from pytorch_ie.core.document import Annotation, _enumerate_dependencies
|
12 |
+
|
13 |
+
from src.utils import distance
|
14 |
+
from src.utils.span_utils import get_overlap_len
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
|
|
68 |
return new_doc
|
69 |
|
70 |
|
71 |
+
# TODO: remove and use pie_modules.document.processing.SpansViaRelationMerger instead
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
def merge_spans_via_relation(
|
73 |
document: D,
|
74 |
relation_layer: str,
|
|
|
121 |
|
122 |
|
123 |
def remove_partitions_by_labels(
|
124 |
+
document: D, partition_layer: str, label_blacklist: List[str], span_layer: Optional[str] = None
|
125 |
) -> D:
|
126 |
+
"""Remove partitions with labels in the blacklist from a document.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
document: The document to process.
|
130 |
+
partition_layer: The name of the partition layer.
|
131 |
+
label_blacklist: The list of labels to remove.
|
132 |
+
span_layer: The name of the span layer to remove spans from if they are not fully
|
133 |
+
contained in any remaining partition. Any dependent annotations will be removed as well.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
The processed document.
|
137 |
+
"""
|
138 |
+
|
139 |
document = document.copy()
|
140 |
+
p_layer: AnnotationLayer = document[partition_layer]
|
141 |
new_partitions = []
|
142 |
+
for partition in p_layer.clear():
|
143 |
if partition.label not in label_blacklist:
|
144 |
new_partitions.append(partition)
|
145 |
+
p_layer.extend(new_partitions)
|
146 |
+
|
147 |
+
if span_layer is not None:
|
148 |
+
result = document.copy(with_annotations=False)
|
149 |
+
removed_span_ids = set()
|
150 |
+
for span in document[span_layer]:
|
151 |
+
# keep spans fully contained in any partition
|
152 |
+
if any(
|
153 |
+
partition.start <= span.start and span.end <= partition.end
|
154 |
+
for partition in new_partitions
|
155 |
+
):
|
156 |
+
result[span_layer].append(span.copy())
|
157 |
+
else:
|
158 |
+
removed_span_ids.add(span._id)
|
159 |
+
|
160 |
+
result.add_all_annotations_from_other(
|
161 |
+
document,
|
162 |
+
removed_annotations={span_layer: removed_span_ids},
|
163 |
+
strict=False,
|
164 |
+
verbose=False,
|
165 |
+
)
|
166 |
+
document = result
|
167 |
+
|
168 |
return document
|
169 |
|
170 |
|
|
|
191 |
def replace_substrings_in_text_with_spaces(document: D_text, substrings: Iterable[str]) -> D_text:
|
192 |
replacements = {substring: " " * len(substring) for substring in substrings}
|
193 |
return replace_substrings_in_text(document, replacements=replacements)
|
194 |
+
|
195 |
+
|
196 |
+
def relabel_annotations(
|
197 |
+
document: D,
|
198 |
+
label_mapping: Dict[str, Dict[str, str]],
|
199 |
+
) -> D:
|
200 |
+
"""
|
201 |
+
Replace annotation labels in a document.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
document: The document to process.
|
205 |
+
label_mapping: A mapping from layer names to mappings from old labels to new labels.
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
The processed document.
|
209 |
+
|
210 |
+
"""
|
211 |
+
|
212 |
+
dependency_ordered_fields: List[str] = []
|
213 |
+
_enumerate_dependencies(
|
214 |
+
dependency_ordered_fields,
|
215 |
+
dependency_graph=document._annotation_graph,
|
216 |
+
nodes=document._annotation_graph["_artificial_root"],
|
217 |
+
)
|
218 |
+
result = document.copy(with_annotations=False)
|
219 |
+
store: Dict[int, Annotation] = {}
|
220 |
+
# not yet used
|
221 |
+
invalid_annotation_ids: Set[int] = set()
|
222 |
+
for field_name in dependency_ordered_fields:
|
223 |
+
if field_name in document._annotation_fields:
|
224 |
+
layer = document[field_name]
|
225 |
+
for is_prediction, anns in [(False, layer), (True, layer.predictions)]:
|
226 |
+
for ann in anns:
|
227 |
+
new_ann = ann.copy_with_store(
|
228 |
+
override_annotation_store=store,
|
229 |
+
invalid_annotation_ids=invalid_annotation_ids,
|
230 |
+
)
|
231 |
+
if field_name in label_mapping:
|
232 |
+
if ann.label in label_mapping[field_name]:
|
233 |
+
new_label = label_mapping[field_name][ann.label]
|
234 |
+
new_ann = new_ann.copy(label=new_label)
|
235 |
+
else:
|
236 |
+
raise ValueError(
|
237 |
+
f"Label {ann.label} not found in label mapping for {field_name}"
|
238 |
+
)
|
239 |
+
store[ann._id] = new_ann
|
240 |
+
target_layer = result[field_name]
|
241 |
+
if is_prediction:
|
242 |
+
target_layer.predictions.append(new_ann)
|
243 |
+
else:
|
244 |
+
target_layer.append(new_ann)
|
245 |
+
|
246 |
+
return result
|
247 |
+
|
248 |
+
|
249 |
+
DWithSpans = TypeVar("DWithSpans", bound=Document)
|
250 |
+
|
251 |
+
|
252 |
+
def align_predicted_span_annotations(
|
253 |
+
document: DWithSpans, span_layer: str, distance_type: str = "center", verbose: bool = False
|
254 |
+
) -> DWithSpans:
|
255 |
+
"""
|
256 |
+
Aligns predicted span annotations with the closest gold spans in a document.
|
257 |
+
|
258 |
+
First, calculates the distance between each predicted span and each gold span. Then,
|
259 |
+
for each predicted span, the gold span with the smallest distance is selected. If the
|
260 |
+
predicted span and the gold span have an overlap of at least half of the maximum length
|
261 |
+
of the two spans, the predicted span is aligned with the gold span.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
document: The document to process.
|
265 |
+
span_layer: The name of the span layer.
|
266 |
+
distance_type: The type of distance to calculate. One of: center, inner, outer
|
267 |
+
verbose: Whether to print debug information.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
The processed document.
|
271 |
+
"""
|
272 |
+
gold_spans = document[span_layer]
|
273 |
+
if len(gold_spans) == 0:
|
274 |
+
return document.copy()
|
275 |
+
|
276 |
+
pred_spans = document[span_layer].predictions
|
277 |
+
old2new_pred_span = {}
|
278 |
+
span_id2gold_span = {}
|
279 |
+
for pred_span in pred_spans:
|
280 |
+
|
281 |
+
gold_spans_with_distance = [
|
282 |
+
(
|
283 |
+
gold_span,
|
284 |
+
distance(
|
285 |
+
start_end=(pred_span.start, pred_span.end),
|
286 |
+
other_start_end=(gold_span.start, gold_span.end),
|
287 |
+
distance_type=distance_type,
|
288 |
+
),
|
289 |
+
)
|
290 |
+
for gold_span in gold_spans
|
291 |
+
]
|
292 |
+
|
293 |
+
closest_gold_span, min_distance = min(gold_spans_with_distance, key=lambda x: x[1])
|
294 |
+
# if the closest gold span is the same as the predicted span, we don't need to align
|
295 |
+
if min_distance == 0.0:
|
296 |
+
continue
|
297 |
+
|
298 |
+
if have_overlap(
|
299 |
+
start_end=(pred_span.start, pred_span.end),
|
300 |
+
other_start_end=(closest_gold_span.start, closest_gold_span.end),
|
301 |
+
):
|
302 |
+
overlap_len = get_overlap_len(
|
303 |
+
(pred_span.start, pred_span.end), (closest_gold_span.start, closest_gold_span.end)
|
304 |
+
)
|
305 |
+
# get the maximum length of the two spans
|
306 |
+
l_max = max(
|
307 |
+
pred_span.end - pred_span.start, closest_gold_span.end - closest_gold_span.start
|
308 |
+
)
|
309 |
+
# if the overlap is at least half of the maximum length, we consider it a valid match for alignment
|
310 |
+
valid_match = overlap_len >= (l_max / 2)
|
311 |
+
else:
|
312 |
+
valid_match = False
|
313 |
+
|
314 |
+
if valid_match:
|
315 |
+
aligned_pred_span = pred_span.copy(
|
316 |
+
start=closest_gold_span.start, end=closest_gold_span.end
|
317 |
+
)
|
318 |
+
old2new_pred_span[pred_span._id] = aligned_pred_span
|
319 |
+
span_id2gold_span[pred_span._id] = closest_gold_span
|
320 |
+
|
321 |
+
result = document.copy(with_annotations=False)
|
322 |
+
|
323 |
+
# multiple predicted spans can be aligned with the same gold span,
|
324 |
+
# so we need to keep track of the added spans
|
325 |
+
added_pred_span_ids = dict()
|
326 |
+
for pred_span in pred_spans:
|
327 |
+
# just add the predicted span if it was not aligned with a gold span
|
328 |
+
if pred_span._id not in old2new_pred_span:
|
329 |
+
# if this was not added before (e.g. as aligned span), add it
|
330 |
+
if pred_span._id not in added_pred_span_ids:
|
331 |
+
keep_pred_span = pred_span.copy()
|
332 |
+
result[span_layer].predictions.append(keep_pred_span)
|
333 |
+
added_pred_span_ids[pred_span._id] = keep_pred_span
|
334 |
+
elif verbose:
|
335 |
+
print(f"Skipping duplicate predicted span. pred_span='{str(pred_span)}'")
|
336 |
+
else:
|
337 |
+
aligned_pred_span = old2new_pred_span[pred_span._id]
|
338 |
+
# if this was not added before (e.g. as aligned or original pred span), add it
|
339 |
+
if aligned_pred_span._id not in added_pred_span_ids:
|
340 |
+
result[span_layer].predictions.append(aligned_pred_span)
|
341 |
+
added_pred_span_ids[aligned_pred_span._id] = aligned_pred_span
|
342 |
+
elif verbose:
|
343 |
+
prev_pred_span = added_pred_span_ids[aligned_pred_span._id]
|
344 |
+
gold_span = span_id2gold_span[pred_span._id]
|
345 |
+
print(
|
346 |
+
f"Skipping duplicate aligned predicted span. aligned gold_span='{str(gold_span)}', "
|
347 |
+
f"prev_pred_span='{str(prev_pred_span)}', current_pred_span='{str(pred_span)}'"
|
348 |
+
)
|
349 |
+
# print("bbb")
|
350 |
+
|
351 |
+
result[span_layer].extend([span.copy() for span in gold_spans])
|
352 |
+
|
353 |
+
# add remaining gold and predicted spans (the result, _aligned_spans, is just for debugging)
|
354 |
+
_aligned_spans = result.add_all_annotations_from_other(
|
355 |
+
document, override_annotations={span_layer: old2new_pred_span}
|
356 |
+
)
|
357 |
+
|
358 |
+
return result
|
src/hydra_callbacks/save_job_return_value.py
CHANGED
@@ -3,7 +3,7 @@ import logging
|
|
3 |
import os
|
4 |
import pickle
|
5 |
from pathlib import Path
|
6 |
-
from typing import Any, Dict, Generator, List, Tuple, Union
|
7 |
|
8 |
import numpy as np
|
9 |
import pandas as pd
|
@@ -33,38 +33,48 @@ def to_py_obj(obj):
|
|
33 |
def list_of_dicts_to_dict_of_lists_recursive(list_of_dicts):
|
34 |
"""Convert a list of dicts to a dict of lists recursively.
|
35 |
|
36 |
-
|
37 |
# works with nested dicts
|
38 |
>>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": {"c": 2}}, {"a": 3, "b": {"c": 4}}])
|
39 |
-
{'
|
40 |
# works with incomplete dicts
|
41 |
>>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": 2}, {"a": 3}])
|
42 |
-
{'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
Args:
|
45 |
list_of_dicts (List[dict]): A list of dicts.
|
46 |
|
47 |
Returns:
|
48 |
-
dict:
|
49 |
"""
|
50 |
-
if
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
58 |
keys.update(d.keys())
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
else:
|
|
|
68 |
return list_of_dicts
|
69 |
|
70 |
|
@@ -77,20 +87,50 @@ def _flatten_dict_gen(d, parent_key: Tuple[str, ...] = ()) -> Generator:
|
|
77 |
yield new_key, v
|
78 |
|
79 |
|
80 |
-
def flatten_dict(d: Dict[str, Any]) -> Dict[Tuple[str, ...], Any]:
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
|
84 |
-
def unflatten_dict(
|
85 |
-
|
|
|
|
|
|
|
86 |
|
87 |
Example:
|
88 |
>>> d = {("a", "b", "c"): 1, ("a", "b", "d"): 2, ("a", "e"): 3}
|
89 |
>>> unflatten_dict(d)
|
90 |
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
|
|
|
|
|
|
|
|
|
|
|
91 |
"""
|
92 |
result: Dict[str, Any] = {}
|
93 |
for k, v in d.items():
|
|
|
|
|
94 |
if len(k) == 0:
|
95 |
if len(result) > 1:
|
96 |
raise ValueError("Cannot unflatten dictionary with multiple root keys.")
|
@@ -152,30 +192,82 @@ class SaveJobReturnValueCallback(Callback):
|
|
152 |
nested), where the keys are the keys of the job return-values and the values are lists of the corresponding
|
153 |
values of all jobs. This is useful if you want to access specific values of all jobs in a multi-run all at once.
|
154 |
Also, aggregated values (e.g. mean, min, max) are created for all numeric values and saved in another file.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
"""
|
156 |
|
157 |
def __init__(
|
158 |
self,
|
159 |
filenames: Union[str, List[str]] = "job_return_value.json",
|
160 |
integrate_multirun_result: bool = False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
) -> None:
|
162 |
self.log = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
163 |
self.filenames = [filenames] if isinstance(filenames, str) else filenames
|
164 |
self.integrate_multirun_result = integrate_multirun_result
|
165 |
self.job_returns: List[JobReturn] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
def on_job_end(self, config: DictConfig, job_return: JobReturn, **kwargs: Any) -> None:
|
168 |
self.job_returns.append(job_return)
|
169 |
output_dir = Path(config.hydra.runtime.output_dir) # / Path(config.hydra.output_subdir)
|
|
|
|
|
|
|
|
|
|
|
170 |
for filename in self.filenames:
|
171 |
self._save(obj=job_return.return_value, filename=filename, output_dir=output_dir)
|
172 |
|
173 |
def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
if self.integrate_multirun_result:
|
175 |
# rearrange the job return-values of all jobs from a multi-run into a dict of lists (maybe nested),
|
176 |
obj = list_of_dicts_to_dict_of_lists_recursive(
|
177 |
[jr.return_value for jr in self.job_returns]
|
178 |
)
|
|
|
|
|
|
|
|
|
|
|
179 |
# also create an aggregated result
|
180 |
# convert to python object to allow selecting numeric columns
|
181 |
obj_py = to_py_obj(obj)
|
@@ -195,6 +287,11 @@ class SaveJobReturnValueCallback(Callback):
|
|
195 |
else:
|
196 |
# aggregate the numeric values
|
197 |
df_described = df_numbers_only.describe()
|
|
|
|
|
|
|
|
|
|
|
198 |
# add the aggregation keys (e.g. mean, min, ...) as most inner keys and convert back to dict
|
199 |
obj_flat_aggregated = df_described.T.stack().to_dict()
|
200 |
# unflatten because _save() works better with nested dicts
|
@@ -202,25 +299,46 @@ class SaveJobReturnValueCallback(Callback):
|
|
202 |
else:
|
203 |
# create a dict of the job return-values of all jobs from a multi-run
|
204 |
# (_save() works better with nested dicts)
|
205 |
-
|
206 |
-
|
|
|
207 |
obj_aggregated = None
|
208 |
output_dir = Path(config.hydra.sweep.dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
for filename in self.filenames:
|
210 |
self._save(
|
211 |
obj=obj,
|
212 |
filename=filename,
|
213 |
output_dir=output_dir,
|
214 |
-
|
215 |
)
|
216 |
# if available, also save the aggregated result
|
217 |
if obj_aggregated is not None:
|
218 |
file_base_name, ext = os.path.splitext(filename)
|
219 |
filename_aggregated = f"{file_base_name}.aggregated{ext}"
|
220 |
-
self._save(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
def _save(
|
223 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
224 |
) -> None:
|
225 |
self.log.info(f"Saving job_return in {output_dir / filename}")
|
226 |
output_dir.mkdir(parents=True, exist_ok=True)
|
@@ -236,23 +354,43 @@ class SaveJobReturnValueCallback(Callback):
|
|
236 |
elif filename.endswith(".md"):
|
237 |
# Convert PyTorch tensors and numpy arrays to native python types
|
238 |
obj_py = to_py_obj(obj)
|
|
|
|
|
239 |
obj_py_flat = flatten_dict(obj_py)
|
240 |
|
241 |
-
if
|
242 |
-
# In the case of multi-run, we expect to have
|
243 |
-
# We therefore just convert the dict to a pandas DataFrame.
|
244 |
result = pd.DataFrame(obj_py_flat)
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
else:
|
246 |
-
#
|
247 |
-
# We therefore convert the dict to a pandas Series and ...
|
248 |
series = pd.Series(obj_py_flat)
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
else:
|
253 |
-
# ... otherwise we just unpack the one-entry index values and save the resulting Series.
|
254 |
series.index = series.index.get_level_values(0)
|
255 |
result = series
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
with open(str(output_dir / filename), "w") as file:
|
258 |
file.write(result.to_markdown())
|
|
|
3 |
import os
|
4 |
import pickle
|
5 |
from pathlib import Path
|
6 |
+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
7 |
|
8 |
import numpy as np
|
9 |
import pandas as pd
|
|
|
33 |
def list_of_dicts_to_dict_of_lists_recursive(list_of_dicts):
|
34 |
"""Convert a list of dicts to a dict of lists recursively.
|
35 |
|
36 |
+
Examples:
|
37 |
# works with nested dicts
|
38 |
>>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": {"c": 2}}, {"a": 3, "b": {"c": 4}}])
|
39 |
+
{'a': [1, 3], 'b': {'c': [2, 4]}}
|
40 |
# works with incomplete dicts
|
41 |
>>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": 2}, {"a": 3}])
|
42 |
+
{'a': [1, 3], 'b': [2, None]}
|
43 |
+
|
44 |
+
# works with nested incomplete dicts
|
45 |
+
>>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": {"c": 2}}, {"a": 3}])
|
46 |
+
{'a': [1, 3], 'b': {'c': [2, None]}}
|
47 |
+
|
48 |
+
# works with nested incomplete dicts with None values
|
49 |
+
>>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": {"c": 2}}, {"a": None}])
|
50 |
+
{'a': [1, None], 'b': {'c': [2, None]}}
|
51 |
|
52 |
Args:
|
53 |
list_of_dicts (List[dict]): A list of dicts.
|
54 |
|
55 |
Returns:
|
56 |
+
dict: An arbitrarily nested dict of lists.
|
57 |
"""
|
58 |
+
if not list_of_dicts:
|
59 |
+
return {}
|
60 |
+
|
61 |
+
# Check if all elements are either None or dictionaries
|
62 |
+
if all(d is None or isinstance(d, dict) for d in list_of_dicts):
|
63 |
+
# Gather all keys from non-None dictionaries
|
64 |
+
keys = set()
|
65 |
+
for d in list_of_dicts:
|
66 |
+
if d is not None:
|
67 |
keys.update(d.keys())
|
68 |
+
|
69 |
+
# Build up the result recursively
|
70 |
+
return {
|
71 |
+
k: list_of_dicts_to_dict_of_lists_recursive(
|
72 |
+
[(d[k] if d is not None and k in d else None) for d in list_of_dicts]
|
73 |
+
)
|
74 |
+
for k in keys
|
75 |
+
}
|
76 |
else:
|
77 |
+
# If items are not all dict/None, just return the list as is (base case).
|
78 |
return list_of_dicts
|
79 |
|
80 |
|
|
|
87 |
yield new_key, v
|
88 |
|
89 |
|
90 |
+
def flatten_dict(d: Dict[str, Any], pad_keys: bool = True) -> Dict[Tuple[str, ...], Any]:
|
91 |
+
"""Flattens a dictionary with nested keys. Per default, the keys are padded with np.nan to have
|
92 |
+
the same length.
|
93 |
+
|
94 |
+
Example:
|
95 |
+
>>> d = {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
|
96 |
+
>>> flatten_dict(d)
|
97 |
+
{('a', 'b', 'c'): 1, ('a', 'b', 'd'): 2, ('a', 'e', np.nan): 3}
|
98 |
+
|
99 |
+
# with padding the keys
|
100 |
+
>>> d = {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
|
101 |
+
>>> flatten_dict(d, pad_keys=False)
|
102 |
+
{('a', 'b', 'c'): 1, ('a', 'b', 'd'): 2, ('a', 'e'): 3}
|
103 |
+
"""
|
104 |
+
result = dict(_flatten_dict_gen(d))
|
105 |
+
# pad the keys with np.nan to have the same length. We use np.nan to be pandas-friendly.
|
106 |
+
if pad_keys:
|
107 |
+
max_num_keys = max(len(k) for k in result.keys())
|
108 |
+
result = {
|
109 |
+
tuple(list(k) + [np.nan] * (max_num_keys - len(k))): v for k, v in result.items()
|
110 |
+
}
|
111 |
+
return result
|
112 |
|
113 |
|
114 |
+
def unflatten_dict(
|
115 |
+
d: Dict[Tuple[str, ...], Any], unpad_keys: bool = True
|
116 |
+
) -> Union[Dict[str, Any], Any]:
|
117 |
+
"""Unflattens a dictionary with nested keys. Per default, the keys are unpadded by removing
|
118 |
+
np.nan values.
|
119 |
|
120 |
Example:
|
121 |
>>> d = {("a", "b", "c"): 1, ("a", "b", "d"): 2, ("a", "e"): 3}
|
122 |
>>> unflatten_dict(d)
|
123 |
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
|
124 |
+
|
125 |
+
# with unpad the keys
|
126 |
+
>>> d = {("a", "b", "c"): 1, ("a", "b", "d"): 2, ("a", "e", np.nan): 3}
|
127 |
+
>>> unflatten_dict(d)
|
128 |
+
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
|
129 |
"""
|
130 |
result: Dict[str, Any] = {}
|
131 |
for k, v in d.items():
|
132 |
+
if unpad_keys:
|
133 |
+
k = tuple([ki for ki in k if not pd.isna(ki)])
|
134 |
if len(k) == 0:
|
135 |
if len(result) > 1:
|
136 |
raise ValueError("Cannot unflatten dictionary with multiple root keys.")
|
|
|
192 |
nested), where the keys are the keys of the job return-values and the values are lists of the corresponding
|
193 |
values of all jobs. This is useful if you want to access specific values of all jobs in a multi-run all at once.
|
194 |
Also, aggregated values (e.g. mean, min, max) are created for all numeric values and saved in another file.
|
195 |
+
multirun_aggregator_blacklist: List[str] (default: None)
|
196 |
+
A list of keys to exclude from the aggregation (of multirun results), such as "count" or "25%". If None,
|
197 |
+
all keys are included. See pd.DataFrame.describe() for possible aggregation keys.
|
198 |
+
For numeric values, it is recommended to use ["min", "25%", "50%", "75%", "max"]
|
199 |
+
which will result in keeping only the count, mean and std values.
|
200 |
+
multirun_create_ids_from_overrides: bool (default: True)
|
201 |
+
Create job identifiers from the overrides of the jobs in a multi-run. If False, the job index is used as
|
202 |
+
identifier.
|
203 |
+
markdown_round_digits: int (default: 3)
|
204 |
+
The number of digits to round the values in the markdown file. If None, no rounding is applied.
|
205 |
+
multirun_job_id_key: str (default: "job_id")
|
206 |
+
The key to use for the job identifiers in the integrated multi-run result.
|
207 |
+
paths_file: str (default: None)
|
208 |
+
The file to save the paths of the log directories to. If None, the paths are not saved.
|
209 |
+
path_id: str (default: None)
|
210 |
+
A prefix to add to each line in the paths_file separated by a colon. If None, no prefix is added.
|
211 |
+
multirun_paths_file: str (default: None)
|
212 |
+
The file to save the paths of the multi-run log directories to. If None, the paths are not saved.
|
213 |
+
multirun_path_id: str (default: None)
|
214 |
+
A prefix to add to each line in the multirun_paths_file separated by a colon. If None, no prefix is added.
|
215 |
"""
|
216 |
|
217 |
def __init__(
|
218 |
self,
|
219 |
filenames: Union[str, List[str]] = "job_return_value.json",
|
220 |
integrate_multirun_result: bool = False,
|
221 |
+
multirun_aggregator_blacklist: Optional[List[str]] = None,
|
222 |
+
multirun_create_ids_from_overrides: bool = True,
|
223 |
+
markdown_round_digits: Optional[int] = 3,
|
224 |
+
multirun_job_id_key: str = "job_id",
|
225 |
+
paths_file: Optional[str] = None,
|
226 |
+
path_id: Optional[str] = None,
|
227 |
+
multirun_paths_file: Optional[str] = None,
|
228 |
+
multirun_path_id: Optional[str] = None,
|
229 |
) -> None:
|
230 |
self.log = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
231 |
self.filenames = [filenames] if isinstance(filenames, str) else filenames
|
232 |
self.integrate_multirun_result = integrate_multirun_result
|
233 |
self.job_returns: List[JobReturn] = []
|
234 |
+
self.multirun_aggregator_blacklist = multirun_aggregator_blacklist
|
235 |
+
self.multirun_create_ids_from_overrides = multirun_create_ids_from_overrides
|
236 |
+
self.multirun_job_id_key = multirun_job_id_key
|
237 |
+
self.markdown_round_digits = markdown_round_digits
|
238 |
+
self.multirun_paths_file = multirun_paths_file
|
239 |
+
self.multirun_path_id = multirun_path_id
|
240 |
+
self.paths_file = paths_file
|
241 |
+
self.path_id = path_id
|
242 |
|
243 |
def on_job_end(self, config: DictConfig, job_return: JobReturn, **kwargs: Any) -> None:
|
244 |
self.job_returns.append(job_return)
|
245 |
output_dir = Path(config.hydra.runtime.output_dir) # / Path(config.hydra.output_subdir)
|
246 |
+
if self.paths_file is not None:
|
247 |
+
# append the output_dir to the file
|
248 |
+
with open(self.paths_file, "a") as file:
|
249 |
+
file.write(f"{output_dir}\n")
|
250 |
+
|
251 |
for filename in self.filenames:
|
252 |
self._save(obj=job_return.return_value, filename=filename, output_dir=output_dir)
|
253 |
|
254 |
def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
|
255 |
+
job_ids: Union[List[str], List[int]]
|
256 |
+
if self.multirun_create_ids_from_overrides:
|
257 |
+
job_ids = overrides_to_identifiers([jr.overrides for jr in self.job_returns])
|
258 |
+
else:
|
259 |
+
job_ids = list(range(len(self.job_returns)))
|
260 |
+
|
261 |
if self.integrate_multirun_result:
|
262 |
# rearrange the job return-values of all jobs from a multi-run into a dict of lists (maybe nested),
|
263 |
obj = list_of_dicts_to_dict_of_lists_recursive(
|
264 |
[jr.return_value for jr in self.job_returns]
|
265 |
)
|
266 |
+
if not isinstance(obj, dict):
|
267 |
+
obj = {"value": obj}
|
268 |
+
if self.multirun_create_ids_from_overrides:
|
269 |
+
obj[self.multirun_job_id_key] = job_ids
|
270 |
+
|
271 |
# also create an aggregated result
|
272 |
# convert to python object to allow selecting numeric columns
|
273 |
obj_py = to_py_obj(obj)
|
|
|
287 |
else:
|
288 |
# aggregate the numeric values
|
289 |
df_described = df_numbers_only.describe()
|
290 |
+
# remove rows in the blacklist
|
291 |
+
if self.multirun_aggregator_blacklist is not None:
|
292 |
+
df_described = df_described.drop(
|
293 |
+
self.multirun_aggregator_blacklist, errors="ignore", axis="index"
|
294 |
+
)
|
295 |
# add the aggregation keys (e.g. mean, min, ...) as most inner keys and convert back to dict
|
296 |
obj_flat_aggregated = df_described.T.stack().to_dict()
|
297 |
# unflatten because _save() works better with nested dicts
|
|
|
299 |
else:
|
300 |
# create a dict of the job return-values of all jobs from a multi-run
|
301 |
# (_save() works better with nested dicts)
|
302 |
+
obj = {
|
303 |
+
identifier: jr.return_value for identifier, jr in zip(job_ids, self.job_returns)
|
304 |
+
}
|
305 |
obj_aggregated = None
|
306 |
output_dir = Path(config.hydra.sweep.dir)
|
307 |
+
if self.multirun_paths_file is not None:
|
308 |
+
# append the output_dir to the file
|
309 |
+
line = f"{output_dir}\n"
|
310 |
+
if self.multirun_path_id is not None:
|
311 |
+
line = f"{self.multirun_path_id}:{line}"
|
312 |
+
with open(self.multirun_paths_file, "a") as file:
|
313 |
+
file.write(line)
|
314 |
+
|
315 |
for filename in self.filenames:
|
316 |
self._save(
|
317 |
obj=obj,
|
318 |
filename=filename,
|
319 |
output_dir=output_dir,
|
320 |
+
is_tabular_data=self.integrate_multirun_result,
|
321 |
)
|
322 |
# if available, also save the aggregated result
|
323 |
if obj_aggregated is not None:
|
324 |
file_base_name, ext = os.path.splitext(filename)
|
325 |
filename_aggregated = f"{file_base_name}.aggregated{ext}"
|
326 |
+
self._save(
|
327 |
+
obj=obj_aggregated,
|
328 |
+
filename=filename_aggregated,
|
329 |
+
output_dir=output_dir,
|
330 |
+
# If we have aggregated (integrated multi-run) results, we unstack the last level,
|
331 |
+
# i.e. the aggregation key.
|
332 |
+
unstack_last_index_level=True,
|
333 |
+
)
|
334 |
|
335 |
def _save(
|
336 |
+
self,
|
337 |
+
obj: Any,
|
338 |
+
filename: str,
|
339 |
+
output_dir: Path,
|
340 |
+
is_tabular_data: bool = False,
|
341 |
+
unstack_last_index_level: bool = False,
|
342 |
) -> None:
|
343 |
self.log.info(f"Saving job_return in {output_dir / filename}")
|
344 |
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
354 |
elif filename.endswith(".md"):
|
355 |
# Convert PyTorch tensors and numpy arrays to native python types
|
356 |
obj_py = to_py_obj(obj)
|
357 |
+
if not isinstance(obj_py, dict):
|
358 |
+
obj_py = {"value": obj_py}
|
359 |
obj_py_flat = flatten_dict(obj_py)
|
360 |
|
361 |
+
if is_tabular_data:
|
362 |
+
# In the case of (not aggregated) integrated multi-run result, we expect to have
|
363 |
+
# multiple values for each key. We therefore just convert the dict to a pandas DataFrame.
|
364 |
result = pd.DataFrame(obj_py_flat)
|
365 |
+
job_id_column = (self.multirun_job_id_key,) + (np.nan,) * (
|
366 |
+
result.columns.nlevels - 1
|
367 |
+
)
|
368 |
+
if job_id_column in result.columns:
|
369 |
+
result = result.set_index(job_id_column)
|
370 |
+
result.index.name = self.multirun_job_id_key
|
371 |
else:
|
372 |
+
# Otherwise, we have only one value for each key. We convert the dict to a pandas Series.
|
|
|
373 |
series = pd.Series(obj_py_flat)
|
374 |
+
# The series has a MultiIndex because flatten_dict() uses a tuple as key.
|
375 |
+
if len(series.index.levels) <= 1:
|
376 |
+
# If there is only one level, we just use the first level values as index.
|
|
|
|
|
377 |
series.index = series.index.get_level_values(0)
|
378 |
result = series
|
379 |
+
else:
|
380 |
+
# If there are multiple levels, we unstack the series to get a DataFrame
|
381 |
+
# providing a better overview.
|
382 |
+
if unstack_last_index_level:
|
383 |
+
# If we have aggregated (integrated multi-run) results, we unstack the last level,
|
384 |
+
# i.e. the aggregation key.
|
385 |
+
result = series.unstack(-1)
|
386 |
+
else:
|
387 |
+
# Otherwise we have a default multi-run result and unstack the first level,
|
388 |
+
# i.e. the identifier created from the overrides, and transpose the result
|
389 |
+
# to have the individual jobs as rows.
|
390 |
+
result = series.unstack(0).T
|
391 |
+
|
392 |
+
if self.markdown_round_digits is not None:
|
393 |
+
result = result.round(self.markdown_round_digits)
|
394 |
|
395 |
with open(str(output_dir / filename), "w") as file:
|
396 |
file.write(result.to_markdown())
|
src/langchain_modules/pie_document_store.py
CHANGED
@@ -75,7 +75,7 @@ class PieDocumentStore(SerializableStore, BaseStore[str, LCDocument], abc.ABC):
|
|
75 |
caption: pie_document[layer_name] for layer_name, caption in layer_captions.items()
|
76 |
}
|
77 |
layer_sizes = {
|
78 |
-
f"num_{caption}
|
79 |
for caption, layer in layers.items()
|
80 |
}
|
81 |
rows.append({"doc_id": doc_id, **layer_sizes})
|
|
|
75 |
caption: pie_document[layer_name] for layer_name, caption in layer_captions.items()
|
76 |
}
|
77 |
layer_sizes = {
|
78 |
+
f"num_{caption}": len(layer) + (len(layer.predictions) if use_predictions else 0)
|
79 |
for caption, layer in layers.items()
|
80 |
}
|
81 |
rows.append({"doc_id": doc_id, **layer_sizes})
|
src/langchain_modules/span_retriever.py
CHANGED
@@ -23,6 +23,7 @@ from pytorch_ie.documents import (
|
|
23 |
TextDocumentWithSpans,
|
24 |
)
|
25 |
|
|
|
26 |
from .pie_document_store import PieDocumentStore
|
27 |
from .serializable_store import SerializableStore
|
28 |
from .span_vectorstore import SpanVectorStore
|
@@ -30,20 +31,6 @@ from .span_vectorstore import SpanVectorStore
|
|
30 |
logger = logging.getLogger(__name__)
|
31 |
|
32 |
|
33 |
-
def _parse_config(config_string: str, format: str) -> Dict[str, Any]:
|
34 |
-
"""Parse a configuration string."""
|
35 |
-
if format == "json":
|
36 |
-
import json
|
37 |
-
|
38 |
-
return json.loads(config_string)
|
39 |
-
elif format == "yaml":
|
40 |
-
import yaml
|
41 |
-
|
42 |
-
return yaml.safe_load(config_string)
|
43 |
-
else:
|
44 |
-
raise ValueError(f"Unsupported format: {format}. Use 'json' or 'yaml'.")
|
45 |
-
|
46 |
-
|
47 |
METADATA_KEY_CHILD_ID2IDX = "child_id2idx"
|
48 |
|
49 |
|
@@ -136,7 +123,7 @@ class DocumentAwareSpanRetriever(BaseRetriever, SerializableStore):
|
|
136 |
) -> "DocumentAwareSpanRetriever":
|
137 |
"""Instantiate a retriever from a configuration string."""
|
138 |
return cls.instantiate_from_config(
|
139 |
-
|
140 |
)
|
141 |
|
142 |
@classmethod
|
@@ -725,6 +712,8 @@ class DocumentAwareSpanRetrieverWithRelations(DocumentAwareSpanRetriever):
|
|
725 |
"""The list of span labels to consider."""
|
726 |
reversed_relations_suffix: Optional[str] = None
|
727 |
"""Whether to consider reverse relations as well."""
|
|
|
|
|
728 |
|
729 |
def get_relation_layer(
|
730 |
self, pie_document: TextBasedDocument, use_predicted_annotations: bool
|
@@ -762,11 +751,19 @@ class DocumentAwareSpanRetrieverWithRelations(DocumentAwareSpanRetriever):
|
|
762 |
)
|
763 |
|
764 |
for relation in relations:
|
|
|
|
|
|
|
|
|
765 |
if self.relation_labels is None or relation.label in self.relation_labels:
|
766 |
head2label2tails_with_scores[span2id[relation.head]][relation.label].append(
|
767 |
(span2id[relation.tail], relation.score)
|
768 |
)
|
769 |
-
|
|
|
|
|
|
|
|
|
770 |
reversed_label = f"{relation.label}{self.reversed_relations_suffix}"
|
771 |
if self.relation_labels is None or reversed_label in self.relation_labels:
|
772 |
head2label2tails_with_scores[span2id[relation.tail]][
|
|
|
23 |
TextDocumentWithSpans,
|
24 |
)
|
25 |
|
26 |
+
from ..utils import parse_config
|
27 |
from .pie_document_store import PieDocumentStore
|
28 |
from .serializable_store import SerializableStore
|
29 |
from .span_vectorstore import SpanVectorStore
|
|
|
31 |
logger = logging.getLogger(__name__)
|
32 |
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
METADATA_KEY_CHILD_ID2IDX = "child_id2idx"
|
35 |
|
36 |
|
|
|
123 |
) -> "DocumentAwareSpanRetriever":
|
124 |
"""Instantiate a retriever from a configuration string."""
|
125 |
return cls.instantiate_from_config(
|
126 |
+
parse_config(config_string, format=format), overwrites=overwrites
|
127 |
)
|
128 |
|
129 |
@classmethod
|
|
|
712 |
"""The list of span labels to consider."""
|
713 |
reversed_relations_suffix: Optional[str] = None
|
714 |
"""Whether to consider reverse relations as well."""
|
715 |
+
symmetric_relations: Optional[list[str]] = None
|
716 |
+
"""The list of relation labels that are symmetric."""
|
717 |
|
718 |
def get_relation_layer(
|
719 |
self, pie_document: TextBasedDocument, use_predicted_annotations: bool
|
|
|
751 |
)
|
752 |
|
753 |
for relation in relations:
|
754 |
+
is_symmetric = (
|
755 |
+
self.symmetric_relations is not None
|
756 |
+
and relation.label in self.symmetric_relations
|
757 |
+
)
|
758 |
if self.relation_labels is None or relation.label in self.relation_labels:
|
759 |
head2label2tails_with_scores[span2id[relation.head]][relation.label].append(
|
760 |
(span2id[relation.tail], relation.score)
|
761 |
)
|
762 |
+
if is_symmetric:
|
763 |
+
head2label2tails_with_scores[span2id[relation.tail]][
|
764 |
+
relation.label
|
765 |
+
].append((span2id[relation.head], relation.score))
|
766 |
+
if self.reversed_relations_suffix is not None and not is_symmetric:
|
767 |
reversed_label = f"{relation.label}{self.reversed_relations_suffix}"
|
768 |
if self.relation_labels is None or reversed_label in self.relation_labels:
|
769 |
head2label2tails_with_scores[span2id[relation.tail]][
|
src/pipeline/ner_re_pipeline.py
CHANGED
@@ -2,7 +2,18 @@ from __future__ import annotations
|
|
2 |
|
3 |
import logging
|
4 |
from functools import partial
|
5 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
from pie_modules.utils import resolve_type
|
8 |
from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
|
@@ -72,11 +83,13 @@ def add_annotations_from_other_documents(
|
|
72 |
def process_pipeline_steps(
|
73 |
documents: Sequence[Document],
|
74 |
processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]],
|
|
|
75 |
) -> Sequence[Document]:
|
76 |
|
77 |
# call the processors in the order they are provided
|
78 |
for step_name, processor in processors.items():
|
79 |
-
|
|
|
80 |
processed_documents = processor(documents)
|
81 |
if processed_documents is not None:
|
82 |
documents = processed_documents
|
@@ -120,6 +133,7 @@ class NerRePipeline:
|
|
120 |
batch_size: Optional[int] = None,
|
121 |
show_progress_bar: Optional[bool] = None,
|
122 |
document_type: Optional[Union[Type[Document], str]] = None,
|
|
|
123 |
**processor_kwargs,
|
124 |
):
|
125 |
self.taskmodule = DummyTaskmodule(document_type)
|
@@ -128,6 +142,7 @@ class NerRePipeline:
|
|
128 |
self.processor_kwargs = processor_kwargs or {}
|
129 |
self.entity_layer = entity_layer
|
130 |
self.relation_layer = relation_layer
|
|
|
131 |
# set some values for the inference processors, if provided
|
132 |
for inference_pipeline in ["ner_pipeline", "re_pipeline"]:
|
133 |
if inference_pipeline not in self.processor_kwargs:
|
@@ -145,7 +160,29 @@ class NerRePipeline:
|
|
145 |
):
|
146 |
self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar
|
147 |
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
input_docs: Sequence[Document]
|
151 |
# we need to keep the original documents to add the gold data back
|
@@ -166,24 +203,14 @@ class NerRePipeline:
|
|
166 |
layer_names=[self.entity_layer, self.relation_layer],
|
167 |
**self.processor_kwargs.get("clear_annotations", {}),
|
168 |
),
|
169 |
-
"ner_pipeline":
|
170 |
-
self.ner_model_path, **self.processor_kwargs.get("ner_pipeline", {})
|
171 |
-
),
|
172 |
"use_predicted_entities": partial(
|
173 |
process_documents,
|
174 |
processor=move_annotations_from_predictions,
|
175 |
layer_names=[self.entity_layer],
|
176 |
**self.processor_kwargs.get("use_predicted_entities", {}),
|
177 |
),
|
178 |
-
|
179 |
-
# process_documents,
|
180 |
-
# processor=CandidateRelationAdder(
|
181 |
-
# **self.processor_kwargs.get("create_candidate_relations", {})
|
182 |
-
# ),
|
183 |
-
# ),
|
184 |
-
"re_pipeline": AutoPipeline.from_pretrained(
|
185 |
-
self.re_model_path, **self.processor_kwargs.get("re_pipeline", {})
|
186 |
-
),
|
187 |
# otherwise we can not move the entities back to predictions
|
188 |
"clear_candidate_relations": partial(
|
189 |
process_documents,
|
@@ -204,5 +231,8 @@ class NerRePipeline:
|
|
204 |
**self.processor_kwargs.get("re_add_gold_data", {}),
|
205 |
),
|
206 |
},
|
|
|
207 |
)
|
|
|
|
|
208 |
return docs_with_predictions
|
|
|
2 |
|
3 |
import logging
|
4 |
from functools import partial
|
5 |
+
from typing import (
|
6 |
+
Callable,
|
7 |
+
Dict,
|
8 |
+
Iterable,
|
9 |
+
List,
|
10 |
+
Optional,
|
11 |
+
Sequence,
|
12 |
+
Type,
|
13 |
+
TypeVar,
|
14 |
+
Union,
|
15 |
+
overload,
|
16 |
+
)
|
17 |
|
18 |
from pie_modules.utils import resolve_type
|
19 |
from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
|
|
|
83 |
def process_pipeline_steps(
|
84 |
documents: Sequence[Document],
|
85 |
processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]],
|
86 |
+
verbose: bool = False,
|
87 |
) -> Sequence[Document]:
|
88 |
|
89 |
# call the processors in the order they are provided
|
90 |
for step_name, processor in processors.items():
|
91 |
+
if verbose:
|
92 |
+
logger.info(f"process {step_name} ...")
|
93 |
processed_documents = processor(documents)
|
94 |
if processed_documents is not None:
|
95 |
documents = processed_documents
|
|
|
133 |
batch_size: Optional[int] = None,
|
134 |
show_progress_bar: Optional[bool] = None,
|
135 |
document_type: Optional[Union[Type[Document], str]] = None,
|
136 |
+
verbose: bool = True,
|
137 |
**processor_kwargs,
|
138 |
):
|
139 |
self.taskmodule = DummyTaskmodule(document_type)
|
|
|
142 |
self.processor_kwargs = processor_kwargs or {}
|
143 |
self.entity_layer = entity_layer
|
144 |
self.relation_layer = relation_layer
|
145 |
+
self.verbose = verbose
|
146 |
# set some values for the inference processors, if provided
|
147 |
for inference_pipeline in ["ner_pipeline", "re_pipeline"]:
|
148 |
if inference_pipeline not in self.processor_kwargs:
|
|
|
160 |
):
|
161 |
self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar
|
162 |
|
163 |
+
self.ner_pipeline = AutoPipeline.from_pretrained(
|
164 |
+
self.ner_model_path, **self.processor_kwargs.get("ner_pipeline", {})
|
165 |
+
)
|
166 |
+
self.re_pipeline = AutoPipeline.from_pretrained(
|
167 |
+
self.re_model_path, **self.processor_kwargs.get("re_pipeline", {})
|
168 |
+
)
|
169 |
+
|
170 |
+
@overload
|
171 |
+
def __call__(
|
172 |
+
self, documents: Sequence[Document], inplace: bool = False
|
173 |
+
) -> Sequence[Document]: ...
|
174 |
+
|
175 |
+
@overload
|
176 |
+
def __call__(self, documents: Document, inplace: bool = False) -> Document: ...
|
177 |
+
|
178 |
+
def __call__(
|
179 |
+
self, documents: Union[Sequence[Document], Document], inplace: bool = False
|
180 |
+
) -> Union[Sequence[Document], Document]:
|
181 |
+
|
182 |
+
is_single_doc = False
|
183 |
+
if isinstance(documents, Document):
|
184 |
+
documents = [documents]
|
185 |
+
is_single_doc = True
|
186 |
|
187 |
input_docs: Sequence[Document]
|
188 |
# we need to keep the original documents to add the gold data back
|
|
|
203 |
layer_names=[self.entity_layer, self.relation_layer],
|
204 |
**self.processor_kwargs.get("clear_annotations", {}),
|
205 |
),
|
206 |
+
"ner_pipeline": self.ner_pipeline,
|
|
|
|
|
207 |
"use_predicted_entities": partial(
|
208 |
process_documents,
|
209 |
processor=move_annotations_from_predictions,
|
210 |
layer_names=[self.entity_layer],
|
211 |
**self.processor_kwargs.get("use_predicted_entities", {}),
|
212 |
),
|
213 |
+
"re_pipeline": self.re_pipeline,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
# otherwise we can not move the entities back to predictions
|
215 |
"clear_candidate_relations": partial(
|
216 |
process_documents,
|
|
|
231 |
**self.processor_kwargs.get("re_add_gold_data", {}),
|
232 |
),
|
233 |
},
|
234 |
+
verbose=self.verbose,
|
235 |
)
|
236 |
+
if is_single_doc:
|
237 |
+
return docs_with_predictions[0]
|
238 |
return docs_with_predictions
|
src/predict.py
CHANGED
@@ -106,8 +106,12 @@ def predict(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
106 |
# Per default, the model is loaded with .from_pretrained() which already loads the weights.
|
107 |
# However, ckpt_path can be used to load different weights from any checkpoint.
|
108 |
if cfg.ckpt_path is not None:
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
111 |
)
|
112 |
|
113 |
# auto-convert the dataset if the metric specifies a document type
|
|
|
106 |
# Per default, the model is loaded with .from_pretrained() which already loads the weights.
|
107 |
# However, ckpt_path can be used to load different weights from any checkpoint.
|
108 |
if cfg.ckpt_path is not None:
|
109 |
+
log.info(f"Loading model weights from checkpoint: {cfg.ckpt_path}")
|
110 |
+
pipeline.model = (
|
111 |
+
type(pipeline.model)
|
112 |
+
.load_from_checkpoint(checkpoint_path=cfg.ckpt_path)
|
113 |
+
.to(pipeline.device)
|
114 |
+
.to(dtype=pipeline.model.dtype)
|
115 |
)
|
116 |
|
117 |
# auto-convert the dataset if the metric specifies a document type
|
src/start_demo.py
CHANGED
@@ -19,8 +19,10 @@ import yaml
|
|
19 |
from src.demo.annotation_utils import load_argumentation_model
|
20 |
from src.demo.backend_utils import (
|
21 |
download_processed_documents,
|
|
|
22 |
process_text_from_arxiv,
|
23 |
process_uploaded_files,
|
|
|
24 |
render_annotated_document,
|
25 |
upload_processed_documents,
|
26 |
wrapped_add_annotated_pie_documents_from_dataset,
|
@@ -31,6 +33,7 @@ from src.demo.frontend_utils import (
|
|
31 |
escape_regex,
|
32 |
get_cell_for_fixed_column_from_df,
|
33 |
open_accordion,
|
|
|
34 |
unescape_regex,
|
35 |
)
|
36 |
from src.demo.rendering_utils import AVAILABLE_RENDER_MODES, HIGHLIGHT_SPANS_JS
|
@@ -67,12 +70,10 @@ def main(cfg: DictConfig) -> None:
|
|
67 |
|
68 |
example_text = cfg["example_text"]
|
69 |
|
70 |
-
default_device = "cuda" if torch.cuda.is_available() else "cpu"
|
71 |
|
72 |
-
default_retriever_config_str =
|
73 |
-
|
74 |
-
default_model_name = cfg["default_model_name"]
|
75 |
-
default_model_revision = cfg["default_model_revision"]
|
76 |
handle_parts_of_same = cfg["handle_parts_of_same"]
|
77 |
|
78 |
default_arxiv_id = cfg["default_arxiv_id"]
|
@@ -97,19 +98,32 @@ def main(cfg: DictConfig) -> None:
|
|
97 |
}
|
98 |
render_caption2mode = {v: k for k, v in render_mode2caption.items()}
|
99 |
default_min_similarity = cfg["default_min_similarity"]
|
|
|
100 |
layer_caption_mapping = cfg["layer_caption_mapping"]
|
101 |
relation_name_mapping = cfg["relation_name_mapping"]
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
gr.Info("Loading models ...")
|
104 |
argumentation_model = load_argumentation_model(
|
105 |
-
|
106 |
-
revision=default_model_revision,
|
107 |
device=default_device,
|
108 |
)
|
109 |
retriever = load_retriever(
|
110 |
-
default_retriever_config_str, device=default_device, config_format="yaml"
|
111 |
)
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
with gr.Blocks() as demo:
|
114 |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
115 |
# models_state = gr.State((argumentation_model, embedding_model))
|
@@ -131,18 +145,16 @@ def main(cfg: DictConfig) -> None:
|
|
131 |
|
132 |
with gr.Accordion("Model Configuration", open=False):
|
133 |
with gr.Accordion("argumentation structure", open=True):
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
label="Model Revision",
|
140 |
-
value=default_model_revision,
|
141 |
)
|
142 |
load_arg_model_btn = gr.Button("Load Argumentation Model")
|
143 |
|
144 |
with gr.Accordion("retriever", open=True):
|
145 |
-
|
146 |
language="yaml",
|
147 |
label="Retriever Configuration",
|
148 |
value=default_retriever_config_str,
|
@@ -155,26 +167,25 @@ def main(cfg: DictConfig) -> None:
|
|
155 |
value=default_device,
|
156 |
)
|
157 |
load_arg_model_btn.click(
|
158 |
-
fn=lambda
|
159 |
load_argumentation_model(
|
160 |
-
|
161 |
-
revision=_model_revision,
|
162 |
device=_device,
|
163 |
),
|
164 |
),
|
165 |
-
inputs=[
|
166 |
outputs=argumentation_model_state,
|
167 |
)
|
168 |
load_retriever_btn.click(
|
169 |
fn=lambda _retriever_config, _device, _previous_retriever: (
|
170 |
load_retriever(
|
171 |
-
|
172 |
device=_device,
|
173 |
previous_retriever=_previous_retriever[0],
|
174 |
config_format="yaml",
|
175 |
),
|
176 |
),
|
177 |
-
inputs=[
|
178 |
outputs=retriever_state,
|
179 |
)
|
180 |
|
@@ -213,7 +224,7 @@ def main(cfg: DictConfig) -> None:
|
|
213 |
with gr.Tabs() as right_tabs:
|
214 |
with gr.Tab("Retrieval", id="retrieval") as retrieval_tab:
|
215 |
with gr.Accordion(
|
216 |
-
|
217 |
) as processed_documents_accordion:
|
218 |
processed_documents_df = gr.DataFrame(
|
219 |
headers=["id", "num_adus", "num_relations"],
|
@@ -274,7 +285,7 @@ def main(cfg: DictConfig) -> None:
|
|
274 |
minimum=2,
|
275 |
maximum=50,
|
276 |
step=1,
|
277 |
-
value=
|
278 |
)
|
279 |
retrieve_similar_adus_btn = gr.Button(
|
280 |
"Retrieve *similar* ADUs for *selected* ADU"
|
@@ -293,8 +304,10 @@ def main(cfg: DictConfig) -> None:
|
|
293 |
"Retrieve *relevant* ADUs for *all* ADUs in the document"
|
294 |
)
|
295 |
all_relevant_adus_df = gr.DataFrame(
|
296 |
-
headers=["doc_id", "adu_id", "score", "text"],
|
|
|
297 |
)
|
|
|
298 |
|
299 |
with gr.Tab("Import Documents", id="import_documents") as import_documents_tab:
|
300 |
upload_btn = gr.UploadButton(
|
@@ -303,6 +316,28 @@ def main(cfg: DictConfig) -> None:
|
|
303 |
file_count="multiple",
|
304 |
)
|
305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
with gr.Accordion("Import text from arXiv", open=False):
|
307 |
arxiv_id = gr.Textbox(
|
308 |
label="arXiv paper ID",
|
@@ -326,13 +361,25 @@ def main(cfg: DictConfig) -> None:
|
|
326 |
load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
|
327 |
|
328 |
render_event_kwargs = dict(
|
329 |
-
fn=lambda _retriever, _document_id, _render_as, _render_kwargs: render_annotated_document(
|
330 |
retriever=_retriever[0],
|
331 |
document_id=_document_id,
|
332 |
render_with=render_caption2mode[_render_as],
|
333 |
render_kwargs_json=_render_kwargs,
|
|
|
|
|
|
|
|
|
|
|
334 |
),
|
335 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
outputs=rendered_output,
|
337 |
)
|
338 |
|
@@ -343,6 +390,16 @@ def main(cfg: DictConfig) -> None:
|
|
343 |
inputs=[retriever_state],
|
344 |
outputs=[processed_documents_df],
|
345 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
predict_btn.click(
|
347 |
fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
|
348 |
).then(
|
@@ -367,6 +424,8 @@ def main(cfg: DictConfig) -> None:
|
|
367 |
api_name="predict",
|
368 |
).success(
|
369 |
**show_overview_kwargs
|
|
|
|
|
370 |
).success(
|
371 |
**render_event_kwargs
|
372 |
)
|
@@ -396,6 +455,8 @@ def main(cfg: DictConfig) -> None:
|
|
396 |
api_name="predict",
|
397 |
).success(
|
398 |
**show_overview_kwargs
|
|
|
|
|
399 |
)
|
400 |
|
401 |
load_pie_dataset_btn.click(
|
@@ -409,6 +470,8 @@ def main(cfg: DictConfig) -> None:
|
|
409 |
),
|
410 |
inputs=[retriever_state, load_pie_dataset_kwargs_str],
|
411 |
outputs=[processed_documents_df],
|
|
|
|
|
412 |
)
|
413 |
|
414 |
selected_document_id.change(
|
@@ -430,7 +493,9 @@ def main(cfg: DictConfig) -> None:
|
|
430 |
file_names=_file_names,
|
431 |
argumentation_model=_argumentation_model[0],
|
432 |
retriever=_retriever[0],
|
433 |
-
split_regex_escaped=
|
|
|
|
|
434 |
handle_parts_of_same=handle_parts_of_same,
|
435 |
layer_captions=layer_caption_mapping,
|
436 |
),
|
@@ -441,7 +506,61 @@ def main(cfg: DictConfig) -> None:
|
|
441 |
split_regex_escaped,
|
442 |
],
|
443 |
outputs=[processed_documents_df],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
)
|
|
|
445 |
processed_documents_df.select(
|
446 |
fn=get_cell_for_fixed_column_from_df,
|
447 |
inputs=[processed_documents_df, gr.State("doc_id")],
|
@@ -461,7 +580,7 @@ def main(cfg: DictConfig) -> None:
|
|
461 |
),
|
462 |
inputs=[upload_processed_documents_btn, retriever_state],
|
463 |
outputs=[processed_documents_df],
|
464 |
-
)
|
465 |
|
466 |
retrieve_relevant_adus_event_kwargs = dict(
|
467 |
fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans(
|
@@ -533,12 +652,16 @@ def main(cfg: DictConfig) -> None:
|
|
533 |
)
|
534 |
|
535 |
retrieve_all_relevant_adus_btn.click(
|
536 |
-
fn=lambda _retriever, _document_id, _min_similarity, _tok_k:
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
|
|
|
|
|
|
|
|
542 |
),
|
543 |
inputs=[
|
544 |
retriever_state,
|
@@ -546,9 +669,11 @@ def main(cfg: DictConfig) -> None:
|
|
546 |
min_similarity,
|
547 |
top_k,
|
548 |
],
|
549 |
-
outputs=[all_relevant_adus_df],
|
550 |
)
|
551 |
|
|
|
|
|
552 |
# select query span id from the "retrieve all" result data frames
|
553 |
all_similar_adus_df.select(
|
554 |
fn=get_cell_for_fixed_column_from_df,
|
|
|
19 |
from src.demo.annotation_utils import load_argumentation_model
|
20 |
from src.demo.backend_utils import (
|
21 |
download_processed_documents,
|
22 |
+
load_acl_anthology_venues,
|
23 |
process_text_from_arxiv,
|
24 |
process_uploaded_files,
|
25 |
+
process_uploaded_pdf_files,
|
26 |
render_annotated_document,
|
27 |
upload_processed_documents,
|
28 |
wrapped_add_annotated_pie_documents_from_dataset,
|
|
|
33 |
escape_regex,
|
34 |
get_cell_for_fixed_column_from_df,
|
35 |
open_accordion,
|
36 |
+
open_accordion_with_stats,
|
37 |
unescape_regex,
|
38 |
)
|
39 |
from src.demo.rendering_utils import AVAILABLE_RENDER_MODES, HIGHLIGHT_SPANS_JS
|
|
|
70 |
|
71 |
example_text = cfg["example_text"]
|
72 |
|
73 |
+
default_device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
74 |
|
75 |
+
default_retriever_config_str = yaml.dump(cfg["retriever"])
|
76 |
+
default_argumentation_model_config_str = yaml.dump(cfg["argumentation_model"])
|
|
|
|
|
77 |
handle_parts_of_same = cfg["handle_parts_of_same"]
|
78 |
|
79 |
default_arxiv_id = cfg["default_arxiv_id"]
|
|
|
98 |
}
|
99 |
render_caption2mode = {v: k for k, v in render_mode2caption.items()}
|
100 |
default_min_similarity = cfg["default_min_similarity"]
|
101 |
+
default_top_k = cfg["default_top_k"]
|
102 |
layer_caption_mapping = cfg["layer_caption_mapping"]
|
103 |
relation_name_mapping = cfg["relation_name_mapping"]
|
104 |
|
105 |
+
indexed_documents_label = "Indexed Documents"
|
106 |
+
indexed_documents_caption2column = {
|
107 |
+
"documents": "TOTAL",
|
108 |
+
"ADUs": "num_adus",
|
109 |
+
"Relations": "num_relations",
|
110 |
+
}
|
111 |
+
|
112 |
gr.Info("Loading models ...")
|
113 |
argumentation_model = load_argumentation_model(
|
114 |
+
config_str=default_argumentation_model_config_str,
|
|
|
115 |
device=default_device,
|
116 |
)
|
117 |
retriever = load_retriever(
|
118 |
+
config_str=default_retriever_config_str, device=default_device, config_format="yaml"
|
119 |
)
|
120 |
|
121 |
+
if cfg.get("pdf_fulltext_extractor"):
|
122 |
+
gr.Info("Loading PDF fulltext extractor ...")
|
123 |
+
pdf_fulltext_extractor = hydra.utils.instantiate(cfg["pdf_fulltext_extractor"])
|
124 |
+
else:
|
125 |
+
pdf_fulltext_extractor = None
|
126 |
+
|
127 |
with gr.Blocks() as demo:
|
128 |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
129 |
# models_state = gr.State((argumentation_model, embedding_model))
|
|
|
145 |
|
146 |
with gr.Accordion("Model Configuration", open=False):
|
147 |
with gr.Accordion("argumentation structure", open=True):
|
148 |
+
argumentation_model_config_str = gr.Code(
|
149 |
+
language="yaml",
|
150 |
+
label="Argumentation Model Configuration",
|
151 |
+
value=default_argumentation_model_config_str,
|
152 |
+
lines=len(default_argumentation_model_config_str.split("\n")),
|
|
|
|
|
153 |
)
|
154 |
load_arg_model_btn = gr.Button("Load Argumentation Model")
|
155 |
|
156 |
with gr.Accordion("retriever", open=True):
|
157 |
+
retriever_config_str = gr.Code(
|
158 |
language="yaml",
|
159 |
label="Retriever Configuration",
|
160 |
value=default_retriever_config_str,
|
|
|
167 |
value=default_device,
|
168 |
)
|
169 |
load_arg_model_btn.click(
|
170 |
+
fn=lambda _argumentation_model_config_str, _device: (
|
171 |
load_argumentation_model(
|
172 |
+
config_str=_argumentation_model_config_str,
|
|
|
173 |
device=_device,
|
174 |
),
|
175 |
),
|
176 |
+
inputs=[argumentation_model_config_str, device],
|
177 |
outputs=argumentation_model_state,
|
178 |
)
|
179 |
load_retriever_btn.click(
|
180 |
fn=lambda _retriever_config, _device, _previous_retriever: (
|
181 |
load_retriever(
|
182 |
+
config_str=_retriever_config,
|
183 |
device=_device,
|
184 |
previous_retriever=_previous_retriever[0],
|
185 |
config_format="yaml",
|
186 |
),
|
187 |
),
|
188 |
+
inputs=[retriever_config_str, device, retriever_state],
|
189 |
outputs=retriever_state,
|
190 |
)
|
191 |
|
|
|
224 |
with gr.Tabs() as right_tabs:
|
225 |
with gr.Tab("Retrieval", id="retrieval") as retrieval_tab:
|
226 |
with gr.Accordion(
|
227 |
+
indexed_documents_label, open=False
|
228 |
) as processed_documents_accordion:
|
229 |
processed_documents_df = gr.DataFrame(
|
230 |
headers=["id", "num_adus", "num_relations"],
|
|
|
285 |
minimum=2,
|
286 |
maximum=50,
|
287 |
step=1,
|
288 |
+
value=default_top_k,
|
289 |
)
|
290 |
retrieve_similar_adus_btn = gr.Button(
|
291 |
"Retrieve *similar* ADUs for *selected* ADU"
|
|
|
304 |
"Retrieve *relevant* ADUs for *all* ADUs in the document"
|
305 |
)
|
306 |
all_relevant_adus_df = gr.DataFrame(
|
307 |
+
headers=["doc_id", "adu_id", "score", "text", "query_span_id"],
|
308 |
+
interactive=False,
|
309 |
)
|
310 |
+
all_relevant_adus_query_doc_id = gr.Textbox(visible=False)
|
311 |
|
312 |
with gr.Tab("Import Documents", id="import_documents") as import_documents_tab:
|
313 |
upload_btn = gr.UploadButton(
|
|
|
316 |
file_count="multiple",
|
317 |
)
|
318 |
|
319 |
+
upload_pdf_btn = gr.UploadButton(
|
320 |
+
"Batch Analyse PDFs",
|
321 |
+
# file_types=["pdf"],
|
322 |
+
file_count="multiple",
|
323 |
+
visible=pdf_fulltext_extractor is not None,
|
324 |
+
)
|
325 |
+
|
326 |
+
enable_acl_venue_loading = pdf_fulltext_extractor is not None and cfg.get(
|
327 |
+
"acl_anthology_pdf_dir"
|
328 |
+
)
|
329 |
+
acl_anthology_venues = gr.Textbox(
|
330 |
+
label="ACL Anthology Venues",
|
331 |
+
value="wiesp",
|
332 |
+
max_lines=1,
|
333 |
+
visible=enable_acl_venue_loading,
|
334 |
+
)
|
335 |
+
load_acl_anthology_venues_btn = gr.Button(
|
336 |
+
"Import from ACL Anthology",
|
337 |
+
variant="secondary",
|
338 |
+
visible=enable_acl_venue_loading,
|
339 |
+
)
|
340 |
+
|
341 |
with gr.Accordion("Import text from arXiv", open=False):
|
342 |
arxiv_id = gr.Textbox(
|
343 |
label="arXiv paper ID",
|
|
|
361 |
load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
|
362 |
|
363 |
render_event_kwargs = dict(
|
364 |
+
fn=lambda _retriever, _document_id, _render_as, _render_kwargs, _all_relevant_adus_df, _all_relevant_adus_query_doc_id: render_annotated_document(
|
365 |
retriever=_retriever[0],
|
366 |
document_id=_document_id,
|
367 |
render_with=render_caption2mode[_render_as],
|
368 |
render_kwargs_json=_render_kwargs,
|
369 |
+
highlight_span_ids=(
|
370 |
+
_all_relevant_adus_df["query_span_id"].tolist()
|
371 |
+
if _document_id == _all_relevant_adus_query_doc_id
|
372 |
+
else None
|
373 |
+
),
|
374 |
),
|
375 |
+
inputs=[
|
376 |
+
retriever_state,
|
377 |
+
selected_document_id,
|
378 |
+
render_as,
|
379 |
+
render_kwargs,
|
380 |
+
all_relevant_adus_df,
|
381 |
+
all_relevant_adus_query_doc_id,
|
382 |
+
],
|
383 |
outputs=rendered_output,
|
384 |
)
|
385 |
|
|
|
390 |
inputs=[retriever_state],
|
391 |
outputs=[processed_documents_df],
|
392 |
)
|
393 |
+
show_stats_kwargs = dict(
|
394 |
+
fn=lambda _processed_documents_df: open_accordion_with_stats(
|
395 |
+
_processed_documents_df,
|
396 |
+
base_label=indexed_documents_label,
|
397 |
+
caption2column=indexed_documents_caption2column,
|
398 |
+
total_column="TOTAL",
|
399 |
+
),
|
400 |
+
inputs=[processed_documents_df],
|
401 |
+
outputs=[processed_documents_accordion],
|
402 |
+
)
|
403 |
predict_btn.click(
|
404 |
fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
|
405 |
).then(
|
|
|
424 |
api_name="predict",
|
425 |
).success(
|
426 |
**show_overview_kwargs
|
427 |
+
).success(
|
428 |
+
**show_stats_kwargs
|
429 |
).success(
|
430 |
**render_event_kwargs
|
431 |
)
|
|
|
455 |
api_name="predict",
|
456 |
).success(
|
457 |
**show_overview_kwargs
|
458 |
+
).success(
|
459 |
+
**show_stats_kwargs
|
460 |
)
|
461 |
|
462 |
load_pie_dataset_btn.click(
|
|
|
470 |
),
|
471 |
inputs=[retriever_state, load_pie_dataset_kwargs_str],
|
472 |
outputs=[processed_documents_df],
|
473 |
+
).success(
|
474 |
+
**show_stats_kwargs
|
475 |
)
|
476 |
|
477 |
selected_document_id.change(
|
|
|
493 |
file_names=_file_names,
|
494 |
argumentation_model=_argumentation_model[0],
|
495 |
retriever=_retriever[0],
|
496 |
+
split_regex_escaped=(
|
497 |
+
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
|
498 |
+
),
|
499 |
handle_parts_of_same=handle_parts_of_same,
|
500 |
layer_captions=layer_caption_mapping,
|
501 |
),
|
|
|
506 |
split_regex_escaped,
|
507 |
],
|
508 |
outputs=[processed_documents_df],
|
509 |
+
).success(
|
510 |
+
**show_stats_kwargs
|
511 |
+
)
|
512 |
+
upload_pdf_btn.upload(
|
513 |
+
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs]
|
514 |
+
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then(
|
515 |
+
fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_pdf_files(
|
516 |
+
file_names=_file_names,
|
517 |
+
argumentation_model=_argumentation_model[0],
|
518 |
+
retriever=_retriever[0],
|
519 |
+
split_regex_escaped=(
|
520 |
+
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
|
521 |
+
),
|
522 |
+
handle_parts_of_same=handle_parts_of_same,
|
523 |
+
layer_captions=layer_caption_mapping,
|
524 |
+
pdf_fulltext_extractor=pdf_fulltext_extractor,
|
525 |
+
),
|
526 |
+
inputs=[
|
527 |
+
upload_pdf_btn,
|
528 |
+
argumentation_model_state,
|
529 |
+
retriever_state,
|
530 |
+
split_regex_escaped,
|
531 |
+
],
|
532 |
+
outputs=[processed_documents_df],
|
533 |
+
).success(
|
534 |
+
**show_stats_kwargs
|
535 |
+
)
|
536 |
+
|
537 |
+
load_acl_anthology_venues_btn.click(
|
538 |
+
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs]
|
539 |
+
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then(
|
540 |
+
fn=lambda _acl_anthology_venues, _argumentation_model, _retriever, _split_regex_escaped: load_acl_anthology_venues(
|
541 |
+
pdf_fulltext_extractor=pdf_fulltext_extractor,
|
542 |
+
venues=[venue.strip() for venue in _acl_anthology_venues.split(",")],
|
543 |
+
argumentation_model=_argumentation_model[0],
|
544 |
+
retriever=_retriever[0],
|
545 |
+
split_regex_escaped=(
|
546 |
+
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
|
547 |
+
),
|
548 |
+
handle_parts_of_same=handle_parts_of_same,
|
549 |
+
layer_captions=layer_caption_mapping,
|
550 |
+
acl_anthology_data_dir=cfg.get("acl_anthology_data_dir"),
|
551 |
+
pdf_output_dir=cfg.get("acl_anthology_pdf_dir"),
|
552 |
+
),
|
553 |
+
inputs=[
|
554 |
+
acl_anthology_venues,
|
555 |
+
argumentation_model_state,
|
556 |
+
retriever_state,
|
557 |
+
split_regex_escaped,
|
558 |
+
],
|
559 |
+
outputs=[processed_documents_df],
|
560 |
+
).success(
|
561 |
+
**show_stats_kwargs
|
562 |
)
|
563 |
+
|
564 |
processed_documents_df.select(
|
565 |
fn=get_cell_for_fixed_column_from_df,
|
566 |
inputs=[processed_documents_df, gr.State("doc_id")],
|
|
|
580 |
),
|
581 |
inputs=[upload_processed_documents_btn, retriever_state],
|
582 |
outputs=[processed_documents_df],
|
583 |
+
).success(**show_stats_kwargs)
|
584 |
|
585 |
retrieve_relevant_adus_event_kwargs = dict(
|
586 |
fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans(
|
|
|
652 |
)
|
653 |
|
654 |
retrieve_all_relevant_adus_btn.click(
|
655 |
+
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: (
|
656 |
+
retrieve_all_relevant_spans(
|
657 |
+
retriever=_retriever[0],
|
658 |
+
query_doc_id=_document_id,
|
659 |
+
k=_tok_k,
|
660 |
+
score_threshold=_min_similarity,
|
661 |
+
query_span_id_column="query_span_id",
|
662 |
+
query_span_text_column="query_span_text",
|
663 |
+
),
|
664 |
+
_document_id,
|
665 |
),
|
666 |
inputs=[
|
667 |
retriever_state,
|
|
|
669 |
min_similarity,
|
670 |
top_k,
|
671 |
],
|
672 |
+
outputs=[all_relevant_adus_df, all_relevant_adus_query_doc_id],
|
673 |
)
|
674 |
|
675 |
+
all_relevant_adus_df.change(**render_event_kwargs)
|
676 |
+
|
677 |
# select query span id from the "retrieve all" result data frames
|
678 |
all_similar_adus_df.select(
|
679 |
fn=get_cell_for_fixed_column_from_df,
|
src/train.py
CHANGED
@@ -220,6 +220,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
220 |
train_metrics = trainer.callback_metrics
|
221 |
|
222 |
best_ckpt_path = trainer.checkpoint_callback.best_model_path
|
|
|
223 |
if best_ckpt_path != "":
|
224 |
log.info(f"Best ckpt path: {best_ckpt_path}")
|
225 |
best_checkpoint_file = os.path.basename(best_ckpt_path)
|
@@ -228,6 +229,14 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
228 |
best_checkpoint=best_checkpoint_file,
|
229 |
checkpoint_dir=trainer.checkpoint_callback.dirpath,
|
230 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
if not cfg.trainer.get("fast_dev_run"):
|
233 |
if cfg.model_save_dir is not None:
|
@@ -259,6 +268,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
|
259 |
trainer.test(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None)
|
260 |
|
261 |
test_metrics = trainer.callback_metrics
|
|
|
262 |
|
263 |
# merge train and test metrics
|
264 |
metric_dict = {**train_metrics, **test_metrics}
|
|
|
220 |
train_metrics = trainer.callback_metrics
|
221 |
|
222 |
best_ckpt_path = trainer.checkpoint_callback.best_model_path
|
223 |
+
best_epoch = None
|
224 |
if best_ckpt_path != "":
|
225 |
log.info(f"Best ckpt path: {best_ckpt_path}")
|
226 |
best_checkpoint_file = os.path.basename(best_ckpt_path)
|
|
|
229 |
best_checkpoint=best_checkpoint_file,
|
230 |
checkpoint_dir=trainer.checkpoint_callback.dirpath,
|
231 |
)
|
232 |
+
# get epoch from best_checkpoint_file (e.g. "epoch_078.ckpt")
|
233 |
+
try:
|
234 |
+
best_epoch = int(os.path.splitext(best_checkpoint_file)[0].split("_")[-1])
|
235 |
+
except Exception as e:
|
236 |
+
log.warning(
|
237 |
+
f'Could not retrieve epoch from best checkpoint file name: "{e}". '
|
238 |
+
f"Expected format: " + '"epoch_{best_epoch}.ckpt"'
|
239 |
+
)
|
240 |
|
241 |
if not cfg.trainer.get("fast_dev_run"):
|
242 |
if cfg.model_save_dir is not None:
|
|
|
268 |
trainer.test(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None)
|
269 |
|
270 |
test_metrics = trainer.callback_metrics
|
271 |
+
test_metrics["best_epoch"] = best_epoch
|
272 |
|
273 |
# merge train and test metrics
|
274 |
metric_dict = {**train_metrics, **test_metrics}
|
src/utils/__init__.py
CHANGED
@@ -1,4 +1,9 @@
|
|
1 |
-
from .config_utils import
|
|
|
|
|
|
|
|
|
|
|
2 |
from .data_utils import download_and_unzip, filter_dataframe_and_get_column
|
3 |
from .logging_utils import close_loggers, get_pylogger, log_hyperparameters
|
4 |
from .rich_utils import enforce_tags, print_config_tree
|
|
|
1 |
+
from .config_utils import (
|
2 |
+
execute_pipeline,
|
3 |
+
instantiate_dict_entries,
|
4 |
+
parse_config,
|
5 |
+
prepare_omegaconf,
|
6 |
+
)
|
7 |
from .data_utils import download_and_unzip, filter_dataframe_and_get_column
|
8 |
from .logging_utils import close_loggers, get_pylogger, log_hyperparameters
|
9 |
from .rich_utils import enforce_tags, print_config_tree
|
src/utils/config_utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from copy import copy
|
2 |
-
from typing import Any, List, Optional
|
3 |
|
4 |
from hydra.utils import instantiate
|
5 |
from omegaconf import DictConfig, OmegaConf
|
@@ -69,3 +69,17 @@ def prepare_omegaconf():
|
|
69 |
OmegaConf.register_new_resolver("replace", lambda s, x, y: s.replace(x, y))
|
70 |
else:
|
71 |
logger.warning("OmegaConf resolver 'replace' is already registered")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from copy import copy
|
2 |
+
from typing import Any, Dict, List, Optional
|
3 |
|
4 |
from hydra.utils import instantiate
|
5 |
from omegaconf import DictConfig, OmegaConf
|
|
|
69 |
OmegaConf.register_new_resolver("replace", lambda s, x, y: s.replace(x, y))
|
70 |
else:
|
71 |
logger.warning("OmegaConf resolver 'replace' is already registered")
|
72 |
+
|
73 |
+
|
74 |
+
def parse_config(config_string: str, format: str) -> Dict[str, Any]:
|
75 |
+
"""Parse a configuration string."""
|
76 |
+
if format == "json":
|
77 |
+
import json
|
78 |
+
|
79 |
+
return json.loads(config_string)
|
80 |
+
elif format == "yaml":
|
81 |
+
import yaml
|
82 |
+
|
83 |
+
return yaml.safe_load(config_string)
|
84 |
+
else:
|
85 |
+
raise ValueError(f"Unsupported format: {format}. Use 'json' or 'yaml'.")
|
src/utils/pdf_utils/README.MD
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Generate paper json files from a collection xml file, with fulltext extraction.
|
2 |
+
|
3 |
+
This is a slightly re-arranged version of Sotaro Takeshita's code, which is available at https://github.com/gengo-proj/data-factory.
|
4 |
+
|
5 |
+
## Requirements
|
6 |
+
|
7 |
+
- Docker
|
8 |
+
- Python>=3.10
|
9 |
+
- python packages:
|
10 |
+
- acl-anthology-py>=0.4.3
|
11 |
+
- bs4
|
12 |
+
- jsonschema
|
13 |
+
|
14 |
+
## Setup
|
15 |
+
|
16 |
+
Start Grobid Docker container
|
17 |
+
|
18 |
+
```bash
|
19 |
+
docker run --rm --init --ulimit core=0 -p 8070:8070 lfoppiano/grobid:0.8.0
|
20 |
+
```
|
21 |
+
|
22 |
+
Get the meta data from ACL Anthology
|
23 |
+
|
24 |
+
```bash
|
25 |
+
git clone git@github.com:acl-org/acl-anthology.git
|
26 |
+
```
|
27 |
+
|
28 |
+
## Usage
|
29 |
+
|
30 |
+
```bash
|
31 |
+
python src/data/acl_anthology_crawler.py \
|
32 |
+
--base-output-dir <path/to/save/raw-paper.json> \
|
33 |
+
--pdf-output-dir <path/to/save/downloaded/paper.pdf> \
|
34 |
+
--anthology-data-dir ./acl-anthology/data/
|
35 |
+
```
|
src/utils/pdf_utils/__init__.py
ADDED
File without changes
|
src/utils/pdf_utils/acl_anthology_utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Iterator
|
3 |
+
|
4 |
+
from acl_anthology import Anthology
|
5 |
+
|
6 |
+
from .process_pdf import paper_url_to_uuid
|
7 |
+
from .raw_paper import RawPaper
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class XML2RawPapers:
|
12 |
+
anthology: Anthology
|
13 |
+
collection_id_filters: list[str] | None = None
|
14 |
+
venue_id_whitelist: list[str] | None = None
|
15 |
+
verbose: bool = True
|
16 |
+
|
17 |
+
def __call__(self, *args, **kwargs) -> Iterator[RawPaper]:
|
18 |
+
|
19 |
+
for collection_id, collection in self.anthology.collections.items():
|
20 |
+
if self.collection_id_filters is not None:
|
21 |
+
if not any(
|
22 |
+
[
|
23 |
+
collection_id.find(filter_str) != -1
|
24 |
+
for filter_str in self.collection_id_filters
|
25 |
+
]
|
26 |
+
):
|
27 |
+
continue
|
28 |
+
if self.verbose:
|
29 |
+
print(f"Processing collection: {collection_id}")
|
30 |
+
for volume in collection.volumes():
|
31 |
+
if self.venue_id_whitelist is not None:
|
32 |
+
if not any(
|
33 |
+
[venue_id in volume.venue_ids for venue_id in self.venue_id_whitelist]
|
34 |
+
):
|
35 |
+
continue
|
36 |
+
|
37 |
+
volume_id = f"{collection_id}-{volume.id}"
|
38 |
+
|
39 |
+
for paper in volume.papers():
|
40 |
+
fulltext, abstract = None, None
|
41 |
+
if (
|
42 |
+
paper.pdf is not None
|
43 |
+
and paper.pdf.name is not None
|
44 |
+
and paper.pdf.name.find("http") == -1
|
45 |
+
):
|
46 |
+
name = paper.pdf.name
|
47 |
+
else:
|
48 |
+
name = (
|
49 |
+
f"{volume_id}.{paper.id.rjust(3, '0')}"
|
50 |
+
if len(collection_id) == 1
|
51 |
+
else f"{volume_id}.{paper.id}"
|
52 |
+
)
|
53 |
+
|
54 |
+
paper_uuid = paper_url_to_uuid(name)
|
55 |
+
raw_paper = RawPaper(
|
56 |
+
paper_uuid=str(paper_uuid),
|
57 |
+
name=name,
|
58 |
+
collection_id=collection_id,
|
59 |
+
collection_acronym=volume.venues()[0].acronym,
|
60 |
+
volume_id=volume_id,
|
61 |
+
booktitle=volume.title.as_text(),
|
62 |
+
paper_id=int(paper.id),
|
63 |
+
year=int(paper.year),
|
64 |
+
paper_title=paper.title.as_text(),
|
65 |
+
authors=[
|
66 |
+
{"first": author.first, "last": author.last}
|
67 |
+
for author in paper.authors
|
68 |
+
],
|
69 |
+
abstract=(
|
70 |
+
paper.abstract.as_text() if paper.abstract is not None else abstract
|
71 |
+
),
|
72 |
+
url=paper.pdf.url if paper.pdf is not None else None,
|
73 |
+
bibkey=paper.bibkey if paper.bibkey is not None else None,
|
74 |
+
doi=paper.doi if paper.doi is not None else None,
|
75 |
+
fulltext=fulltext,
|
76 |
+
)
|
77 |
+
yield raw_paper
|
src/utils/pdf_utils/client.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Generic API Client """
|
2 |
+
|
3 |
+
import json
|
4 |
+
from copy import deepcopy
|
5 |
+
|
6 |
+
import requests
|
7 |
+
|
8 |
+
try:
|
9 |
+
from urlparse import urljoin
|
10 |
+
except ImportError:
|
11 |
+
from urllib.parse import urljoin
|
12 |
+
|
13 |
+
|
14 |
+
class ApiClient(object):
|
15 |
+
"""Client to interact with a generic Rest API.
|
16 |
+
|
17 |
+
Subclasses should implement functionality accordingly with the provided
|
18 |
+
service methods, i.e. ``get``, ``post``, ``put`` and ``delete``.
|
19 |
+
"""
|
20 |
+
|
21 |
+
accept_type = "application/xml"
|
22 |
+
api_base = None
|
23 |
+
|
24 |
+
def __init__(self, base_url, username=None, api_key=None, status_endpoint=None, timeout=60):
|
25 |
+
"""Initialise client.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
base_url (str): The base URL to the service being used.
|
29 |
+
username (str): The username to authenticate with.
|
30 |
+
api_key (str): The API key to authenticate with.
|
31 |
+
timeout (int): Maximum time before timing out.
|
32 |
+
"""
|
33 |
+
self.base_url = base_url
|
34 |
+
self.username = username
|
35 |
+
self.api_key = api_key
|
36 |
+
self.status_endpoint = urljoin(self.base_url, status_endpoint)
|
37 |
+
self.timeout = timeout
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def encode(request, data):
|
41 |
+
"""Add request content data to request body, set Content-type header.
|
42 |
+
|
43 |
+
Should be overridden by subclasses if not using JSON encoding.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
request (HTTPRequest): The request object.
|
47 |
+
data (dict, None): Data to be encoded.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
HTTPRequest: The request object.
|
51 |
+
"""
|
52 |
+
if data is None:
|
53 |
+
return request
|
54 |
+
|
55 |
+
request.add_header("Content-Type", "application/json")
|
56 |
+
request.data = json.dumps(data)
|
57 |
+
|
58 |
+
return request
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def decode(response):
|
62 |
+
"""Decode the returned data in the response.
|
63 |
+
|
64 |
+
Should be overridden by subclasses if something else than JSON is
|
65 |
+
expected.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
response (HTTPResponse): The response object.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
dict or None.
|
72 |
+
"""
|
73 |
+
try:
|
74 |
+
return response.json()
|
75 |
+
except ValueError as e:
|
76 |
+
return e.message
|
77 |
+
|
78 |
+
def get_credentials(self):
|
79 |
+
"""Returns parameters to be added to authenticate the request.
|
80 |
+
|
81 |
+
This lives on its own to make it easier to re-implement it if needed.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
dict: A dictionary containing the credentials.
|
85 |
+
"""
|
86 |
+
return {"username": self.username, "api_key": self.api_key}
|
87 |
+
|
88 |
+
def call_api(
|
89 |
+
self,
|
90 |
+
method,
|
91 |
+
url,
|
92 |
+
headers=None,
|
93 |
+
params=None,
|
94 |
+
data=None,
|
95 |
+
files=None,
|
96 |
+
timeout=None,
|
97 |
+
):
|
98 |
+
"""Call API.
|
99 |
+
|
100 |
+
This returns object containing data, with error details if applicable.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
method (str): The HTTP method to use.
|
104 |
+
url (str): Resource location relative to the base URL.
|
105 |
+
headers (dict or None): Extra request headers to set.
|
106 |
+
params (dict or None): Query-string parameters.
|
107 |
+
data (dict or None): Request body contents for POST or PUT requests.
|
108 |
+
files (dict or None: Files to be passed to the request.
|
109 |
+
timeout (int): Maximum time before timing out.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
ResultParser or ErrorParser.
|
113 |
+
"""
|
114 |
+
headers = deepcopy(headers) or {}
|
115 |
+
headers["Accept"] = self.accept_type
|
116 |
+
params = deepcopy(params) or {}
|
117 |
+
data = data or {}
|
118 |
+
files = files or {}
|
119 |
+
# if self.username is not None and self.api_key is not None:
|
120 |
+
# params.update(self.get_credentials())
|
121 |
+
r = requests.request(
|
122 |
+
method,
|
123 |
+
url,
|
124 |
+
headers=headers,
|
125 |
+
params=params,
|
126 |
+
files=files,
|
127 |
+
data=data,
|
128 |
+
timeout=timeout,
|
129 |
+
)
|
130 |
+
|
131 |
+
return r, r.status_code
|
132 |
+
|
133 |
+
def get(self, url, params=None, **kwargs):
|
134 |
+
"""Call the API with a GET request.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
url (str): Resource location relative to the base URL.
|
138 |
+
params (dict or None): Query-string parameters.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
ResultParser or ErrorParser.
|
142 |
+
"""
|
143 |
+
return self.call_api("GET", url, params=params, **kwargs)
|
144 |
+
|
145 |
+
def delete(self, url, params=None, **kwargs):
|
146 |
+
"""Call the API with a DELETE request.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
url (str): Resource location relative to the base URL.
|
150 |
+
params (dict or None): Query-string parameters.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
ResultParser or ErrorParser.
|
154 |
+
"""
|
155 |
+
return self.call_api("DELETE", url, params=params, **kwargs)
|
156 |
+
|
157 |
+
def put(self, url, params=None, data=None, files=None, **kwargs):
|
158 |
+
"""Call the API with a PUT request.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
url (str): Resource location relative to the base URL.
|
162 |
+
params (dict or None): Query-string parameters.
|
163 |
+
data (dict or None): Request body contents.
|
164 |
+
files (dict or None: Files to be passed to the request.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
An instance of ResultParser or ErrorParser.
|
168 |
+
"""
|
169 |
+
return self.call_api("PUT", url, params=params, data=data, files=files, **kwargs)
|
170 |
+
|
171 |
+
def post(self, url, params=None, data=None, files=None, **kwargs):
|
172 |
+
"""Call the API with a POST request.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
url (str): Resource location relative to the base URL.
|
176 |
+
params (dict or None): Query-string parameters.
|
177 |
+
data (dict or None): Request body contents.
|
178 |
+
files (dict or None: Files to be passed to the request.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
An instance of ResultParser or ErrorParser.
|
182 |
+
"""
|
183 |
+
return self.call_api(
|
184 |
+
method="POST", url=url, params=params, data=data, files=files, **kwargs
|
185 |
+
)
|
186 |
+
|
187 |
+
def service_status(self, **kwargs):
|
188 |
+
"""Call the API to get the status of the service.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
An instance of ResultParser or ErrorParser.
|
192 |
+
"""
|
193 |
+
return self.call_api("GET", self.status_endpoint, params={"format": "json"}, **kwargs)
|
src/utils/pdf_utils/grobid_client.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import io
|
3 |
+
import ntpath
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
from typing import List, Optional
|
7 |
+
|
8 |
+
from .client import ApiClient
|
9 |
+
|
10 |
+
# This version uses the standard ProcessPoolExecutor for parallelizing the concurrent calls to the GROBID services.
|
11 |
+
# Given the limits of ThreadPoolExecutor (input stored in memory, blocking Executor.map until the whole input
|
12 |
+
# is acquired), it works with batches of PDF of a size indicated in the config.json file (default is 1000 entries).
|
13 |
+
# We are moving from first batch to the second one only when the first is entirely processed - which means it is
|
14 |
+
# slightly sub-optimal, but should scale better. However acquiring a list of million of files in directories would
|
15 |
+
# require something scalable too, which is not implemented for the moment.
|
16 |
+
DEFAULT_GROBID_CONFIG = {
|
17 |
+
"grobid_server": "localhost",
|
18 |
+
"grobid_port": "8070",
|
19 |
+
"batch_size": 1000,
|
20 |
+
"sleep_time": 5,
|
21 |
+
"generateIDs": False,
|
22 |
+
"consolidate_header": False,
|
23 |
+
"consolidate_citations": False,
|
24 |
+
"include_raw_citations": True,
|
25 |
+
"include_raw_affiliations": False,
|
26 |
+
"max_workers": 2,
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
class GrobidClient(ApiClient):
|
31 |
+
def __init__(self, config=None):
|
32 |
+
self.config = config or DEFAULT_GROBID_CONFIG
|
33 |
+
self.generate_ids = self.config["generateIDs"]
|
34 |
+
self.consolidate_header = self.config["consolidate_header"]
|
35 |
+
self.consolidate_citations = self.config["consolidate_citations"]
|
36 |
+
self.include_raw_citations = self.config["include_raw_citations"]
|
37 |
+
self.include_raw_affiliations = self.config["include_raw_affiliations"]
|
38 |
+
self.max_workers = self.config["max_workers"]
|
39 |
+
self.grobid_server = self.config["grobid_server"]
|
40 |
+
self.grobid_port = str(self.config["grobid_port"])
|
41 |
+
self.sleep_time = self.config["sleep_time"]
|
42 |
+
|
43 |
+
def process(self, input: str, output: str, service: str):
|
44 |
+
batch_size_pdf = self.config["batch_size"]
|
45 |
+
pdf_files = []
|
46 |
+
|
47 |
+
for pdf_file in glob.glob(input + "/*.pdf"):
|
48 |
+
pdf_files.append(pdf_file)
|
49 |
+
|
50 |
+
if len(pdf_files) == batch_size_pdf:
|
51 |
+
self.process_batch(pdf_files, output, service)
|
52 |
+
pdf_files = []
|
53 |
+
|
54 |
+
# last batch
|
55 |
+
if len(pdf_files) > 0:
|
56 |
+
self.process_batch(pdf_files, output, service)
|
57 |
+
|
58 |
+
def process_batch(self, pdf_files: List[str], output: str, service: str) -> None:
|
59 |
+
print(len(pdf_files), "PDF files to process")
|
60 |
+
for pdf_file in pdf_files:
|
61 |
+
self.process_pdf(pdf_file, output, service)
|
62 |
+
|
63 |
+
def process_pdf_stream(self, pdf_file: str, pdf_strm: bytes, output: str, service: str) -> str:
|
64 |
+
# process the stream
|
65 |
+
files = {"input": (pdf_file, pdf_strm, "application/pdf", {"Expires": "0"})}
|
66 |
+
|
67 |
+
the_url = "http://" + self.grobid_server
|
68 |
+
the_url += ":" + self.grobid_port
|
69 |
+
the_url += "/api/" + service
|
70 |
+
|
71 |
+
# set the GROBID parameters
|
72 |
+
the_data = {}
|
73 |
+
if self.generate_ids:
|
74 |
+
the_data["generateIDs"] = "1"
|
75 |
+
else:
|
76 |
+
the_data["generateIDs"] = "0"
|
77 |
+
|
78 |
+
if self.consolidate_header:
|
79 |
+
the_data["consolidateHeader"] = "1"
|
80 |
+
else:
|
81 |
+
the_data["consolidateHeader"] = "0"
|
82 |
+
|
83 |
+
if self.consolidate_citations:
|
84 |
+
the_data["consolidateCitations"] = "1"
|
85 |
+
else:
|
86 |
+
the_data["consolidateCitations"] = "0"
|
87 |
+
|
88 |
+
if self.include_raw_affiliations:
|
89 |
+
the_data["includeRawAffiliations"] = "1"
|
90 |
+
else:
|
91 |
+
the_data["includeRawAffiliations"] = "0"
|
92 |
+
|
93 |
+
if self.include_raw_citations:
|
94 |
+
the_data["includeRawCitations"] = "1"
|
95 |
+
else:
|
96 |
+
the_data["includeRawCitations"] = "0"
|
97 |
+
|
98 |
+
res, status = self.post(
|
99 |
+
url=the_url, files=files, data=the_data, headers={"Accept": "text/plain"}
|
100 |
+
)
|
101 |
+
|
102 |
+
if status == 503:
|
103 |
+
time.sleep(self.sleep_time)
|
104 |
+
# TODO: check if simply passing output as output is correct
|
105 |
+
return self.process_pdf_stream(
|
106 |
+
pdf_file=pdf_file, pdf_strm=pdf_strm, service=service, output=output
|
107 |
+
)
|
108 |
+
elif status != 200:
|
109 |
+
with open(os.path.join(output, "failed.log"), "a+") as failed:
|
110 |
+
failed.write(pdf_file.strip(".pdf") + "\n")
|
111 |
+
print("Processing failed with error " + str(status))
|
112 |
+
return ""
|
113 |
+
else:
|
114 |
+
return res.text
|
115 |
+
|
116 |
+
def process_pdf(self, pdf_file: str, output: str, service: str) -> None:
|
117 |
+
# check if TEI file is already produced
|
118 |
+
# we use ntpath here to be sure it will work on Windows too
|
119 |
+
pdf_file_name = ntpath.basename(pdf_file)
|
120 |
+
filename = os.path.join(output, os.path.splitext(pdf_file_name)[0] + ".tei.xml")
|
121 |
+
if os.path.isfile(filename):
|
122 |
+
return
|
123 |
+
|
124 |
+
print(pdf_file)
|
125 |
+
pdf_strm = open(pdf_file, "rb").read()
|
126 |
+
tei_text = self.process_pdf_stream(pdf_file, pdf_strm, output, service)
|
127 |
+
|
128 |
+
# writing TEI file
|
129 |
+
if tei_text:
|
130 |
+
with io.open(filename, "w+", encoding="utf8") as tei_file:
|
131 |
+
tei_file.write(tei_text)
|
132 |
+
|
133 |
+
def process_citation(self, bib_string: str, log_file: str) -> Optional[str]:
|
134 |
+
# process citation raw string and return corresponding dict
|
135 |
+
the_data = {"citations": bib_string, "consolidateCitations": "0"}
|
136 |
+
|
137 |
+
the_url = "http://" + self.grobid_server
|
138 |
+
the_url += ":" + self.grobid_port
|
139 |
+
the_url += "/api/processCitation"
|
140 |
+
|
141 |
+
for _ in range(5):
|
142 |
+
try:
|
143 |
+
res, status = self.post(
|
144 |
+
url=the_url, data=the_data, headers={"Accept": "text/plain"}
|
145 |
+
)
|
146 |
+
if status == 503:
|
147 |
+
time.sleep(self.sleep_time)
|
148 |
+
continue
|
149 |
+
elif status != 200:
|
150 |
+
with open(log_file, "a+") as failed:
|
151 |
+
failed.write("-- BIBSTR --\n")
|
152 |
+
failed.write(bib_string + "\n\n")
|
153 |
+
break
|
154 |
+
else:
|
155 |
+
return res.text
|
156 |
+
except Exception:
|
157 |
+
continue
|
158 |
+
|
159 |
+
return None
|
160 |
+
|
161 |
+
def process_header_names(self, header_string: str, log_file: str) -> Optional[str]:
|
162 |
+
# process author names from header string
|
163 |
+
the_data = {"names": header_string}
|
164 |
+
|
165 |
+
the_url = "http://" + self.grobid_server
|
166 |
+
the_url += ":" + self.grobid_port
|
167 |
+
the_url += "/api/processHeaderNames"
|
168 |
+
|
169 |
+
res, status = self.post(url=the_url, data=the_data, headers={"Accept": "text/plain"})
|
170 |
+
|
171 |
+
if status == 503:
|
172 |
+
time.sleep(self.sleep_time)
|
173 |
+
return self.process_header_names(header_string, log_file)
|
174 |
+
elif status != 200:
|
175 |
+
with open(log_file, "a+") as failed:
|
176 |
+
failed.write("-- AUTHOR --\n")
|
177 |
+
failed.write(header_string + "\n\n")
|
178 |
+
else:
|
179 |
+
return res.text
|
180 |
+
|
181 |
+
return None
|
182 |
+
|
183 |
+
def process_affiliations(self, aff_string: str, log_file: str) -> Optional[str]:
|
184 |
+
# process affiliation from input string
|
185 |
+
the_data = {"affiliations": aff_string}
|
186 |
+
|
187 |
+
the_url = "http://" + self.grobid_server
|
188 |
+
the_url += ":" + self.grobid_port
|
189 |
+
the_url += "/api/processAffiliations"
|
190 |
+
|
191 |
+
res, status = self.post(url=the_url, data=the_data, headers={"Accept": "text/plain"})
|
192 |
+
|
193 |
+
if status == 503:
|
194 |
+
time.sleep(self.sleep_time)
|
195 |
+
return self.process_affiliations(aff_string, log_file)
|
196 |
+
elif status != 200:
|
197 |
+
with open(log_file, "a+") as failed:
|
198 |
+
failed.write("-- AFFILIATION --\n")
|
199 |
+
failed.write(aff_string + "\n\n")
|
200 |
+
else:
|
201 |
+
return res.text
|
202 |
+
|
203 |
+
return None
|
src/utils/pdf_utils/grobid_util.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from collections import defaultdict
|
3 |
+
from typing import Dict, List, Optional, Union
|
4 |
+
|
5 |
+
import bs4
|
6 |
+
from bs4 import BeautifulSoup
|
7 |
+
|
8 |
+
SUBSTITUTE_TAGS = {"persName", "orgName", "publicationStmt", "titleStmt", "biblScope"}
|
9 |
+
|
10 |
+
|
11 |
+
def clean_tags(el: bs4.element.Tag):
|
12 |
+
"""
|
13 |
+
Replace all tags with lowercase version
|
14 |
+
:param el:
|
15 |
+
:return:
|
16 |
+
"""
|
17 |
+
for sub_tag in SUBSTITUTE_TAGS:
|
18 |
+
for sub_el in el.find_all(sub_tag):
|
19 |
+
sub_el.name = sub_tag.lower()
|
20 |
+
|
21 |
+
|
22 |
+
def soup_from_path(file_path: str):
|
23 |
+
"""
|
24 |
+
Read XML file
|
25 |
+
:param file_path:
|
26 |
+
:return:
|
27 |
+
"""
|
28 |
+
return BeautifulSoup(open(file_path, "rb").read(), "xml")
|
29 |
+
|
30 |
+
|
31 |
+
def get_title_from_grobid_xml(raw_xml: BeautifulSoup) -> str:
|
32 |
+
"""
|
33 |
+
Returns title
|
34 |
+
:return:
|
35 |
+
"""
|
36 |
+
for title_entry in raw_xml.find_all("title"):
|
37 |
+
if title_entry.has_attr("level") and title_entry["level"] == "a":
|
38 |
+
return title_entry.text
|
39 |
+
try:
|
40 |
+
return raw_xml.title.text
|
41 |
+
except AttributeError:
|
42 |
+
return ""
|
43 |
+
|
44 |
+
|
45 |
+
def get_author_names_from_grobid_xml(
|
46 |
+
raw_xml: BeautifulSoup,
|
47 |
+
) -> List[Dict[str, Union[str, List[str]]]]:
|
48 |
+
"""
|
49 |
+
Returns a list of dictionaries, one for each author,
|
50 |
+
containing the first and last names.
|
51 |
+
|
52 |
+
e.g.
|
53 |
+
{
|
54 |
+
"first": first,
|
55 |
+
"middle": middle,
|
56 |
+
"last": last,
|
57 |
+
"suffix": suffix
|
58 |
+
}
|
59 |
+
"""
|
60 |
+
names = []
|
61 |
+
|
62 |
+
for author in raw_xml.find_all("author"):
|
63 |
+
if not author.persname:
|
64 |
+
continue
|
65 |
+
|
66 |
+
# forenames include first and middle names
|
67 |
+
forenames = author.persname.find_all("forename")
|
68 |
+
|
69 |
+
# surnames include last names
|
70 |
+
surnames = author.persname.find_all("surname")
|
71 |
+
|
72 |
+
# name suffixes
|
73 |
+
suffixes = author.persname.find_all("suffix")
|
74 |
+
|
75 |
+
first = ""
|
76 |
+
middle = []
|
77 |
+
last = ""
|
78 |
+
suffix = ""
|
79 |
+
|
80 |
+
for forename in forenames:
|
81 |
+
if forename["type"] == "first":
|
82 |
+
if not first:
|
83 |
+
first = forename.text
|
84 |
+
else:
|
85 |
+
middle.append(forename.text)
|
86 |
+
elif forename["type"] == "middle":
|
87 |
+
middle.append(forename.text)
|
88 |
+
|
89 |
+
if len(surnames) > 1:
|
90 |
+
for surname in surnames[:-1]:
|
91 |
+
middle.append(surname.text)
|
92 |
+
last = surnames[-1].text
|
93 |
+
elif len(surnames) == 1:
|
94 |
+
last = surnames[0].text
|
95 |
+
|
96 |
+
if len(suffix) >= 1:
|
97 |
+
suffix = " ".join([suff.text for suff in suffixes])
|
98 |
+
|
99 |
+
names_dict: Dict[str, Union[str, List[str]]] = {
|
100 |
+
"first": first,
|
101 |
+
"middle": middle,
|
102 |
+
"last": last,
|
103 |
+
"suffix": suffix,
|
104 |
+
}
|
105 |
+
|
106 |
+
names.append(names_dict)
|
107 |
+
return names
|
108 |
+
|
109 |
+
|
110 |
+
def get_affiliation_from_grobid_xml(raw_xml: BeautifulSoup) -> Dict:
|
111 |
+
"""
|
112 |
+
Get affiliation from grobid xml
|
113 |
+
:param raw_xml:
|
114 |
+
:return:
|
115 |
+
"""
|
116 |
+
location_dict = dict()
|
117 |
+
laboratory_name = ""
|
118 |
+
institution_name = ""
|
119 |
+
|
120 |
+
if raw_xml and raw_xml.affiliation:
|
121 |
+
for child in raw_xml.affiliation:
|
122 |
+
if child.name == "orgname":
|
123 |
+
if child.has_attr("type"):
|
124 |
+
if child["type"] == "laboratory":
|
125 |
+
laboratory_name = child.text
|
126 |
+
elif child["type"] == "institution":
|
127 |
+
institution_name = child.text
|
128 |
+
elif child.name == "address":
|
129 |
+
for grandchild in child:
|
130 |
+
if grandchild.name and grandchild.text:
|
131 |
+
location_dict[grandchild.name] = grandchild.text
|
132 |
+
|
133 |
+
if laboratory_name or institution_name:
|
134 |
+
return {
|
135 |
+
"laboratory": laboratory_name,
|
136 |
+
"institution": institution_name,
|
137 |
+
"location": location_dict,
|
138 |
+
}
|
139 |
+
|
140 |
+
return {}
|
141 |
+
|
142 |
+
|
143 |
+
def get_author_data_from_grobid_xml(raw_xml: BeautifulSoup) -> List[Dict]:
|
144 |
+
"""
|
145 |
+
Returns a list of dictionaries, one for each author,
|
146 |
+
containing the first and last names.
|
147 |
+
|
148 |
+
e.g.
|
149 |
+
{
|
150 |
+
"first": first,
|
151 |
+
"middle": middle,
|
152 |
+
"last": last,
|
153 |
+
"suffix": suffix,
|
154 |
+
"affiliation": {
|
155 |
+
"laboratory": "",
|
156 |
+
"institution": "",
|
157 |
+
"location": "",
|
158 |
+
},
|
159 |
+
"email": ""
|
160 |
+
}
|
161 |
+
"""
|
162 |
+
authors = []
|
163 |
+
|
164 |
+
for author in raw_xml.find_all("author"):
|
165 |
+
|
166 |
+
first = ""
|
167 |
+
middle = []
|
168 |
+
last = ""
|
169 |
+
suffix = ""
|
170 |
+
|
171 |
+
if author.persname:
|
172 |
+
# forenames include first and middle names
|
173 |
+
forenames = author.persname.find_all("forename")
|
174 |
+
|
175 |
+
# surnames include last names
|
176 |
+
surnames = author.persname.find_all("surname")
|
177 |
+
|
178 |
+
# name suffixes
|
179 |
+
suffixes = author.persname.find_all("suffix")
|
180 |
+
|
181 |
+
for forename in forenames:
|
182 |
+
if forename.has_attr("type"):
|
183 |
+
if forename["type"] == "first":
|
184 |
+
if not first:
|
185 |
+
first = forename.text
|
186 |
+
else:
|
187 |
+
middle.append(forename.text)
|
188 |
+
elif forename["type"] == "middle":
|
189 |
+
middle.append(forename.text)
|
190 |
+
|
191 |
+
if len(surnames) > 1:
|
192 |
+
for surname in surnames[:-1]:
|
193 |
+
middle.append(surname.text)
|
194 |
+
last = surnames[-1].text
|
195 |
+
elif len(surnames) == 1:
|
196 |
+
last = surnames[0].text
|
197 |
+
|
198 |
+
if len(suffix) >= 1:
|
199 |
+
suffix = " ".join([suffix.text for suffix in suffixes])
|
200 |
+
|
201 |
+
affiliation = get_affiliation_from_grobid_xml(author)
|
202 |
+
|
203 |
+
email = ""
|
204 |
+
if author.email:
|
205 |
+
email = author.email.text
|
206 |
+
|
207 |
+
author_dict = {
|
208 |
+
"first": first,
|
209 |
+
"middle": middle,
|
210 |
+
"last": last,
|
211 |
+
"suffix": suffix,
|
212 |
+
"affiliation": affiliation,
|
213 |
+
"email": email,
|
214 |
+
}
|
215 |
+
|
216 |
+
authors.append(author_dict)
|
217 |
+
|
218 |
+
return authors
|
219 |
+
|
220 |
+
|
221 |
+
def get_year_from_grobid_xml(raw_xml: BeautifulSoup) -> Optional[int]:
|
222 |
+
"""
|
223 |
+
Returns date published if exists
|
224 |
+
:return:
|
225 |
+
"""
|
226 |
+
if raw_xml.date and raw_xml.date.has_attr("when"):
|
227 |
+
# match year in date text (which is in some unspecified date format)
|
228 |
+
year_match = re.match(r"((19|20)\d{2})", raw_xml.date["when"])
|
229 |
+
if year_match:
|
230 |
+
year = year_match.group(0)
|
231 |
+
if year and year.isnumeric() and len(year) == 4:
|
232 |
+
return int(year)
|
233 |
+
return None
|
234 |
+
|
235 |
+
|
236 |
+
def get_venue_from_grobid_xml(raw_xml: BeautifulSoup, title_text: str) -> str:
|
237 |
+
"""
|
238 |
+
Returns venue/journal/publisher of bib entry
|
239 |
+
Grobid ref documentation: https://grobid.readthedocs.io/en/latest/training/Bibliographical-references/
|
240 |
+
level="j": journal title
|
241 |
+
level="m": "non journal bibliographical item holding the cited article"
|
242 |
+
level="s": series title
|
243 |
+
:return:
|
244 |
+
"""
|
245 |
+
title_names = []
|
246 |
+
keep_types = ["j", "m", "s"]
|
247 |
+
# get all titles of the anove types
|
248 |
+
for title_entry in raw_xml.find_all("title"):
|
249 |
+
if (
|
250 |
+
title_entry.has_attr("level")
|
251 |
+
and title_entry["level"] in keep_types
|
252 |
+
and title_entry.text != title_text
|
253 |
+
):
|
254 |
+
title_names.append((title_entry["level"], title_entry.text))
|
255 |
+
# return the title name that most likely belongs to the journal or publication venue
|
256 |
+
if title_names:
|
257 |
+
title_names.sort(key=lambda x: keep_types.index(x[0]))
|
258 |
+
return title_names[0][1]
|
259 |
+
return ""
|
260 |
+
|
261 |
+
|
262 |
+
def get_volume_from_grobid_xml(raw_xml: BeautifulSoup) -> str:
|
263 |
+
"""
|
264 |
+
Returns the volume number of grobid bib entry
|
265 |
+
Grobid <biblscope unit="volume">
|
266 |
+
:return:
|
267 |
+
"""
|
268 |
+
for bibl_entry in raw_xml.find_all("biblscope"):
|
269 |
+
if bibl_entry.has_attr("unit") and bibl_entry["unit"] == "volume":
|
270 |
+
return bibl_entry.text
|
271 |
+
return ""
|
272 |
+
|
273 |
+
|
274 |
+
def get_issue_from_grobid_xml(raw_xml: BeautifulSoup) -> str:
|
275 |
+
"""
|
276 |
+
Returns the issue number of grobid bib entry
|
277 |
+
Grobid <biblscope unit="issue">
|
278 |
+
:return:
|
279 |
+
"""
|
280 |
+
for bibl_entry in raw_xml.find_all("biblscope"):
|
281 |
+
if bibl_entry.has_attr("unit") and bibl_entry["unit"] == "issue":
|
282 |
+
return bibl_entry.text
|
283 |
+
return ""
|
284 |
+
|
285 |
+
|
286 |
+
def get_pages_from_grobid_xml(raw_xml: BeautifulSoup) -> str:
|
287 |
+
"""
|
288 |
+
Returns the page numbers of grobid bib entry
|
289 |
+
Grobid <biblscope unit="page">
|
290 |
+
:return:
|
291 |
+
"""
|
292 |
+
for bibl_entry in raw_xml.find_all("biblscope"):
|
293 |
+
if (
|
294 |
+
bibl_entry.has_attr("unit")
|
295 |
+
and bibl_entry["unit"] == "page"
|
296 |
+
and bibl_entry.has_attr("from")
|
297 |
+
):
|
298 |
+
from_page = bibl_entry["from"]
|
299 |
+
if bibl_entry.has_attr("to"):
|
300 |
+
to_page = bibl_entry["to"]
|
301 |
+
return f"{from_page}--{to_page}"
|
302 |
+
else:
|
303 |
+
return from_page
|
304 |
+
return ""
|
305 |
+
|
306 |
+
|
307 |
+
def get_other_ids_from_grobid_xml(raw_xml: BeautifulSoup) -> Dict[str, List]:
|
308 |
+
"""
|
309 |
+
Returns a dictionary of other identifiers from grobid bib entry (arxiv, pubmed, doi)
|
310 |
+
:param raw_xml:
|
311 |
+
:return:
|
312 |
+
"""
|
313 |
+
other_ids = defaultdict(list)
|
314 |
+
|
315 |
+
for idno_entry in raw_xml.find_all("idno"):
|
316 |
+
if idno_entry.has_attr("type") and idno_entry.text:
|
317 |
+
other_ids[idno_entry["type"]].append(idno_entry.text)
|
318 |
+
|
319 |
+
return other_ids
|
320 |
+
|
321 |
+
|
322 |
+
def get_raw_bib_text_from_grobid_xml(raw_xml: BeautifulSoup) -> str:
|
323 |
+
"""
|
324 |
+
Returns the raw bibiliography string
|
325 |
+
:param raw_xml:
|
326 |
+
:return:
|
327 |
+
"""
|
328 |
+
for note in raw_xml.find_all("note"):
|
329 |
+
if note.has_attr("type") and note["type"] == "raw_reference":
|
330 |
+
return note.text
|
331 |
+
return ""
|
332 |
+
|
333 |
+
|
334 |
+
def get_publication_datetime_from_grobid_xml(raw_xml: BeautifulSoup) -> str:
|
335 |
+
"""
|
336 |
+
Finds and returns the publication datetime if it exists
|
337 |
+
:param raw_xml:
|
338 |
+
:return:
|
339 |
+
"""
|
340 |
+
if raw_xml.publicationStmt:
|
341 |
+
for child in raw_xml.publicationstmt:
|
342 |
+
if (
|
343 |
+
child.name == "date"
|
344 |
+
and child.has_attr("type")
|
345 |
+
and child["type"] == "published"
|
346 |
+
and child.has_attr("when")
|
347 |
+
):
|
348 |
+
return child["when"]
|
349 |
+
return ""
|
350 |
+
|
351 |
+
|
352 |
+
def parse_bib_entry(bib_entry: BeautifulSoup) -> Dict:
|
353 |
+
"""
|
354 |
+
Parse one bib entry
|
355 |
+
:param bib_entry:
|
356 |
+
:return:
|
357 |
+
"""
|
358 |
+
clean_tags(bib_entry)
|
359 |
+
title = get_title_from_grobid_xml(bib_entry)
|
360 |
+
return {
|
361 |
+
"ref_id": bib_entry.attrs.get("xml:id", None),
|
362 |
+
"title": title,
|
363 |
+
"authors": get_author_names_from_grobid_xml(bib_entry),
|
364 |
+
"year": get_year_from_grobid_xml(bib_entry),
|
365 |
+
"venue": get_venue_from_grobid_xml(bib_entry, title),
|
366 |
+
"volume": get_volume_from_grobid_xml(bib_entry),
|
367 |
+
"issue": get_issue_from_grobid_xml(bib_entry),
|
368 |
+
"pages": get_pages_from_grobid_xml(bib_entry),
|
369 |
+
"other_ids": get_other_ids_from_grobid_xml(bib_entry),
|
370 |
+
"raw_text": get_raw_bib_text_from_grobid_xml(bib_entry),
|
371 |
+
"urls": [],
|
372 |
+
}
|
373 |
+
|
374 |
+
|
375 |
+
def is_reference_tag(tag: bs4.element.Tag) -> bool:
|
376 |
+
return tag.name == "ref" and tag.attrs.get("type", "") == "bibr"
|
377 |
+
|
378 |
+
|
379 |
+
def extract_paper_metadata_from_grobid_xml(tag: bs4.element.Tag) -> Dict:
|
380 |
+
"""
|
381 |
+
Extract paper metadata (title, authors, affiliation, year) from grobid xml
|
382 |
+
:param tag:
|
383 |
+
:return:
|
384 |
+
"""
|
385 |
+
clean_tags(tag)
|
386 |
+
paper_metadata = {
|
387 |
+
"title": tag.titlestmt.title.text,
|
388 |
+
"authors": get_author_data_from_grobid_xml(tag),
|
389 |
+
"year": get_publication_datetime_from_grobid_xml(tag),
|
390 |
+
}
|
391 |
+
return paper_metadata
|
392 |
+
|
393 |
+
|
394 |
+
def parse_bibliography(soup: BeautifulSoup) -> List[Dict]:
|
395 |
+
"""
|
396 |
+
Finds all bibliography entries in a grobid xml.
|
397 |
+
"""
|
398 |
+
bibliography = soup.listBibl
|
399 |
+
if bibliography is None:
|
400 |
+
return []
|
401 |
+
|
402 |
+
entries = bibliography.find_all("biblStruct")
|
403 |
+
|
404 |
+
structured_entries = []
|
405 |
+
for entry in entries:
|
406 |
+
bib_entry = parse_bib_entry(entry)
|
407 |
+
# add bib entry only if it has a title
|
408 |
+
if bib_entry["title"]:
|
409 |
+
structured_entries.append(bib_entry)
|
410 |
+
|
411 |
+
bibliography.decompose()
|
412 |
+
|
413 |
+
return structured_entries
|
src/utils/pdf_utils/process_pdf.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import uuid
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Dict, Optional
|
7 |
+
|
8 |
+
import requests
|
9 |
+
from bs4 import BeautifulSoup
|
10 |
+
|
11 |
+
from .grobid_client import GrobidClient
|
12 |
+
from .grobid_util import extract_paper_metadata_from_grobid_xml, parse_bibliography
|
13 |
+
from .s2orc_paper import Paper
|
14 |
+
from .utils import (
|
15 |
+
_clean_empty_and_duplicate_authors_from_grobid_parse,
|
16 |
+
check_if_citations_are_bracket_style,
|
17 |
+
extract_abstract_from_tei_xml,
|
18 |
+
extract_back_matter_from_tei_xml,
|
19 |
+
extract_body_text_from_tei_xml,
|
20 |
+
extract_figures_and_tables_from_tei_xml,
|
21 |
+
normalize_grobid_id,
|
22 |
+
sub_all_note_tags,
|
23 |
+
)
|
24 |
+
|
25 |
+
BASE_TEMP_DIR = "./grobid/temp"
|
26 |
+
BASE_OUTPUT_DIR = "./grobid/output"
|
27 |
+
BASE_LOG_DIR = "./grobid/log"
|
28 |
+
|
29 |
+
|
30 |
+
def convert_tei_xml_soup_to_s2orc_json(soup: BeautifulSoup, paper_id: str, pdf_hash: str) -> Paper:
|
31 |
+
"""
|
32 |
+
Convert Grobid TEI XML to S2ORC json format
|
33 |
+
:param soup: BeautifulSoup of XML file content
|
34 |
+
:param paper_id: name of file
|
35 |
+
:param pdf_hash: hash of PDF
|
36 |
+
:return:
|
37 |
+
"""
|
38 |
+
# extract metadata
|
39 |
+
metadata = extract_paper_metadata_from_grobid_xml(soup.fileDesc)
|
40 |
+
# clean metadata authors (remove dupes etc)
|
41 |
+
metadata["authors"] = _clean_empty_and_duplicate_authors_from_grobid_parse(metadata["authors"])
|
42 |
+
|
43 |
+
# parse bibliography entries (removes empty bib entries)
|
44 |
+
biblio_entries = parse_bibliography(soup)
|
45 |
+
bibkey_map = {normalize_grobid_id(bib["ref_id"]): bib for bib in biblio_entries}
|
46 |
+
|
47 |
+
# # process formulas and replace with text
|
48 |
+
# extract_formulas_from_tei_xml(soup)
|
49 |
+
|
50 |
+
# extract figure and table captions
|
51 |
+
refkey_map = extract_figures_and_tables_from_tei_xml(soup)
|
52 |
+
|
53 |
+
# get bracket style
|
54 |
+
is_bracket_style = check_if_citations_are_bracket_style(soup)
|
55 |
+
|
56 |
+
# substitute all note tags with p tags
|
57 |
+
soup = sub_all_note_tags(soup)
|
58 |
+
|
59 |
+
# process abstract if possible
|
60 |
+
abstract_entries = extract_abstract_from_tei_xml(
|
61 |
+
soup, bibkey_map, refkey_map, is_bracket_style
|
62 |
+
)
|
63 |
+
|
64 |
+
# process body text
|
65 |
+
body_entries = extract_body_text_from_tei_xml(soup, bibkey_map, refkey_map, is_bracket_style)
|
66 |
+
|
67 |
+
# parse back matter (acks, author statements, competing interests, abbrevs etc)
|
68 |
+
back_matter = extract_back_matter_from_tei_xml(soup, bibkey_map, refkey_map, is_bracket_style)
|
69 |
+
|
70 |
+
# form final paper entry
|
71 |
+
return Paper(
|
72 |
+
paper_id=paper_id,
|
73 |
+
pdf_hash=pdf_hash,
|
74 |
+
metadata=metadata,
|
75 |
+
abstract=abstract_entries,
|
76 |
+
body_text=body_entries,
|
77 |
+
back_matter=back_matter,
|
78 |
+
bib_entries=bibkey_map,
|
79 |
+
ref_entries=refkey_map,
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
def convert_tei_xml_file_to_s2orc_json(tei_file: str, pdf_hash: str = "") -> Paper:
|
84 |
+
"""
|
85 |
+
Convert a TEI XML file to S2ORC JSON
|
86 |
+
:param tei_file:
|
87 |
+
:param pdf_hash:
|
88 |
+
:return:
|
89 |
+
"""
|
90 |
+
if not os.path.exists(tei_file):
|
91 |
+
raise FileNotFoundError("Input TEI XML file doesn't exist")
|
92 |
+
paper_id = tei_file.split("/")[-1].split(".")[0]
|
93 |
+
soup = BeautifulSoup(open(tei_file, "rb").read(), "xml")
|
94 |
+
paper = convert_tei_xml_soup_to_s2orc_json(soup, paper_id, pdf_hash)
|
95 |
+
return paper
|
96 |
+
|
97 |
+
|
98 |
+
def process_pdf_stream(
|
99 |
+
input_file: str, sha: str, input_stream: bytes, grobid_config: Optional[Dict] = None
|
100 |
+
) -> Dict:
|
101 |
+
"""
|
102 |
+
Process PDF stream
|
103 |
+
:param input_file:
|
104 |
+
:param sha:
|
105 |
+
:param input_stream:
|
106 |
+
:return:
|
107 |
+
"""
|
108 |
+
# process PDF through Grobid -> TEI.XML
|
109 |
+
client = GrobidClient(grobid_config)
|
110 |
+
tei_text = client.process_pdf_stream(
|
111 |
+
input_file, input_stream, "temp", "processFulltextDocument"
|
112 |
+
)
|
113 |
+
|
114 |
+
# make soup
|
115 |
+
soup = BeautifulSoup(tei_text, "xml")
|
116 |
+
|
117 |
+
# get paper
|
118 |
+
paper = convert_tei_xml_soup_to_s2orc_json(soup, input_file, sha)
|
119 |
+
|
120 |
+
return paper.release_json("pdf")
|
121 |
+
|
122 |
+
|
123 |
+
def process_pdf_file(
|
124 |
+
input_file: str,
|
125 |
+
temp_dir: str = BASE_TEMP_DIR,
|
126 |
+
output_dir: str = BASE_OUTPUT_DIR,
|
127 |
+
grobid_config: Optional[Dict] = None,
|
128 |
+
verbose: bool = True,
|
129 |
+
) -> str:
|
130 |
+
"""
|
131 |
+
Process a PDF file and get JSON representation
|
132 |
+
:param input_file:
|
133 |
+
:param temp_dir:
|
134 |
+
:param output_dir:
|
135 |
+
:return:
|
136 |
+
"""
|
137 |
+
os.makedirs(temp_dir, exist_ok=True)
|
138 |
+
os.makedirs(output_dir, exist_ok=True)
|
139 |
+
|
140 |
+
# get paper id as the name of the file
|
141 |
+
paper_id = ".".join(input_file.split("/")[-1].split(".")[:-1])
|
142 |
+
tei_file = os.path.join(temp_dir, f"{paper_id}.tei.xml")
|
143 |
+
output_file = os.path.join(output_dir, f"{paper_id}.json")
|
144 |
+
|
145 |
+
# check if input file exists and output file doesn't
|
146 |
+
if not os.path.exists(input_file):
|
147 |
+
raise FileNotFoundError(f"{input_file} doesn't exist")
|
148 |
+
if os.path.exists(output_file):
|
149 |
+
if verbose:
|
150 |
+
print(f"{output_file} already exists!")
|
151 |
+
return output_file
|
152 |
+
|
153 |
+
# process PDF through Grobid -> TEI.XML
|
154 |
+
client = GrobidClient(grobid_config)
|
155 |
+
# TODO: compute PDF hash
|
156 |
+
# TODO: add grobid version number to output
|
157 |
+
client.process_pdf(input_file, temp_dir, "processFulltextDocument")
|
158 |
+
|
159 |
+
# process TEI.XML -> JSON
|
160 |
+
assert os.path.exists(tei_file)
|
161 |
+
paper = convert_tei_xml_file_to_s2orc_json(tei_file)
|
162 |
+
|
163 |
+
# write to file
|
164 |
+
with open(output_file, "w") as outf:
|
165 |
+
json.dump(paper.release_json(), outf, indent=4, sort_keys=False)
|
166 |
+
|
167 |
+
return output_file
|
168 |
+
|
169 |
+
|
170 |
+
UUID_NAMESPACE = uuid.UUID("bab08d37-ac12-40c4-847a-20ca337742fd")
|
171 |
+
|
172 |
+
|
173 |
+
def paper_url_to_uuid(paper_url: str) -> "uuid.UUID":
|
174 |
+
return uuid.uuid5(UUID_NAMESPACE, paper_url)
|
175 |
+
|
176 |
+
|
177 |
+
@dataclass
|
178 |
+
class PDFDownloader:
|
179 |
+
verbose: bool = True
|
180 |
+
|
181 |
+
def download(self, url: str, opath: str | Path) -> Path:
|
182 |
+
"""Download a pdf file from URL and save locally.
|
183 |
+
Skip if there is a file at `opath` already.
|
184 |
+
|
185 |
+
Parameters
|
186 |
+
----------
|
187 |
+
url : str
|
188 |
+
URL of the target PDF file
|
189 |
+
opath : str
|
190 |
+
Path to save downloaded PDF data.
|
191 |
+
"""
|
192 |
+
if os.path.exists(opath):
|
193 |
+
return Path(opath)
|
194 |
+
|
195 |
+
if not os.path.exists(os.path.dirname(opath)):
|
196 |
+
os.makedirs(os.path.dirname(opath), exist_ok=True)
|
197 |
+
|
198 |
+
if self.verbose:
|
199 |
+
print(f"Downloading {url} into {opath}")
|
200 |
+
with open(opath, "wb") as f:
|
201 |
+
res = requests.get(url)
|
202 |
+
f.write(res.content)
|
203 |
+
|
204 |
+
return Path(opath)
|
205 |
+
|
206 |
+
|
207 |
+
@dataclass
|
208 |
+
class FulltextExtractor:
|
209 |
+
|
210 |
+
def __call__(self, pdf_file_path: Path | str) -> tuple[str, dict] | None:
|
211 |
+
"""Extract plain text from a PDf file"""
|
212 |
+
raise NotImplementedError
|
213 |
+
|
214 |
+
|
215 |
+
@dataclass
|
216 |
+
class GrobidFulltextExtractor(FulltextExtractor):
|
217 |
+
tmp_dir: str = "./tmp/grobid"
|
218 |
+
grobid_config: Optional[Dict] = None
|
219 |
+
section_seperator: str = "\n\n"
|
220 |
+
paragraph_seperator: str = "\n"
|
221 |
+
verbose: bool = True
|
222 |
+
|
223 |
+
def construct_plain_text(self, extraction_result: dict) -> str:
|
224 |
+
|
225 |
+
section_strings = []
|
226 |
+
|
227 |
+
# add the title, if available (consider it as the first section)
|
228 |
+
title = extraction_result.get("title")
|
229 |
+
if title and title.strip():
|
230 |
+
section_strings.append(title.strip())
|
231 |
+
|
232 |
+
section_paragraphs: dict[str, list[str]] = extraction_result["sections"]
|
233 |
+
section_strings.extend(
|
234 |
+
self.paragraph_seperator.join(
|
235 |
+
# consider the section title as the first paragraph and
|
236 |
+
# remove empty paragraphs
|
237 |
+
filter(lambda s: len(s) > 0, map(lambda s: s.strip(), [section_name] + paragraphs))
|
238 |
+
)
|
239 |
+
for section_name, paragraphs in section_paragraphs.items()
|
240 |
+
)
|
241 |
+
|
242 |
+
return self.section_seperator.join(section_strings)
|
243 |
+
|
244 |
+
def postprocess_extraction_result(self, extraction_result: dict) -> dict:
|
245 |
+
|
246 |
+
# add sections
|
247 |
+
sections: dict[str, list[str]] = {}
|
248 |
+
for body_text in extraction_result["pdf_parse"]["body_text"]:
|
249 |
+
section_name = body_text["section"]
|
250 |
+
|
251 |
+
if section_name not in sections.keys():
|
252 |
+
sections[section_name] = []
|
253 |
+
sections[section_name] += [body_text["text"]]
|
254 |
+
extraction_result = {**extraction_result, "sections": sections}
|
255 |
+
|
256 |
+
return extraction_result
|
257 |
+
|
258 |
+
def __call__(self, pdf_file_path: Path | str) -> tuple[str, dict] | None:
|
259 |
+
"""Extract plain text from a PDf file"""
|
260 |
+
try:
|
261 |
+
extraction_fpath = process_pdf_file(
|
262 |
+
str(pdf_file_path),
|
263 |
+
temp_dir=self.tmp_dir,
|
264 |
+
output_dir=self.tmp_dir,
|
265 |
+
grobid_config=self.grobid_config,
|
266 |
+
verbose=self.verbose,
|
267 |
+
)
|
268 |
+
with open(extraction_fpath, "r") as f:
|
269 |
+
extraction_result = json.load(f)
|
270 |
+
|
271 |
+
processed_extraction_result = self.postprocess_extraction_result(extraction_result)
|
272 |
+
plain_text = self.construct_plain_text(processed_extraction_result)
|
273 |
+
return plain_text, extraction_result
|
274 |
+
except AssertionError:
|
275 |
+
print("Grobid failed to parse this document.")
|
276 |
+
return None
|
src/utils/pdf_utils/raw_paper.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from dataclasses import asdict, dataclass
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any
|
6 |
+
|
7 |
+
from jsonschema import validate
|
8 |
+
|
9 |
+
# TODO: load from file
|
10 |
+
schema = {
|
11 |
+
"title": "RawPaper",
|
12 |
+
"type": "object",
|
13 |
+
"properties": {
|
14 |
+
"paper_uuid": {"type": "string"},
|
15 |
+
"name": {"type": "string"},
|
16 |
+
"collection_id": {"type": "string"},
|
17 |
+
"collection_acronym": {"type": "string"},
|
18 |
+
"volume_id": {"type": "string"},
|
19 |
+
"booktitle": {"type": "string"},
|
20 |
+
"paper_id": {"type": "integer"},
|
21 |
+
"year": {"type": ["integer", "null"]},
|
22 |
+
"paper_title": {"type": "string"},
|
23 |
+
"authors": {
|
24 |
+
"type": "array",
|
25 |
+
"items": {
|
26 |
+
"type": "object",
|
27 |
+
"items": {
|
28 |
+
"first": {"type": ["string", "null"]},
|
29 |
+
"last": {"type": ["string", "null"]},
|
30 |
+
},
|
31 |
+
},
|
32 |
+
},
|
33 |
+
"abstract": {"type": ["string", "null"]},
|
34 |
+
"url": {"type": "string"},
|
35 |
+
"bibkey": {"type": ["string", "null"]},
|
36 |
+
"doi": {"type": ["string", "null"]},
|
37 |
+
"fulltext": {
|
38 |
+
"type": ["object", "null"],
|
39 |
+
"patternProperties": {"^.*$": {"type": "array", "items": {"type": "string"}}},
|
40 |
+
},
|
41 |
+
},
|
42 |
+
}
|
43 |
+
|
44 |
+
assert isinstance(schema, dict)
|
45 |
+
|
46 |
+
|
47 |
+
@dataclass
|
48 |
+
class RawPaper:
|
49 |
+
paper_uuid: str
|
50 |
+
name: str
|
51 |
+
|
52 |
+
collection_id: str
|
53 |
+
collection_acronym: str
|
54 |
+
volume_id: str
|
55 |
+
booktitle: str
|
56 |
+
paper_id: int
|
57 |
+
year: int | None
|
58 |
+
|
59 |
+
paper_title: str
|
60 |
+
authors: list[dict[str, str | None]]
|
61 |
+
abstract: str | None
|
62 |
+
url: str | None
|
63 |
+
bibkey: str
|
64 |
+
doi: str | None
|
65 |
+
fulltext: dict[str, list[str]] | None
|
66 |
+
|
67 |
+
@classmethod
|
68 |
+
def load_from_json(cls, fpath: str | Path) -> "RawPaper":
|
69 |
+
fpath = fpath if not isinstance(fpath, Path) else str(fpath)
|
70 |
+
# return cls(**sienna.load(fpath))
|
71 |
+
with open(fpath, "r") as f:
|
72 |
+
data = cls(**json.load(f))
|
73 |
+
return data
|
74 |
+
|
75 |
+
def get_fname(self) -> str:
|
76 |
+
return f"{self.name}.json"
|
77 |
+
|
78 |
+
def dumps(self) -> dict[str, Any]:
|
79 |
+
return asdict(self)
|
80 |
+
|
81 |
+
def validate(self) -> None:
|
82 |
+
validate(self.dumps(), schema=schema)
|
83 |
+
|
84 |
+
def save(self, odir: str) -> None:
|
85 |
+
self.validate()
|
86 |
+
if not os.path.exists(odir):
|
87 |
+
os.makedirs(odir, exist_ok=True)
|
88 |
+
opath = os.path.join(odir, self.get_fname())
|
89 |
+
with open(opath, "w") as f:
|
90 |
+
f.write(json.dumps(self.dumps(), indent=2))
|
src/utils/pdf_utils/s2orc_paper.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from typing import Any, Dict, List, Optional
|
3 |
+
|
4 |
+
S2ORC_NAME_STRING = "S2ORC"
|
5 |
+
S2ORC_VERSION_STRING = "1.0.0"
|
6 |
+
|
7 |
+
CORRECT_KEYS = {"issn": "issue", "type": "type_str"}
|
8 |
+
|
9 |
+
SKIP_KEYS = {"link", "bib_id"}
|
10 |
+
|
11 |
+
REFERENCE_OUTPUT_KEYS = {
|
12 |
+
"figure": {"text", "type_str", "uris", "num", "fig_num"},
|
13 |
+
"table": {"text", "type_str", "content", "num", "html"},
|
14 |
+
"footnote": {"text", "type_str", "num"},
|
15 |
+
"section": {"text", "type_str", "num", "parent"},
|
16 |
+
"equation": {"text", "type_str", "latex", "mathml", "num"},
|
17 |
+
}
|
18 |
+
|
19 |
+
METADATA_KEYS = {"title", "authors", "year", "venue", "identifiers"}
|
20 |
+
|
21 |
+
|
22 |
+
class ReferenceEntry:
|
23 |
+
"""
|
24 |
+
Class for representing S2ORC figure and table references
|
25 |
+
|
26 |
+
An example json representation (values are examples, not accurate):
|
27 |
+
|
28 |
+
{
|
29 |
+
"FIGREF0": {
|
30 |
+
"text": "FIG. 2. Depth profiles of...",
|
31 |
+
"latex": null,
|
32 |
+
"type": "figure"
|
33 |
+
},
|
34 |
+
"TABREF2": {
|
35 |
+
"text": "Diversity indices of...",
|
36 |
+
"latex": null,
|
37 |
+
"type": "table",
|
38 |
+
"content": "",
|
39 |
+
"html": ""
|
40 |
+
}
|
41 |
+
}
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
ref_id: str,
|
47 |
+
text: str,
|
48 |
+
type_str: str,
|
49 |
+
latex: Optional[str] = None,
|
50 |
+
mathml: Optional[str] = None,
|
51 |
+
content: Optional[str] = None,
|
52 |
+
html: Optional[str] = None,
|
53 |
+
uris: Optional[List[str]] = None,
|
54 |
+
num: Optional[str] = None,
|
55 |
+
parent: Optional[str] = None,
|
56 |
+
fig_num: Optional[str] = None,
|
57 |
+
):
|
58 |
+
self.ref_id = ref_id
|
59 |
+
self.text = text
|
60 |
+
self.type_str = type_str
|
61 |
+
self.latex = latex
|
62 |
+
self.mathml = mathml
|
63 |
+
self.content = content
|
64 |
+
self.html = html
|
65 |
+
self.uris = uris
|
66 |
+
self.num = num
|
67 |
+
self.parent = parent
|
68 |
+
self.fig_num = fig_num
|
69 |
+
|
70 |
+
def as_json(self):
|
71 |
+
keep_keys = REFERENCE_OUTPUT_KEYS.get(self.type_str, None)
|
72 |
+
if keep_keys:
|
73 |
+
return {k: self.__getattribute__(k) for k in keep_keys}
|
74 |
+
else:
|
75 |
+
return {
|
76 |
+
"text": self.text,
|
77 |
+
"type": self.type_str,
|
78 |
+
"latex": self.latex,
|
79 |
+
"mathml": self.mathml,
|
80 |
+
"content": self.content,
|
81 |
+
"html": self.html,
|
82 |
+
"uris": self.uris,
|
83 |
+
"num": self.num,
|
84 |
+
"parent": self.parent,
|
85 |
+
"fig_num": self.fig_num,
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
class BibliographyEntry:
|
90 |
+
"""
|
91 |
+
Class for representing S2ORC parsed bibliography entries
|
92 |
+
|
93 |
+
An example json representation (values are examples, not accurate):
|
94 |
+
|
95 |
+
{
|
96 |
+
"title": "Mobility Reports...",
|
97 |
+
"authors": [
|
98 |
+
{
|
99 |
+
"first": "A",
|
100 |
+
"middle": ["A"],
|
101 |
+
"last": "Haija",
|
102 |
+
"suffix": ""
|
103 |
+
}
|
104 |
+
],
|
105 |
+
"year": 2015,
|
106 |
+
"venue": "IEEE Wireless Commune Mag",
|
107 |
+
"volume": "42",
|
108 |
+
"issn": "9",
|
109 |
+
"pages": "80--92",
|
110 |
+
"other_ids": {
|
111 |
+
"doi": [
|
112 |
+
"10.1109/TWC.2014.2360196"
|
113 |
+
],
|
114 |
+
|
115 |
+
}
|
116 |
+
}
|
117 |
+
|
118 |
+
"""
|
119 |
+
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
bib_id: str,
|
123 |
+
title: str,
|
124 |
+
authors: List[Dict[str, str]],
|
125 |
+
ref_id: Optional[str] = None,
|
126 |
+
year: Optional[int] = None,
|
127 |
+
venue: Optional[str] = None,
|
128 |
+
volume: Optional[str] = None,
|
129 |
+
issue: Optional[str] = None,
|
130 |
+
pages: Optional[str] = None,
|
131 |
+
other_ids: Optional[Dict[str, List]] = None,
|
132 |
+
num: Optional[int] = None,
|
133 |
+
urls: Optional[List] = None,
|
134 |
+
raw_text: Optional[str] = None,
|
135 |
+
links: Optional[List] = None,
|
136 |
+
):
|
137 |
+
self.bib_id = bib_id
|
138 |
+
self.ref_id = ref_id
|
139 |
+
self.title = title
|
140 |
+
self.authors = authors
|
141 |
+
self.year = year
|
142 |
+
self.venue = venue
|
143 |
+
self.volume = volume
|
144 |
+
self.issue = issue
|
145 |
+
self.pages = pages
|
146 |
+
self.other_ids = other_ids
|
147 |
+
self.num = num
|
148 |
+
self.urls = urls
|
149 |
+
self.raw_text = raw_text
|
150 |
+
self.links = links
|
151 |
+
|
152 |
+
def as_json(self):
|
153 |
+
return {
|
154 |
+
"ref_id": self.ref_id,
|
155 |
+
"title": self.title,
|
156 |
+
"authors": self.authors,
|
157 |
+
"year": self.year,
|
158 |
+
"venue": self.venue,
|
159 |
+
"volume": self.volume,
|
160 |
+
"issue": self.issue,
|
161 |
+
"pages": self.pages,
|
162 |
+
"other_ids": self.other_ids,
|
163 |
+
"num": self.num,
|
164 |
+
"urls": self.urls,
|
165 |
+
"raw_text": self.raw_text,
|
166 |
+
"links": self.links,
|
167 |
+
}
|
168 |
+
|
169 |
+
|
170 |
+
class Affiliation:
|
171 |
+
"""
|
172 |
+
Class for representing affiliation info
|
173 |
+
|
174 |
+
Example:
|
175 |
+
{
|
176 |
+
"laboratory": "Key Laboratory of Urban Environment and Health",
|
177 |
+
"institution": "Chinese Academy of Sciences",
|
178 |
+
"location": {
|
179 |
+
"postCode": "361021",
|
180 |
+
"settlement": "Xiamen",
|
181 |
+
"country": "People's Republic of China"
|
182 |
+
}
|
183 |
+
"""
|
184 |
+
|
185 |
+
def __init__(self, laboratory: str, institution: str, location: Dict):
|
186 |
+
self.laboratory = laboratory
|
187 |
+
self.institution = institution
|
188 |
+
self.location = location
|
189 |
+
|
190 |
+
def as_json(self):
|
191 |
+
return {
|
192 |
+
"laboratory": self.laboratory,
|
193 |
+
"institution": self.institution,
|
194 |
+
"location": self.location,
|
195 |
+
}
|
196 |
+
|
197 |
+
|
198 |
+
class Author:
|
199 |
+
"""
|
200 |
+
Class for representing paper authors
|
201 |
+
|
202 |
+
Example:
|
203 |
+
|
204 |
+
{
|
205 |
+
"first": "Anyi",
|
206 |
+
"middle": [],
|
207 |
+
"last": "Hu",
|
208 |
+
"suffix": "",
|
209 |
+
"affiliation": {
|
210 |
+
"laboratory": "Key Laboratory of Urban Environment and Health",
|
211 |
+
"institution": "Chinese Academy of Sciences",
|
212 |
+
"location": {
|
213 |
+
"postCode": "361021",
|
214 |
+
"settlement": "Xiamen",
|
215 |
+
"country": "People's Republic of China"
|
216 |
+
}
|
217 |
+
},
|
218 |
+
"email": ""
|
219 |
+
}
|
220 |
+
"""
|
221 |
+
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
first: str,
|
225 |
+
middle: List[str],
|
226 |
+
last: str,
|
227 |
+
suffix: str,
|
228 |
+
affiliation: Optional[Dict] = None,
|
229 |
+
email: Optional[str] = None,
|
230 |
+
):
|
231 |
+
self.first = first
|
232 |
+
self.middle = middle
|
233 |
+
self.last = last
|
234 |
+
self.suffix = suffix
|
235 |
+
self.affiliation = Affiliation(**affiliation) if affiliation else {}
|
236 |
+
self.email = email
|
237 |
+
|
238 |
+
def as_json(self):
|
239 |
+
return {
|
240 |
+
"first": self.first,
|
241 |
+
"middle": self.middle,
|
242 |
+
"last": self.last,
|
243 |
+
"suffix": self.suffix,
|
244 |
+
"affiliation": self.affiliation.as_json() if self.affiliation else {},
|
245 |
+
"email": self.email,
|
246 |
+
}
|
247 |
+
|
248 |
+
|
249 |
+
class Metadata:
|
250 |
+
"""
|
251 |
+
Class for representing paper metadata
|
252 |
+
|
253 |
+
Example:
|
254 |
+
{
|
255 |
+
"title": "Niche Partitioning...",
|
256 |
+
"authors": [
|
257 |
+
{
|
258 |
+
"first": "Anyi",
|
259 |
+
"middle": [],
|
260 |
+
"last": "Hu",
|
261 |
+
"suffix": "",
|
262 |
+
"affiliation": {
|
263 |
+
"laboratory": "Key Laboratory of Urban Environment and Health",
|
264 |
+
"institution": "Chinese Academy of Sciences",
|
265 |
+
"location": {
|
266 |
+
"postCode": "361021",
|
267 |
+
"settlement": "Xiamen",
|
268 |
+
"country": "People's Republic of China"
|
269 |
+
}
|
270 |
+
},
|
271 |
+
"email": ""
|
272 |
+
}
|
273 |
+
],
|
274 |
+
"year": "2011-11"
|
275 |
+
}
|
276 |
+
"""
|
277 |
+
|
278 |
+
def __init__(
|
279 |
+
self,
|
280 |
+
title: str,
|
281 |
+
authors: List[Dict],
|
282 |
+
year: Optional[str] = None,
|
283 |
+
venue: Optional[str] = None,
|
284 |
+
identifiers: Optional[Dict] = {},
|
285 |
+
):
|
286 |
+
self.title = title
|
287 |
+
self.authors = [Author(**author) for author in authors]
|
288 |
+
self.year = year
|
289 |
+
self.venue = venue
|
290 |
+
self.identifiers = identifiers
|
291 |
+
|
292 |
+
def as_json(self):
|
293 |
+
return {
|
294 |
+
"title": self.title,
|
295 |
+
"authors": [author.as_json() for author in self.authors],
|
296 |
+
"year": self.year,
|
297 |
+
"venue": self.venue,
|
298 |
+
"identifiers": self.identifiers,
|
299 |
+
}
|
300 |
+
|
301 |
+
|
302 |
+
class Paragraph:
|
303 |
+
"""
|
304 |
+
Class for representing a parsed paragraph from Grobid xml
|
305 |
+
All xml tags are removed from the paragraph text, all figures, equations, and tables are replaced
|
306 |
+
with a special token that maps to a reference identifier
|
307 |
+
Citation mention spans and section header are extracted
|
308 |
+
|
309 |
+
An example json representation (values are examples, not accurate):
|
310 |
+
|
311 |
+
{
|
312 |
+
"text": "Formal language techniques BID1 may be used to study FORMULA0 (see REF0)...",
|
313 |
+
"mention_spans": [
|
314 |
+
{
|
315 |
+
"start": 27,
|
316 |
+
"end": 31,
|
317 |
+
"text": "[1]")
|
318 |
+
],
|
319 |
+
"ref_spans": [
|
320 |
+
{
|
321 |
+
"start": ,
|
322 |
+
"end": ,
|
323 |
+
"text": "Fig. 1"
|
324 |
+
}
|
325 |
+
],
|
326 |
+
"eq_spans": [
|
327 |
+
{
|
328 |
+
"start": 53,
|
329 |
+
"end": 61,
|
330 |
+
"text": "α = 1",
|
331 |
+
"latex": "\\alpha = 1",
|
332 |
+
"ref_id": null
|
333 |
+
}
|
334 |
+
],
|
335 |
+
"section": "Abstract"
|
336 |
+
}
|
337 |
+
"""
|
338 |
+
|
339 |
+
def __init__(
|
340 |
+
self,
|
341 |
+
text: str,
|
342 |
+
cite_spans: List[Dict],
|
343 |
+
ref_spans: List[Dict],
|
344 |
+
eq_spans: Optional[List[Dict]] = [],
|
345 |
+
section: Optional[Any] = None,
|
346 |
+
sec_num: Optional[Any] = None,
|
347 |
+
):
|
348 |
+
self.text = text
|
349 |
+
self.cite_spans = cite_spans
|
350 |
+
self.ref_spans = ref_spans
|
351 |
+
self.eq_spans = eq_spans
|
352 |
+
if type(section) is str:
|
353 |
+
if section:
|
354 |
+
sec_parts = section.split("::")
|
355 |
+
section_list = [[None, sec_name] for sec_name in sec_parts]
|
356 |
+
else:
|
357 |
+
section_list = None
|
358 |
+
if section_list and sec_num:
|
359 |
+
section_list[-1][0] = sec_num
|
360 |
+
else:
|
361 |
+
section_list = section
|
362 |
+
self.section = section_list
|
363 |
+
|
364 |
+
def as_json(self):
|
365 |
+
return {
|
366 |
+
"text": self.text,
|
367 |
+
"cite_spans": self.cite_spans,
|
368 |
+
"ref_spans": self.ref_spans,
|
369 |
+
"eq_spans": self.eq_spans,
|
370 |
+
"section": "::".join([sec[1] for sec in self.section]) if self.section else "",
|
371 |
+
"sec_num": self.section[-1][0] if self.section else None,
|
372 |
+
}
|
373 |
+
|
374 |
+
|
375 |
+
class Paper:
|
376 |
+
"""
|
377 |
+
Class for representing a parsed S2ORC paper
|
378 |
+
"""
|
379 |
+
|
380 |
+
def __init__(
|
381 |
+
self,
|
382 |
+
paper_id: str,
|
383 |
+
pdf_hash: str,
|
384 |
+
metadata: Dict,
|
385 |
+
abstract: List[Dict],
|
386 |
+
body_text: List[Dict],
|
387 |
+
back_matter: List[Dict],
|
388 |
+
bib_entries: Dict,
|
389 |
+
ref_entries: Dict,
|
390 |
+
):
|
391 |
+
self.paper_id = paper_id
|
392 |
+
self.pdf_hash = pdf_hash
|
393 |
+
self.metadata = Metadata(**metadata)
|
394 |
+
self.abstract = [Paragraph(**para) for para in abstract]
|
395 |
+
self.body_text = [Paragraph(**para) for para in body_text]
|
396 |
+
self.back_matter = [Paragraph(**para) for para in back_matter]
|
397 |
+
self.bib_entries = [
|
398 |
+
BibliographyEntry(
|
399 |
+
bib_id=key,
|
400 |
+
**{
|
401 |
+
CORRECT_KEYS[k] if k in CORRECT_KEYS else k: v
|
402 |
+
for k, v in bib.items()
|
403 |
+
if k not in SKIP_KEYS
|
404 |
+
},
|
405 |
+
)
|
406 |
+
for key, bib in bib_entries.items()
|
407 |
+
]
|
408 |
+
self.ref_entries = [
|
409 |
+
ReferenceEntry(
|
410 |
+
ref_id=key,
|
411 |
+
**{
|
412 |
+
CORRECT_KEYS[k] if k in CORRECT_KEYS else k: v
|
413 |
+
for k, v in ref.items()
|
414 |
+
if k != "ref_id"
|
415 |
+
},
|
416 |
+
)
|
417 |
+
for key, ref in ref_entries.items()
|
418 |
+
]
|
419 |
+
|
420 |
+
def as_json(self):
|
421 |
+
return {
|
422 |
+
"paper_id": self.paper_id,
|
423 |
+
"pdf_hash": self.pdf_hash,
|
424 |
+
"metadata": self.metadata.as_json(),
|
425 |
+
"abstract": [para.as_json() for para in self.abstract],
|
426 |
+
"body_text": [para.as_json() for para in self.body_text],
|
427 |
+
"back_matter": [para.as_json() for para in self.back_matter],
|
428 |
+
"bib_entries": {bib.bib_id: bib.as_json() for bib in self.bib_entries},
|
429 |
+
"ref_entries": {ref.ref_id: ref.as_json() for ref in self.ref_entries},
|
430 |
+
}
|
431 |
+
|
432 |
+
@property
|
433 |
+
def raw_abstract_text(self) -> str:
|
434 |
+
"""
|
435 |
+
Get all the body text joined by a newline
|
436 |
+
:return:
|
437 |
+
"""
|
438 |
+
return "\n".join([para.text for para in self.abstract])
|
439 |
+
|
440 |
+
@property
|
441 |
+
def raw_body_text(self) -> str:
|
442 |
+
"""
|
443 |
+
Get all the body text joined by a newline
|
444 |
+
:return:
|
445 |
+
"""
|
446 |
+
return "\n".join([para.text for para in self.body_text])
|
447 |
+
|
448 |
+
def release_json(self, doc_type: str = "pdf") -> Dict:
|
449 |
+
"""
|
450 |
+
Return in release JSON format
|
451 |
+
:return:
|
452 |
+
"""
|
453 |
+
# TODO: not fully implemented; metadata format is not right; extra keys in some places
|
454 |
+
release_dict: Dict = {"paper_id": self.paper_id}
|
455 |
+
release_dict.update(
|
456 |
+
{
|
457 |
+
"header": {
|
458 |
+
"generated_with": f"{S2ORC_NAME_STRING} {S2ORC_VERSION_STRING}",
|
459 |
+
"date_generated": datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
|
460 |
+
}
|
461 |
+
}
|
462 |
+
)
|
463 |
+
release_dict.update(self.metadata.as_json())
|
464 |
+
release_dict.update({"abstract": self.raw_abstract_text})
|
465 |
+
release_dict.update(
|
466 |
+
{
|
467 |
+
f"{doc_type}_parse": {
|
468 |
+
"paper_id": self.paper_id,
|
469 |
+
"_pdf_hash": self.pdf_hash,
|
470 |
+
"abstract": [para.as_json() for para in self.abstract],
|
471 |
+
"body_text": [para.as_json() for para in self.body_text],
|
472 |
+
"back_matter": [para.as_json() for para in self.back_matter],
|
473 |
+
"bib_entries": {bib.bib_id: bib.as_json() for bib in self.bib_entries},
|
474 |
+
"ref_entries": {ref.ref_id: ref.as_json() for ref in self.ref_entries},
|
475 |
+
}
|
476 |
+
}
|
477 |
+
)
|
478 |
+
return release_dict
|
src/utils/pdf_utils/s2orc_utils.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
|
3 |
+
from .s2orc_paper import METADATA_KEYS, Paper
|
4 |
+
|
5 |
+
|
6 |
+
def load_s2orc(paper_dict: Dict[str, Any]) -> Paper:
|
7 |
+
"""
|
8 |
+
Load release S2ORC into Paper class
|
9 |
+
:param paper_dict:
|
10 |
+
:return:
|
11 |
+
"""
|
12 |
+
paper_id = paper_dict["paper_id"]
|
13 |
+
pdf_hash = paper_dict.get("_pdf_hash", paper_dict.get("s2_pdf_hash", None))
|
14 |
+
|
15 |
+
# 2019 gorc parses
|
16 |
+
grobid_parse = paper_dict.get("grobid_parse")
|
17 |
+
if grobid_parse:
|
18 |
+
metadata = {k: v for k, v in paper_dict["metadata"].items() if k in METADATA_KEYS}
|
19 |
+
abstract = grobid_parse.get("abstract", [])
|
20 |
+
body_text = grobid_parse.get("body_text", [])
|
21 |
+
back_matter = grobid_parse.get("back_matter", [])
|
22 |
+
bib_entries = grobid_parse.get("bib_entries", {})
|
23 |
+
for k, v in bib_entries.items():
|
24 |
+
if "link" in v:
|
25 |
+
v["links"] = [v["link"]]
|
26 |
+
ref_entries = grobid_parse.get("ref_entries", {})
|
27 |
+
# current and 2020 s2orc release_json
|
28 |
+
elif ("pdf_parse" in paper_dict and paper_dict.get("pdf_parse")) or (
|
29 |
+
"body_text" in paper_dict and paper_dict.get("body_text")
|
30 |
+
):
|
31 |
+
if "pdf_parse" in paper_dict:
|
32 |
+
paper_dict = paper_dict["pdf_parse"]
|
33 |
+
if paper_dict.get("metadata"):
|
34 |
+
metadata = {
|
35 |
+
k: v for k, v in paper_dict.get("metadata", {}).items() if k in METADATA_KEYS
|
36 |
+
}
|
37 |
+
# 2020 s2orc releases (metadata is separate)
|
38 |
+
else:
|
39 |
+
metadata = {"title": None, "authors": [], "year": None}
|
40 |
+
abstract = paper_dict.get("abstract", [])
|
41 |
+
body_text = paper_dict.get("body_text", [])
|
42 |
+
back_matter = paper_dict.get("back_matter", [])
|
43 |
+
bib_entries = paper_dict.get("bib_entries", {})
|
44 |
+
for k, v in bib_entries.items():
|
45 |
+
if "link" in v:
|
46 |
+
v["links"] = [v["link"]]
|
47 |
+
ref_entries = paper_dict.get("ref_entries", {})
|
48 |
+
else:
|
49 |
+
print(paper_id)
|
50 |
+
raise NotImplementedError("Unknown S2ORC file type!")
|
51 |
+
|
52 |
+
return Paper(
|
53 |
+
paper_id=paper_id,
|
54 |
+
pdf_hash=pdf_hash,
|
55 |
+
metadata=metadata,
|
56 |
+
abstract=abstract,
|
57 |
+
body_text=body_text,
|
58 |
+
back_matter=back_matter,
|
59 |
+
bib_entries=bib_entries,
|
60 |
+
ref_entries=ref_entries,
|
61 |
+
)
|
src/utils/pdf_utils/utils.py
ADDED
@@ -0,0 +1,904 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Dict, List, Tuple
|
3 |
+
|
4 |
+
import bs4
|
5 |
+
|
6 |
+
|
7 |
+
def replace_refspans(
|
8 |
+
spans_to_replace: List[Tuple[int, int, str, str]],
|
9 |
+
full_string: str,
|
10 |
+
pre_padding: str = "",
|
11 |
+
post_padding: str = "",
|
12 |
+
btwn_padding: str = ", ",
|
13 |
+
) -> str:
|
14 |
+
"""
|
15 |
+
For each span within the full string, replace that span with new text
|
16 |
+
:param spans_to_replace: list of tuples of form (start_ind, end_ind, span_text, new_substring)
|
17 |
+
:param full_string:
|
18 |
+
:param pre_padding:
|
19 |
+
:param post_padding:
|
20 |
+
:param btwn_padding:
|
21 |
+
:return:
|
22 |
+
"""
|
23 |
+
# assert all spans are equal to full_text span
|
24 |
+
assert all([full_string[start:end] == span for start, end, span, _ in spans_to_replace])
|
25 |
+
|
26 |
+
# assert none of the spans start with the same start ind
|
27 |
+
start_inds = [rep[0] for rep in spans_to_replace]
|
28 |
+
assert len(set(start_inds)) == len(start_inds)
|
29 |
+
|
30 |
+
# sort by start index
|
31 |
+
spans_to_replace.sort(key=lambda x: x[0])
|
32 |
+
|
33 |
+
# form strings for each span group
|
34 |
+
for i, entry in enumerate(spans_to_replace):
|
35 |
+
start, end, span, new_string = entry
|
36 |
+
|
37 |
+
# skip empties
|
38 |
+
if end <= 0:
|
39 |
+
continue
|
40 |
+
|
41 |
+
# compute shift amount
|
42 |
+
shift_amount = len(new_string) - len(span) + len(pre_padding) + len(post_padding)
|
43 |
+
|
44 |
+
# shift remaining appropriately
|
45 |
+
for ind in range(i + 1, len(spans_to_replace)):
|
46 |
+
next_start, next_end, next_span, next_string = spans_to_replace[ind]
|
47 |
+
# skip empties
|
48 |
+
if next_end <= 0:
|
49 |
+
continue
|
50 |
+
# if overlap between ref span and current ref span, remove from replacement
|
51 |
+
if next_start < end:
|
52 |
+
next_start = 0
|
53 |
+
next_end = 0
|
54 |
+
next_string = ""
|
55 |
+
# if ref span abuts previous reference span
|
56 |
+
elif next_start == end:
|
57 |
+
next_start += shift_amount
|
58 |
+
next_end += shift_amount
|
59 |
+
next_string = btwn_padding + pre_padding + next_string + post_padding
|
60 |
+
# if ref span starts after, shift starts and ends
|
61 |
+
elif next_start > end:
|
62 |
+
next_start += shift_amount
|
63 |
+
next_end += shift_amount
|
64 |
+
next_string = pre_padding + next_string + post_padding
|
65 |
+
# save adjusted span
|
66 |
+
spans_to_replace[ind] = (next_start, next_end, next_span, next_string)
|
67 |
+
|
68 |
+
spans_to_replace = [entry for entry in spans_to_replace if entry[1] > 0]
|
69 |
+
spans_to_replace.sort(key=lambda x: x[0])
|
70 |
+
|
71 |
+
# apply shifts in series
|
72 |
+
for start, end, span, new_string in spans_to_replace:
|
73 |
+
assert full_string[start:end] == span
|
74 |
+
full_string = full_string[:start] + new_string + full_string[end:]
|
75 |
+
|
76 |
+
return full_string
|
77 |
+
|
78 |
+
|
79 |
+
BRACKET_REGEX = re.compile(r"\[[1-9]\d{0,2}([,;\-\s]+[1-9]\d{0,2})*;?\]")
|
80 |
+
BRACKET_STYLE_THRESHOLD = 5
|
81 |
+
|
82 |
+
SINGLE_BRACKET_REGEX = re.compile(r"\[([1-9]\d{0,2})\]")
|
83 |
+
EXPANSION_CHARS = {"-", "–"}
|
84 |
+
|
85 |
+
REPLACE_TABLE_TOKS = {
|
86 |
+
"<row>": "<tr>",
|
87 |
+
"<row/>": "<tr/>",
|
88 |
+
"</row>": "</tr>",
|
89 |
+
"<cell>": "<td>",
|
90 |
+
"<cell/>": "<td/>",
|
91 |
+
"</cell>": "</td>",
|
92 |
+
"<cell ": "<td ",
|
93 |
+
"cols=": "colspan=",
|
94 |
+
}
|
95 |
+
|
96 |
+
|
97 |
+
def span_already_added(sub_start: int, sub_end: int, span_indices: List[Tuple[int, int]]) -> bool:
|
98 |
+
"""
|
99 |
+
Check if span is a subspan of existing span
|
100 |
+
:param sub_start:
|
101 |
+
:param sub_end:
|
102 |
+
:param span_indices:
|
103 |
+
:return:
|
104 |
+
"""
|
105 |
+
for span_start, span_end in span_indices:
|
106 |
+
if sub_start >= span_start and sub_end <= span_end:
|
107 |
+
return True
|
108 |
+
return False
|
109 |
+
|
110 |
+
|
111 |
+
def is_expansion_string(between_string: str) -> bool:
|
112 |
+
"""
|
113 |
+
Check if the string between two refs is an expansion string
|
114 |
+
:param between_string:
|
115 |
+
:return:
|
116 |
+
"""
|
117 |
+
if (
|
118 |
+
len(between_string) <= 2
|
119 |
+
and any([c in EXPANSION_CHARS for c in between_string])
|
120 |
+
and all([c in EXPANSION_CHARS.union({" "}) for c in between_string])
|
121 |
+
):
|
122 |
+
return True
|
123 |
+
return False
|
124 |
+
|
125 |
+
|
126 |
+
# TODO: still cases like `09bcee03baceb509d4fcf736fa1322cb8adf507f` w/ dups like ['L Jung', 'R Hessler', 'Louis Jung', 'Roland Hessler']
|
127 |
+
# example paper that has empties & duplicates: `09bce26cc7e825e15a4469e3e78b7a54898bb97f`
|
128 |
+
def _clean_empty_and_duplicate_authors_from_grobid_parse(
|
129 |
+
authors: List[Dict],
|
130 |
+
) -> List[Dict]:
|
131 |
+
"""
|
132 |
+
Within affiliation, `location` is a dict with fields <settlement>, <region>, <country>, <postCode>, etc.
|
133 |
+
Too much hassle, so just take the first one that's not empty.
|
134 |
+
"""
|
135 |
+
# stripping empties
|
136 |
+
clean_authors_list = []
|
137 |
+
for author in authors:
|
138 |
+
clean_first = author["first"].strip()
|
139 |
+
clean_last = author["last"].strip()
|
140 |
+
clean_middle = [m.strip() for m in author["middle"]]
|
141 |
+
clean_suffix = author["suffix"].strip()
|
142 |
+
if clean_first or clean_last or clean_middle:
|
143 |
+
author["first"] = clean_first
|
144 |
+
author["last"] = clean_last
|
145 |
+
author["middle"] = clean_middle
|
146 |
+
author["suffix"] = clean_suffix
|
147 |
+
clean_authors_list.append(author)
|
148 |
+
# combining duplicates (preserve first occurrence of author name as position)
|
149 |
+
key_to_author_blobs = {}
|
150 |
+
ordered_keys_by_author_pos = []
|
151 |
+
for author in clean_authors_list:
|
152 |
+
key = (
|
153 |
+
author["first"],
|
154 |
+
author["last"],
|
155 |
+
" ".join(author["middle"]),
|
156 |
+
author["suffix"],
|
157 |
+
)
|
158 |
+
if key not in key_to_author_blobs:
|
159 |
+
key_to_author_blobs[key] = author
|
160 |
+
ordered_keys_by_author_pos.append(key)
|
161 |
+
else:
|
162 |
+
if author["email"]:
|
163 |
+
key_to_author_blobs[key]["email"] = author["email"]
|
164 |
+
if author["affiliation"] and (
|
165 |
+
author["affiliation"]["institution"]
|
166 |
+
or author["affiliation"]["laboratory"]
|
167 |
+
or author["affiliation"]["location"]
|
168 |
+
):
|
169 |
+
key_to_author_blobs[key]["affiliation"] = author["affiliation"]
|
170 |
+
dedup_authors_list = [key_to_author_blobs[key] for key in ordered_keys_by_author_pos]
|
171 |
+
return dedup_authors_list
|
172 |
+
|
173 |
+
|
174 |
+
def sub_spans_and_update_indices(
|
175 |
+
spans_to_replace: List[Tuple[int, int, str, str]], full_string: str
|
176 |
+
) -> Tuple[str, List]:
|
177 |
+
"""
|
178 |
+
Replace all spans and recompute indices
|
179 |
+
:param spans_to_replace:
|
180 |
+
:param full_string:
|
181 |
+
:return:
|
182 |
+
"""
|
183 |
+
# TODO: check no spans overlapping
|
184 |
+
# TODO: check all spans well-formed
|
185 |
+
|
186 |
+
# assert all spans are equal to full_text span
|
187 |
+
assert all([full_string[start:end] == token for start, end, token, _ in spans_to_replace])
|
188 |
+
|
189 |
+
# assert none of the spans start with the same start ind
|
190 |
+
start_inds = [rep[0] for rep in spans_to_replace]
|
191 |
+
assert len(set(start_inds)) == len(start_inds)
|
192 |
+
|
193 |
+
# sort by start index
|
194 |
+
spans_to_replace.sort(key=lambda x: x[0])
|
195 |
+
|
196 |
+
# compute offsets for each span
|
197 |
+
new_spans = [
|
198 |
+
(start, end, token, surface, 0) for start, end, token, surface in spans_to_replace
|
199 |
+
]
|
200 |
+
for i, entry in enumerate(spans_to_replace):
|
201 |
+
start, end, token, surface = entry
|
202 |
+
new_end = start + len(surface)
|
203 |
+
offset = new_end - end
|
204 |
+
# new_spans[i][1] += offset
|
205 |
+
new_spans[i] = (
|
206 |
+
new_spans[i][0],
|
207 |
+
new_spans[i][1] + offset,
|
208 |
+
new_spans[i][2],
|
209 |
+
new_spans[i][3],
|
210 |
+
new_spans[i][4],
|
211 |
+
)
|
212 |
+
# for new_span_entry in new_spans[i + 1 :]:
|
213 |
+
# new_span_entry[4] += offset
|
214 |
+
for j in range(i + 1, len(new_spans)):
|
215 |
+
new_spans[j] = (
|
216 |
+
new_spans[j][0],
|
217 |
+
new_spans[j][1],
|
218 |
+
new_spans[j][2],
|
219 |
+
new_spans[j][3],
|
220 |
+
new_spans[j][4] + offset,
|
221 |
+
)
|
222 |
+
|
223 |
+
# generate new text and create final spans
|
224 |
+
new_text = replace_refspans(spans_to_replace, full_string, btwn_padding="")
|
225 |
+
result = [
|
226 |
+
(start + offset, end + offset, token, surface)
|
227 |
+
for start, end, token, surface, offset in new_spans
|
228 |
+
]
|
229 |
+
|
230 |
+
return new_text, result
|
231 |
+
|
232 |
+
|
233 |
+
class UniqTokenGenerator:
|
234 |
+
"""
|
235 |
+
Generate unique token
|
236 |
+
"""
|
237 |
+
|
238 |
+
def __init__(self, tok_string):
|
239 |
+
self.tok_string = tok_string
|
240 |
+
self.ind = 0
|
241 |
+
|
242 |
+
def __iter__(self):
|
243 |
+
return self
|
244 |
+
|
245 |
+
def __next__(self):
|
246 |
+
return self.next()
|
247 |
+
|
248 |
+
def next(self):
|
249 |
+
new_token = f"{self.tok_string}{self.ind}"
|
250 |
+
self.ind += 1
|
251 |
+
return new_token
|
252 |
+
|
253 |
+
|
254 |
+
def normalize_grobid_id(grobid_id: str):
|
255 |
+
"""
|
256 |
+
Normalize grobid object identifiers
|
257 |
+
:param grobid_id:
|
258 |
+
:return:
|
259 |
+
"""
|
260 |
+
str_norm = grobid_id.upper().replace("_", "").replace("#", "")
|
261 |
+
if str_norm.startswith("B"):
|
262 |
+
return str_norm.replace("B", "BIBREF")
|
263 |
+
if str_norm.startswith("TAB"):
|
264 |
+
return str_norm.replace("TAB", "TABREF")
|
265 |
+
if str_norm.startswith("FIG"):
|
266 |
+
return str_norm.replace("FIG", "FIGREF")
|
267 |
+
if str_norm.startswith("FORMULA"):
|
268 |
+
return str_norm.replace("FORMULA", "EQREF")
|
269 |
+
return str_norm
|
270 |
+
|
271 |
+
|
272 |
+
def extract_formulas_from_tei_xml(sp: bs4.BeautifulSoup) -> None:
|
273 |
+
"""
|
274 |
+
Replace all formulas with the text
|
275 |
+
:param sp:
|
276 |
+
:return:
|
277 |
+
"""
|
278 |
+
for eq in sp.find_all("formula"):
|
279 |
+
eq.replace_with(sp.new_string(eq.text.strip()))
|
280 |
+
|
281 |
+
|
282 |
+
def table_to_html(table: bs4.element.Tag) -> str:
|
283 |
+
"""
|
284 |
+
Sub table tags with html table tags
|
285 |
+
:param table_str:
|
286 |
+
:return:
|
287 |
+
"""
|
288 |
+
for tag in table:
|
289 |
+
if tag.name != "row":
|
290 |
+
print(f"Unknown table subtag: {tag.name}")
|
291 |
+
tag.decompose()
|
292 |
+
table_str = str(table)
|
293 |
+
for token, subtoken in REPLACE_TABLE_TOKS.items():
|
294 |
+
table_str = table_str.replace(token, subtoken)
|
295 |
+
return table_str
|
296 |
+
|
297 |
+
|
298 |
+
def extract_figures_and_tables_from_tei_xml(sp: bs4.BeautifulSoup) -> Dict[str, Dict]:
|
299 |
+
"""
|
300 |
+
Generate figure and table dicts
|
301 |
+
:param sp:
|
302 |
+
:return:
|
303 |
+
"""
|
304 |
+
ref_map = dict()
|
305 |
+
|
306 |
+
for fig in sp.find_all("figure"):
|
307 |
+
try:
|
308 |
+
if fig.name and fig.get("xml:id"):
|
309 |
+
if fig.get("type") == "table":
|
310 |
+
ref_map[normalize_grobid_id(fig.get("xml:id"))] = {
|
311 |
+
"text": (
|
312 |
+
fig.figDesc.text.strip()
|
313 |
+
if fig.figDesc
|
314 |
+
else fig.head.text.strip() if fig.head else ""
|
315 |
+
),
|
316 |
+
"latex": None,
|
317 |
+
"type": "table",
|
318 |
+
"content": table_to_html(fig.table),
|
319 |
+
"fig_num": fig.get("xml:id"),
|
320 |
+
}
|
321 |
+
else:
|
322 |
+
if True in [char.isdigit() for char in fig.findNext("head").findNext("label")]:
|
323 |
+
fig_num = fig.findNext("head").findNext("label").contents[0]
|
324 |
+
else:
|
325 |
+
fig_num = None
|
326 |
+
ref_map[normalize_grobid_id(fig.get("xml:id"))] = {
|
327 |
+
"text": fig.figDesc.text.strip() if fig.figDesc else "",
|
328 |
+
"latex": None,
|
329 |
+
"type": "figure",
|
330 |
+
"content": "",
|
331 |
+
"fig_num": fig_num,
|
332 |
+
}
|
333 |
+
except AttributeError:
|
334 |
+
continue
|
335 |
+
fig.decompose()
|
336 |
+
|
337 |
+
return ref_map
|
338 |
+
|
339 |
+
|
340 |
+
def check_if_citations_are_bracket_style(sp: bs4.BeautifulSoup) -> bool:
|
341 |
+
"""
|
342 |
+
Check if the document has bracket style citations
|
343 |
+
:param sp:
|
344 |
+
:return:
|
345 |
+
"""
|
346 |
+
cite_strings = []
|
347 |
+
if sp.body:
|
348 |
+
for div in sp.body.find_all("div"):
|
349 |
+
if div.head:
|
350 |
+
continue
|
351 |
+
for rtag in div.find_all("ref"):
|
352 |
+
ref_type = rtag.get("type")
|
353 |
+
if ref_type == "bibr":
|
354 |
+
cite_strings.append(rtag.text.strip())
|
355 |
+
|
356 |
+
# check how many match bracket style
|
357 |
+
bracket_style = [bool(BRACKET_REGEX.match(cite_str)) for cite_str in cite_strings]
|
358 |
+
|
359 |
+
# return true if
|
360 |
+
if sum(bracket_style) > BRACKET_STYLE_THRESHOLD:
|
361 |
+
return True
|
362 |
+
|
363 |
+
return False
|
364 |
+
|
365 |
+
|
366 |
+
def sub_all_note_tags(sp: bs4.BeautifulSoup) -> bs4.BeautifulSoup:
|
367 |
+
"""
|
368 |
+
Sub all note tags with p tags
|
369 |
+
:param para_el:
|
370 |
+
:param sp:
|
371 |
+
:return:
|
372 |
+
"""
|
373 |
+
for ntag in sp.find_all("note"):
|
374 |
+
p_tag = sp.new_tag("p")
|
375 |
+
p_tag.string = ntag.text.strip()
|
376 |
+
ntag.replace_with(p_tag)
|
377 |
+
return sp
|
378 |
+
|
379 |
+
|
380 |
+
def process_formulas_in_paragraph(para_el: bs4.BeautifulSoup, sp: bs4.BeautifulSoup) -> None:
|
381 |
+
"""
|
382 |
+
Process all formulas in paragraph and replace with text and label
|
383 |
+
:param para_el:
|
384 |
+
:param sp:
|
385 |
+
:return:
|
386 |
+
"""
|
387 |
+
for ftag in para_el.find_all("formula"):
|
388 |
+
# get label if exists and insert a space between formula and label
|
389 |
+
if ftag.label:
|
390 |
+
label = " " + ftag.label.text
|
391 |
+
ftag.label.decompose()
|
392 |
+
else:
|
393 |
+
label = ""
|
394 |
+
ftag.replace_with(sp.new_string(f"{ftag.text.strip()}{label}"))
|
395 |
+
|
396 |
+
|
397 |
+
def process_references_in_paragraph(
|
398 |
+
para_el: bs4.BeautifulSoup, sp: bs4.BeautifulSoup, refs: Dict
|
399 |
+
) -> Dict:
|
400 |
+
"""
|
401 |
+
Process all references in paragraph and generate a dict that contains (type, ref_id, surface_form)
|
402 |
+
:param para_el:
|
403 |
+
:param sp:
|
404 |
+
:param refs:
|
405 |
+
:return:
|
406 |
+
"""
|
407 |
+
tokgen = UniqTokenGenerator("REFTOKEN")
|
408 |
+
ref_dict = dict()
|
409 |
+
for rtag in para_el.find_all("ref"):
|
410 |
+
try:
|
411 |
+
ref_type = rtag.get("type")
|
412 |
+
# skip if citation
|
413 |
+
if ref_type == "bibr":
|
414 |
+
continue
|
415 |
+
if ref_type == "table" or ref_type == "figure":
|
416 |
+
ref_id = rtag.get("target")
|
417 |
+
if ref_id and normalize_grobid_id(ref_id) in refs:
|
418 |
+
# normalize reference string
|
419 |
+
rtag_string = normalize_grobid_id(ref_id)
|
420 |
+
else:
|
421 |
+
rtag_string = None
|
422 |
+
# add to ref set
|
423 |
+
ref_key = tokgen.next()
|
424 |
+
ref_dict[ref_key] = (rtag_string, rtag.text.strip(), ref_type)
|
425 |
+
rtag.replace_with(sp.new_string(f" {ref_key} "))
|
426 |
+
else:
|
427 |
+
# replace with surface form
|
428 |
+
rtag.replace_with(sp.new_string(rtag.text.strip()))
|
429 |
+
except AttributeError:
|
430 |
+
continue
|
431 |
+
return ref_dict
|
432 |
+
|
433 |
+
|
434 |
+
def process_citations_in_paragraph(
|
435 |
+
para_el: bs4.BeautifulSoup, sp: bs4.BeautifulSoup, bibs: Dict, bracket: bool
|
436 |
+
) -> Dict:
|
437 |
+
"""
|
438 |
+
Process all citations in paragraph and generate a dict for surface forms
|
439 |
+
:param para_el:
|
440 |
+
:param sp:
|
441 |
+
:param bibs:
|
442 |
+
:param bracket:
|
443 |
+
:return:
|
444 |
+
"""
|
445 |
+
|
446 |
+
# CHECK if range between two surface forms is appropriate for bracket style expansion
|
447 |
+
def _get_surface_range(start_surface, end_surface):
|
448 |
+
span1_match = SINGLE_BRACKET_REGEX.match(start_surface)
|
449 |
+
span2_match = SINGLE_BRACKET_REGEX.match(end_surface)
|
450 |
+
if span1_match and span2_match:
|
451 |
+
# get numbers corresponding to citations
|
452 |
+
span1_num = int(span1_match.group(1))
|
453 |
+
span2_num = int(span2_match.group(1))
|
454 |
+
# expand if range is between 1 and 20
|
455 |
+
if 1 < span2_num - span1_num < 20:
|
456 |
+
return span1_num, span2_num
|
457 |
+
return None
|
458 |
+
|
459 |
+
# CREATE BIBREF range between two reference ids, e.g. BIBREF1-BIBREF4 -> BIBREF1 BIBREF2 BIBREF3 BIBREF4
|
460 |
+
def _create_ref_id_range(start_ref_id, end_ref_id):
|
461 |
+
start_ref_num = int(start_ref_id[6:])
|
462 |
+
end_ref_num = int(end_ref_id[6:])
|
463 |
+
return [f"BIBREF{curr_ref_num}" for curr_ref_num in range(start_ref_num, end_ref_num + 1)]
|
464 |
+
|
465 |
+
# CREATE surface form range between two bracket strings, e.g. [1]-[4] -> [1] [2] [3] [4]
|
466 |
+
def _create_surface_range(start_number, end_number):
|
467 |
+
return [f"[{n}]" for n in range(start_number, end_number + 1)]
|
468 |
+
|
469 |
+
# create citation dict with keywords
|
470 |
+
cite_map = dict()
|
471 |
+
tokgen = UniqTokenGenerator("CITETOKEN")
|
472 |
+
|
473 |
+
for rtag in para_el.find_all("ref"):
|
474 |
+
try:
|
475 |
+
# get surface span, e.g. [3]
|
476 |
+
surface_span = rtag.text.strip()
|
477 |
+
|
478 |
+
# check if target is available (#b2 -> BID2)
|
479 |
+
if rtag.get("target"):
|
480 |
+
# normalize reference string
|
481 |
+
rtag_ref_id = normalize_grobid_id(rtag.get("target"))
|
482 |
+
|
483 |
+
# skip if rtag ref_id not in bibliography
|
484 |
+
if rtag_ref_id not in bibs:
|
485 |
+
cite_key = tokgen.next()
|
486 |
+
rtag.replace_with(sp.new_string(f" {cite_key} "))
|
487 |
+
cite_map[cite_key] = (None, surface_span)
|
488 |
+
continue
|
489 |
+
|
490 |
+
# if bracket style, only keep if surface form is bracket
|
491 |
+
if bracket:
|
492 |
+
# valid bracket span
|
493 |
+
if surface_span and (
|
494 |
+
surface_span[0] == "["
|
495 |
+
or surface_span[-1] == "]"
|
496 |
+
or surface_span[-1] == ","
|
497 |
+
):
|
498 |
+
pass
|
499 |
+
# invalid, replace tag with surface form and continue to next ref tag
|
500 |
+
else:
|
501 |
+
rtag.replace_with(sp.new_string(f" {surface_span} "))
|
502 |
+
continue
|
503 |
+
# not bracket, add cite span and move on
|
504 |
+
else:
|
505 |
+
cite_key = tokgen.next()
|
506 |
+
rtag.replace_with(sp.new_string(f" {cite_key} "))
|
507 |
+
cite_map[cite_key] = (rtag_ref_id, surface_span)
|
508 |
+
continue
|
509 |
+
|
510 |
+
# EXTRA PROCESSING FOR BRACKET STYLE CITATIONS; EXPAND RANGES ###
|
511 |
+
# look backward for range marker, e.g. [1]-*[3]*
|
512 |
+
backward_between_span = ""
|
513 |
+
for sib in rtag.previous_siblings:
|
514 |
+
if sib.name == "ref":
|
515 |
+
break
|
516 |
+
elif type(sib) is bs4.NavigableString:
|
517 |
+
backward_between_span += sib
|
518 |
+
else:
|
519 |
+
break
|
520 |
+
|
521 |
+
# check if there's a backwards expansion, e.g. need to expand [1]-[3] -> [1] [2] [3]
|
522 |
+
if is_expansion_string(backward_between_span):
|
523 |
+
# get surface number range
|
524 |
+
surface_num_range = _get_surface_range(
|
525 |
+
rtag.find_previous_sibling("ref").text.strip(), surface_span
|
526 |
+
)
|
527 |
+
# if the surface number range is reasonable (range < 20, in order), EXPAND
|
528 |
+
if surface_num_range:
|
529 |
+
# delete previous ref tag and anything in between (i.e. delete "-" and extra spaces)
|
530 |
+
for sib in rtag.previous_siblings:
|
531 |
+
if sib.name == "ref":
|
532 |
+
break
|
533 |
+
elif type(sib) is bs4.NavigableString:
|
534 |
+
sib.replace_with(sp.new_string(""))
|
535 |
+
else:
|
536 |
+
break
|
537 |
+
|
538 |
+
# get ref id of previous ref, e.g. [1] (#b0 -> BID0)
|
539 |
+
previous_rtag = rtag.find_previous_sibling("ref")
|
540 |
+
previous_rtag_ref_id = normalize_grobid_id(previous_rtag.get("target"))
|
541 |
+
previous_rtag.decompose()
|
542 |
+
|
543 |
+
# replace this ref tag with the full range expansion, e.g. [3] (#b2 -> BID1 BID2)
|
544 |
+
id_range = _create_ref_id_range(previous_rtag_ref_id, rtag_ref_id)
|
545 |
+
surface_range = _create_surface_range(
|
546 |
+
surface_num_range[0], surface_num_range[1]
|
547 |
+
)
|
548 |
+
replace_string = ""
|
549 |
+
for range_ref_id, range_surface_form in zip(id_range, surface_range):
|
550 |
+
# only replace if ref id is in bibliography, else add none
|
551 |
+
if range_ref_id in bibs:
|
552 |
+
cite_key = tokgen.next()
|
553 |
+
cite_map[cite_key] = (range_ref_id, range_surface_form)
|
554 |
+
else:
|
555 |
+
cite_key = tokgen.next()
|
556 |
+
cite_map[cite_key] = (None, range_surface_form)
|
557 |
+
replace_string += cite_key + " "
|
558 |
+
rtag.replace_with(sp.new_string(f" {replace_string} "))
|
559 |
+
# ELSE do not expand backwards and replace previous and current rtag with appropriate ref id
|
560 |
+
else:
|
561 |
+
# add mapping between ref id and surface form for previous ref tag
|
562 |
+
previous_rtag = rtag.find_previous_sibling("ref")
|
563 |
+
previous_rtag_ref_id = normalize_grobid_id(previous_rtag.get("target"))
|
564 |
+
previous_rtag_surface = previous_rtag.text.strip()
|
565 |
+
cite_key = tokgen.next()
|
566 |
+
previous_rtag.replace_with(sp.new_string(f" {cite_key} "))
|
567 |
+
cite_map[cite_key] = (
|
568 |
+
previous_rtag_ref_id,
|
569 |
+
previous_rtag_surface,
|
570 |
+
)
|
571 |
+
|
572 |
+
# add mapping between ref id and surface form for current reftag
|
573 |
+
cite_key = tokgen.next()
|
574 |
+
rtag.replace_with(sp.new_string(f" {cite_key} "))
|
575 |
+
cite_map[cite_key] = (rtag_ref_id, surface_span)
|
576 |
+
else:
|
577 |
+
# look forward and see if expansion string, e.g. *[1]*-[3]
|
578 |
+
forward_between_span = ""
|
579 |
+
for sib in rtag.next_siblings:
|
580 |
+
if sib.name == "ref":
|
581 |
+
break
|
582 |
+
elif type(sib) is bs4.NavigableString:
|
583 |
+
forward_between_span += sib
|
584 |
+
else:
|
585 |
+
break
|
586 |
+
# look forward for range marker (if is a range, continue -- range will be expanded
|
587 |
+
# when we get to the second value)
|
588 |
+
if is_expansion_string(forward_between_span):
|
589 |
+
continue
|
590 |
+
# else treat like normal reference
|
591 |
+
else:
|
592 |
+
cite_key = tokgen.next()
|
593 |
+
rtag.replace_with(sp.new_string(f" {cite_key} "))
|
594 |
+
cite_map[cite_key] = (rtag_ref_id, surface_span)
|
595 |
+
|
596 |
+
else:
|
597 |
+
cite_key = tokgen.next()
|
598 |
+
rtag.replace_with(sp.new_string(f" {cite_key} "))
|
599 |
+
cite_map[cite_key] = (None, surface_span)
|
600 |
+
except AttributeError:
|
601 |
+
continue
|
602 |
+
|
603 |
+
return cite_map
|
604 |
+
|
605 |
+
|
606 |
+
def process_paragraph(
|
607 |
+
sp: bs4.BeautifulSoup,
|
608 |
+
para_el: bs4.element.Tag,
|
609 |
+
section_names: List[Tuple],
|
610 |
+
bib_dict: Dict,
|
611 |
+
ref_dict: Dict,
|
612 |
+
bracket: bool,
|
613 |
+
) -> Dict:
|
614 |
+
"""
|
615 |
+
Process one paragraph
|
616 |
+
:param sp:
|
617 |
+
:param para_el:
|
618 |
+
:param section_names:
|
619 |
+
:param bib_dict:
|
620 |
+
:param ref_dict:
|
621 |
+
:param bracket: if bracket style, expand and clean up citations
|
622 |
+
:return:
|
623 |
+
"""
|
624 |
+
# return empty paragraph if no text
|
625 |
+
if not para_el.text:
|
626 |
+
return {
|
627 |
+
"text": "",
|
628 |
+
"cite_spans": [],
|
629 |
+
"ref_spans": [],
|
630 |
+
"eq_spans": [],
|
631 |
+
"section": section_names,
|
632 |
+
}
|
633 |
+
|
634 |
+
# replace formulas with formula text
|
635 |
+
process_formulas_in_paragraph(para_el, sp)
|
636 |
+
|
637 |
+
# get references to tables and figures
|
638 |
+
ref_map = process_references_in_paragraph(para_el, sp, ref_dict)
|
639 |
+
|
640 |
+
# generate citation map for paragraph element (keep only cite spans with bib entry or unlinked)
|
641 |
+
cite_map = process_citations_in_paragraph(para_el, sp, bib_dict, bracket)
|
642 |
+
|
643 |
+
# substitute space characters
|
644 |
+
para_text = re.sub(r"\s+", " ", para_el.text)
|
645 |
+
para_text = re.sub(r"\s", " ", para_text)
|
646 |
+
|
647 |
+
# get all cite and ref spans
|
648 |
+
all_spans_to_replace = []
|
649 |
+
for span in re.finditer(r"(CITETOKEN\d+)", para_text):
|
650 |
+
uniq_token = span.group()
|
651 |
+
ref_id, surface_text = cite_map[uniq_token]
|
652 |
+
all_spans_to_replace.append(
|
653 |
+
(span.start(), span.start() + len(uniq_token), uniq_token, surface_text)
|
654 |
+
)
|
655 |
+
for span in re.finditer(r"(REFTOKEN\d+)", para_text):
|
656 |
+
uniq_token = span.group()
|
657 |
+
ref_id, surface_text, ref_type = ref_map[uniq_token]
|
658 |
+
all_spans_to_replace.append(
|
659 |
+
(span.start(), span.start() + len(uniq_token), uniq_token, surface_text)
|
660 |
+
)
|
661 |
+
|
662 |
+
# replace cite and ref spans and create json blobs
|
663 |
+
para_text, all_spans_to_replace = sub_spans_and_update_indices(all_spans_to_replace, para_text)
|
664 |
+
|
665 |
+
cite_span_blobs = [
|
666 |
+
{"start": start, "end": end, "text": surface, "ref_id": cite_map[token][0]}
|
667 |
+
for start, end, token, surface in all_spans_to_replace
|
668 |
+
if token.startswith("CITETOKEN")
|
669 |
+
]
|
670 |
+
|
671 |
+
ref_span_blobs = [
|
672 |
+
{"start": start, "end": end, "text": surface, "ref_id": ref_map[token][0]}
|
673 |
+
for start, end, token, surface in all_spans_to_replace
|
674 |
+
if token.startswith("REFTOKEN")
|
675 |
+
]
|
676 |
+
|
677 |
+
for cite_blob in cite_span_blobs:
|
678 |
+
assert para_text[cite_blob["start"] : cite_blob["end"]] == cite_blob["text"]
|
679 |
+
|
680 |
+
for ref_blob in ref_span_blobs:
|
681 |
+
assert para_text[ref_blob["start"] : ref_blob["end"]] == ref_blob["text"]
|
682 |
+
|
683 |
+
return {
|
684 |
+
"text": para_text,
|
685 |
+
"cite_spans": cite_span_blobs,
|
686 |
+
"ref_spans": ref_span_blobs,
|
687 |
+
"eq_spans": [],
|
688 |
+
"section": section_names,
|
689 |
+
}
|
690 |
+
|
691 |
+
|
692 |
+
def extract_abstract_from_tei_xml(
|
693 |
+
sp: bs4.BeautifulSoup, bib_dict: Dict, ref_dict: Dict, cleanup_bracket: bool
|
694 |
+
) -> List[Dict]:
|
695 |
+
"""
|
696 |
+
Parse abstract from soup
|
697 |
+
:param sp:
|
698 |
+
:param bib_dict:
|
699 |
+
:param ref_dict:
|
700 |
+
:param cleanup_bracket:
|
701 |
+
:return:
|
702 |
+
"""
|
703 |
+
abstract_text = []
|
704 |
+
if sp.abstract:
|
705 |
+
# process all divs
|
706 |
+
if sp.abstract.div:
|
707 |
+
for div in sp.abstract.find_all("div"):
|
708 |
+
if div.text:
|
709 |
+
if div.p:
|
710 |
+
for para in div.find_all("p"):
|
711 |
+
if para.text:
|
712 |
+
abstract_text.append(
|
713 |
+
process_paragraph(
|
714 |
+
sp,
|
715 |
+
para,
|
716 |
+
[(None, "Abstract")],
|
717 |
+
bib_dict,
|
718 |
+
ref_dict,
|
719 |
+
cleanup_bracket,
|
720 |
+
)
|
721 |
+
)
|
722 |
+
else:
|
723 |
+
if div.text:
|
724 |
+
abstract_text.append(
|
725 |
+
process_paragraph(
|
726 |
+
sp,
|
727 |
+
div,
|
728 |
+
[(None, "Abstract")],
|
729 |
+
bib_dict,
|
730 |
+
ref_dict,
|
731 |
+
cleanup_bracket,
|
732 |
+
)
|
733 |
+
)
|
734 |
+
# process all paragraphs
|
735 |
+
elif sp.abstract.p:
|
736 |
+
for para in sp.abstract.find_all("p"):
|
737 |
+
if para.text:
|
738 |
+
abstract_text.append(
|
739 |
+
process_paragraph(
|
740 |
+
sp,
|
741 |
+
para,
|
742 |
+
[(None, "Abstract")],
|
743 |
+
bib_dict,
|
744 |
+
ref_dict,
|
745 |
+
cleanup_bracket,
|
746 |
+
)
|
747 |
+
)
|
748 |
+
# else just try to get the text
|
749 |
+
else:
|
750 |
+
if sp.abstract.text:
|
751 |
+
abstract_text.append(
|
752 |
+
process_paragraph(
|
753 |
+
sp,
|
754 |
+
sp.abstract,
|
755 |
+
[(None, "Abstract")],
|
756 |
+
bib_dict,
|
757 |
+
ref_dict,
|
758 |
+
cleanup_bracket,
|
759 |
+
)
|
760 |
+
)
|
761 |
+
sp.abstract.decompose()
|
762 |
+
return abstract_text
|
763 |
+
|
764 |
+
|
765 |
+
def extract_body_text_from_div(
|
766 |
+
sp: bs4.BeautifulSoup,
|
767 |
+
div: bs4.element.Tag,
|
768 |
+
sections: List[Tuple],
|
769 |
+
bib_dict: Dict,
|
770 |
+
ref_dict: Dict,
|
771 |
+
cleanup_bracket: bool,
|
772 |
+
) -> List[Dict]:
|
773 |
+
"""
|
774 |
+
Parse body text from soup
|
775 |
+
:param sp:
|
776 |
+
:param div:
|
777 |
+
:param sections:
|
778 |
+
:param bib_dict:
|
779 |
+
:param ref_dict:
|
780 |
+
:param cleanup_bracket:
|
781 |
+
:return:
|
782 |
+
"""
|
783 |
+
chunks = []
|
784 |
+
# check if nested divs; recursively process
|
785 |
+
if div.div:
|
786 |
+
for subdiv in div.find_all("div"):
|
787 |
+
# has header, add to section list and process
|
788 |
+
if subdiv.head:
|
789 |
+
chunks += extract_body_text_from_div(
|
790 |
+
sp,
|
791 |
+
subdiv,
|
792 |
+
sections + [(subdiv.head.get("n", None), subdiv.head.text.strip())],
|
793 |
+
bib_dict,
|
794 |
+
ref_dict,
|
795 |
+
cleanup_bracket,
|
796 |
+
)
|
797 |
+
subdiv.head.decompose()
|
798 |
+
# no header, process with same section list
|
799 |
+
else:
|
800 |
+
chunks += extract_body_text_from_div(
|
801 |
+
sp, subdiv, sections, bib_dict, ref_dict, cleanup_bracket
|
802 |
+
)
|
803 |
+
# process tags individuals
|
804 |
+
for tag in div:
|
805 |
+
try:
|
806 |
+
if tag.name == "p":
|
807 |
+
if tag.text:
|
808 |
+
chunks.append(
|
809 |
+
process_paragraph(sp, tag, sections, bib_dict, ref_dict, cleanup_bracket)
|
810 |
+
)
|
811 |
+
elif tag.name == "formula":
|
812 |
+
# e.g. <formula xml:id="formula_0">Y = W T X.<label>(1)</label></formula>
|
813 |
+
label = tag.label.text
|
814 |
+
tag.label.decompose()
|
815 |
+
eq_text = tag.text
|
816 |
+
chunks.append(
|
817 |
+
{
|
818 |
+
"text": "EQUATION",
|
819 |
+
"cite_spans": [],
|
820 |
+
"ref_spans": [],
|
821 |
+
"eq_spans": [
|
822 |
+
{
|
823 |
+
"start": 0,
|
824 |
+
"end": 8,
|
825 |
+
"text": "EQUATION",
|
826 |
+
"ref_id": "EQREF",
|
827 |
+
"raw_str": eq_text,
|
828 |
+
"eq_num": label,
|
829 |
+
}
|
830 |
+
],
|
831 |
+
"section": sections,
|
832 |
+
}
|
833 |
+
)
|
834 |
+
except AttributeError:
|
835 |
+
if tag.text:
|
836 |
+
chunks.append(
|
837 |
+
process_paragraph(sp, tag, sections, bib_dict, ref_dict, cleanup_bracket)
|
838 |
+
)
|
839 |
+
|
840 |
+
return chunks
|
841 |
+
|
842 |
+
|
843 |
+
def extract_body_text_from_tei_xml(
|
844 |
+
sp: bs4.BeautifulSoup, bib_dict: Dict, ref_dict: Dict, cleanup_bracket: bool
|
845 |
+
) -> List[Dict]:
|
846 |
+
"""
|
847 |
+
Parse body text from soup
|
848 |
+
:param sp:
|
849 |
+
:param bib_dict:
|
850 |
+
:param ref_dict:
|
851 |
+
:param cleanup_bracket:
|
852 |
+
:return:
|
853 |
+
"""
|
854 |
+
body_text = []
|
855 |
+
if sp.body:
|
856 |
+
body_text = extract_body_text_from_div(
|
857 |
+
sp, sp.body, [], bib_dict, ref_dict, cleanup_bracket
|
858 |
+
)
|
859 |
+
sp.body.decompose()
|
860 |
+
return body_text
|
861 |
+
|
862 |
+
|
863 |
+
def extract_back_matter_from_tei_xml(
|
864 |
+
sp: bs4.BeautifulSoup, bib_dict: Dict, ref_dict: Dict, cleanup_bracket: bool
|
865 |
+
) -> List[Dict]:
|
866 |
+
"""
|
867 |
+
Parse back matter from soup
|
868 |
+
:param sp:
|
869 |
+
:param bib_dict:
|
870 |
+
:param ref_dict:
|
871 |
+
:param cleanup_bracket:
|
872 |
+
:return:
|
873 |
+
"""
|
874 |
+
back_text = []
|
875 |
+
|
876 |
+
if sp.back:
|
877 |
+
for div in sp.back.find_all("div"):
|
878 |
+
if div.get("type"):
|
879 |
+
section_type = div.get("type")
|
880 |
+
else:
|
881 |
+
section_type = ""
|
882 |
+
|
883 |
+
for child_div in div.find_all("div"):
|
884 |
+
if child_div.head:
|
885 |
+
section_title = child_div.head.text.strip()
|
886 |
+
section_num = child_div.head.get("n", None)
|
887 |
+
child_div.head.decompose()
|
888 |
+
else:
|
889 |
+
section_title = section_type
|
890 |
+
section_num = None
|
891 |
+
if child_div.text:
|
892 |
+
if child_div.text:
|
893 |
+
back_text.append(
|
894 |
+
process_paragraph(
|
895 |
+
sp,
|
896 |
+
child_div,
|
897 |
+
[(section_num, section_title)],
|
898 |
+
bib_dict,
|
899 |
+
ref_dict,
|
900 |
+
cleanup_bracket,
|
901 |
+
)
|
902 |
+
)
|
903 |
+
sp.back.decompose()
|
904 |
+
return back_text
|