ArneBinder commited on
Commit
ced4316
·
verified ·
1 Parent(s): aed01f0

update from https://github.com/ArneBinder/pie-document-level/pull/397

Browse files
Files changed (50) hide show
  1. argumentation_model/_joint.yaml +4 -0
  2. argumentation_model/_pipelined.yaml +17 -0
  3. argumentation_model/joint.yaml +10 -0
  4. argumentation_model/joint_hps.yaml +7 -0
  5. argumentation_model/pipelined.yaml +8 -0
  6. argumentation_model/pipelined_deprecated.yaml +9 -0
  7. argumentation_model/pipelined_hps.yaml +8 -0
  8. argumentation_model/pipelined_new.yaml +14 -0
  9. demo.yaml +85 -0
  10. pdf_fulltext_extractor/grobid_local.yaml +18 -0
  11. pdf_fulltext_extractor/none.yaml +0 -0
  12. requirements.txt +8 -2
  13. retriever/related_span_retriever_with_relations_from_other_docs.yaml +49 -0
  14. src/analysis/__init__.py +0 -0
  15. src/analysis/combine_job_returns.py +169 -0
  16. src/analysis/common.py +47 -0
  17. src/analysis/compare_job_returns.py +407 -0
  18. src/data/acl_anthology_crawler.py +117 -0
  19. src/data/calc_iaa_for_brat.py +272 -0
  20. src/data/construct_sciarg_abstracts_remaining_gold_retrieval.py +238 -0
  21. src/data/prepare_sciarg_crosssection_annotations.py +398 -0
  22. src/data/split_sciarg_abstracts.py +132 -0
  23. src/demo/annotation_utils.py +88 -41
  24. src/demo/backend_utils.py +106 -13
  25. src/demo/frontend_utils.py +12 -0
  26. src/demo/rendering_utils.py +23 -3
  27. src/demo/rendering_utils_displacy.py +12 -1
  28. src/demo/retrieve_and_dump_all_relevant.py +61 -2
  29. src/demo/retriever_utils.py +8 -6
  30. src/document/processing.py +212 -77
  31. src/hydra_callbacks/save_job_return_value.py +178 -40
  32. src/langchain_modules/pie_document_store.py +1 -1
  33. src/langchain_modules/span_retriever.py +13 -16
  34. src/pipeline/ner_re_pipeline.py +45 -15
  35. src/predict.py +6 -2
  36. src/start_demo.py +161 -36
  37. src/train.py +10 -0
  38. src/utils/__init__.py +6 -1
  39. src/utils/config_utils.py +15 -1
  40. src/utils/pdf_utils/README.MD +35 -0
  41. src/utils/pdf_utils/__init__.py +0 -0
  42. src/utils/pdf_utils/acl_anthology_utils.py +77 -0
  43. src/utils/pdf_utils/client.py +193 -0
  44. src/utils/pdf_utils/grobid_client.py +203 -0
  45. src/utils/pdf_utils/grobid_util.py +413 -0
  46. src/utils/pdf_utils/process_pdf.py +276 -0
  47. src/utils/pdf_utils/raw_paper.py +90 -0
  48. src/utils/pdf_utils/s2orc_paper.py +478 -0
  49. src/utils/pdf_utils/s2orc_utils.py +61 -0
  50. src/utils/pdf_utils/utils.py +904 -0
argumentation_model/_joint.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ _target_: pytorch_ie.auto.AutoPipeline.from_pretrained
2
+ pretrained_model_name_or_path: ???
3
+ # this batch_size that works good (fastest) on a single RTX2080Ti (11GB) (see https://github.com/ArneBinder/pie-document-level/issues/334#issuecomment-2613232344)
4
+ batch_size: 1
argumentation_model/_pipelined.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.pipeline.NerRePipeline
2
+ ner_model_path: ???
3
+ re_model_path: ???
4
+ entity_layer: labeled_spans
5
+ relation_layer: binary_relations
6
+ # this works good on a single RTX2080Ti (11GB)
7
+ ner_pipeline:
8
+ batch_size: 256
9
+ re_pipeline:
10
+ batch_size: 64
11
+ # convert the RE model to half precision for mixed precision inference (speedup approx. 4x)
12
+ half_precision_model: true
13
+ taskmodule_kwargs:
14
+ # don't show statistics after encoding
15
+ collect_statistics: false
16
+ # don't show pipeline steps
17
+ verbose: false
argumentation_model/joint.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _joint
3
+
4
+ # best model based on the validation set (see https://github.com/ArneBinder/pie-document-level/issues/334#issuecomment-2613232344 for details)
5
+ # i.e. models from https://github.com/ArneBinder/pie-document-level/issues/334#issuecomment-2578422544, but with last checkpoint (instead of best validation checkpoint)
6
+ # model_name_or_path: models/dataset-sciarg/task-ner_re/v0.4/2025-01-09_01-50-53
7
+ # ckpt_path: logs/training/multiruns/dataset-sciarg/task-ner_re/v0.4/2025-01-09_01-50-52/2/checkpoints/last.ckpt
8
+ # w&b run (for the loaded checkpoint): [icy-glitter-5](https://wandb.ai/arne/dataset-sciarg-task-ner_re-v0.4-training/runs/it5toj6w)
9
+ pretrained_model_name_or_path: "ArneBinder/sam-pointer-bart-base-v0.4"
10
+ revision: "0445c69bafa31f8153aaeafc1767fad84919926a"
argumentation_model/joint_hps.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _joint
3
+
4
+ # from: hparams_search for all datasets
5
+ # see https://github.com/ArneBinder/pie-document-level/pull/381#issuecomment-2682711151
6
+ # THESE ARE LOCAL PATHS, NOT HUGGINGFACE MODELS!
7
+ pretrained_model_name_or_path: models/dataset-sciarg/task-ner_re/2025-02-23_05-16-45
argumentation_model/pipelined.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _pipelined
3
+
4
+ # from: train pipeline models with bigger train set,
5
+ # see https://github.com/ArneBinder/pie-document-level/issues/355#issuecomment-2612958658
6
+ # THESE ARE LOCAL PATHS, NOT HUGGINGFACE MODELS!
7
+ ner_model_path: models/dataset-sciarg/task-adus/v0.4/2025-01-20_05-50-00
8
+ re_model_path: models/dataset-sciarg/task-relations/v0.4/2025-01-22_20-36-23
argumentation_model/pipelined_deprecated.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _pipelined
3
+
4
+ # from: train pipeline models with bigger train set, but with strange choice of models,
5
+ # see edit history of https://github.com/ArneBinder/pie-document-level/issues/355#issuecomment-2612958658
6
+ # NOTE: these were originally in the pipelined.yaml
7
+ # THESE ARE LOCAL PATHS, NOT HUGGINGFACE MODELS!
8
+ ner_model_path: models/dataset-sciarg/task-adus/v0.4/2025-01-20_09-09-11
9
+ re_model_path: models/dataset-sciarg/task-relations/v0.4/2025-01-22_12-44-51
argumentation_model/pipelined_hps.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _pipelined
3
+
4
+ # from: hparams_search for all datasets,
5
+ # see https://github.com/ArneBinder/pie-document-level/pull/381#issuecomment-2684865102
6
+ # THESE ARE LOCAL PATHS, NOT HUGGINGFACE MODELS!
7
+ ner_model_path: models/dataset-sciarg/task-adur/2025-02-26_07-14-59
8
+ re_model_path: models/dataset-sciarg/task-are/2025-02-20_18-09-25
argumentation_model/pipelined_new.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _pipelined
3
+
4
+ # from: Update scientific ARE experiment configs,
5
+ # see https://github.com/ArneBinder/pie-document-level/pull/379#issuecomment-2651669398
6
+ # i.e. the models are now on Hugging Face
7
+ # ner_model_path: models/dataset-sciarg/task-adur/2025-02-09_23-08-37
8
+ # re_model_path: models/dataset-sciarg/task-are/2025-02-10_19-24-52
9
+ ner_model_path: ArneBinder/sam-adur-sciarg
10
+ ner_pipeline:
11
+ revision: bcbef4e585a5f637009ff702661cf824abede6b0
12
+ re_model_path: ArneBinder/sam-are-sciarg
13
+ re_pipeline:
14
+ revision: 93024388330c58daf20963c2020e08f54553e74c
demo.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ # default retriever, see subfolder retriever for more details
4
+ - retriever: related_span_retriever_with_relations_from_other_docs
5
+ # default argumentation model, see subfolder argumentation_model for more details
6
+ - argumentation_model: pipelined_new
7
+ # since this requires a running GROBID server, we disable it by default
8
+ - pdf_fulltext_extractor: none
9
+
10
+ # Whether to handle segmented entities in the document. If True, labeled_spans are converted
11
+ # to labeled_multi_spans and binary_relations with label "parts_of_same" are used to merge them.
12
+ # This requires the networkx package to be installed.
13
+ handle_parts_of_same: true
14
+ # Split the document text into sections that are processed separately.
15
+ default_split_regex: "\n\n\n+"
16
+
17
+ # retriever details (query parameters)
18
+ default_min_similarity: 0.95
19
+ default_top_k: 10
20
+
21
+ # data import details
22
+ default_arxiv_id: "1706.03762"
23
+ default_load_pie_dataset_kwargs:
24
+ path: "pie/sciarg"
25
+ name: "resolve_parts_of_same"
26
+ split: "train"
27
+
28
+ # set to the data directory of https://github.com/acl-org/acl-anthology
29
+ # to enable ACL venue PDF import (requires to also have a valid pdf_fulltext_extractor)
30
+ # acl_anthology_data_dir=../acl-anthology/data
31
+ # temporary directory to store downloaded PDFs
32
+ acl_anthology_pdf_dir: "data/acl-anthology/pdf"
33
+
34
+ # for better readability in the UI
35
+ render_mode_captions:
36
+ displacy: "displaCy + highlighted arguments"
37
+ pretty_table: "Pretty Table"
38
+ layer_caption_mapping:
39
+ labeled_multi_spans: "adus"
40
+ binary_relations: "relations"
41
+ labeled_partitions: "partitions"
42
+ relation_name_mapping:
43
+ supports_reversed: "supported by"
44
+ contradicts_reversed: "contradicts"
45
+
46
+ default_render_mode: "displacy"
47
+ default_render_kwargs:
48
+ entity_options:
49
+ # we need to have the keys as uppercase because the spacy rendering function converts the labels to uppercase
50
+ colors:
51
+ OWN_CLAIM: "#009933"
52
+ BACKGROUND_CLAIM: "#99ccff"
53
+ DATA: "#993399"
54
+ colors_hover:
55
+ selected: "#ffa"
56
+ # tail options for relationships
57
+ tail:
58
+ # green
59
+ supports: "#9f9"
60
+ # red
61
+ contradicts: "#f99"
62
+ # do not highlight
63
+ parts_of_same: null
64
+ head: null # "#faf"
65
+ other: null
66
+
67
+ example_text: >
68
+ Scholarly Argumentation Mining (SAM) has recently gained attention due to its
69
+ potential to help scholars with the rapid growth of published scientific literature.
70
+ It comprises two subtasks: argumentative discourse unit recognition (ADUR) and
71
+ argumentative relation extraction (ARE), both of which are challenging since they
72
+ require e.g. the integration of domain knowledge, the detection of implicit statements,
73
+ and the disambiguation of argument structure.
74
+
75
+ While previous work focused on dataset construction and baseline methods for
76
+ specific document sections, such as abstract or results, full-text scholarly argumentation
77
+ mining has seen little progress. In this work, we introduce a sequential pipeline model
78
+ combining ADUR and ARE for full-text SAM, and provide a first analysis of the
79
+ performance of pretrained language models (PLMs) on both subtasks.
80
+
81
+ We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best
82
+ reported result by a large margin (+7% F1). We also present the first results for ARE, and
83
+ thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals
84
+ that non-contiguous ADUs as well as the interpretation of discourse connectors pose major
85
+ challenges and that data annotation needs to be more consistent.
pdf_fulltext_extractor/grobid_local.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This requires a running GROBID server. To start the server via Docker, run:
2
+ # docker run --rm --init --ulimit core=0 -p 8070:8070 lfoppiano/grobid:0.8.0
3
+
4
+ _target_: src.utils.pdf_utils.process_pdf.GrobidFulltextExtractor
5
+ section_seperator: "\n\n\n"
6
+ paragraph_seperator: "\n\n"
7
+ grobid_config:
8
+ grobid_server: localhost
9
+ grobid_port: 8070
10
+ batch_size: 1000
11
+ sleep_time: 5
12
+ generateIDs: false
13
+ consolidate_header: false
14
+ consolidate_citations: false
15
+ include_raw_citations: true
16
+ include_raw_affiliations: false
17
+ max_workers: 2
18
+ verbose: false
pdf_fulltext_extractor/none.yaml ADDED
File without changes
requirements.txt CHANGED
@@ -1,7 +1,11 @@
 
 
 
 
1
  # --------- pytorch-ie --------- #
2
- pytorch-ie>=0.29.6,<0.32.0
3
  pie-datasets>=0.10.5,<0.11.0
4
- pie-modules>=0.14.0,<0.15.0
5
 
6
  # --------- models -------- #
7
  adapters>=0.1.2,<0.2.0
@@ -17,6 +21,8 @@ qdrant-client>=1.12.0,<2.0.0
17
  # --------- demo -------- #
18
  gradio~=5.5.0
19
  arxiv~=2.1.3
 
 
20
 
21
  # --------- hydra --------- #
22
  hydra-core>=1.3.0
 
1
+ # -------- dl backend -------- #
2
+ torch==2.0.0
3
+ pytorch-lightning==2.1.2
4
+
5
  # --------- pytorch-ie --------- #
6
+ pytorch-ie>=0.31.4,<0.32.0
7
  pie-datasets>=0.10.5,<0.11.0
8
+ pie-modules>=0.14.2,<0.15.0
9
 
10
  # --------- models -------- #
11
  adapters>=0.1.2,<0.2.0
 
21
  # --------- demo -------- #
22
  gradio~=5.5.0
23
  arxiv~=2.1.3
24
+ # data preparation
25
+ acl-anthology-py>=0.4.3
26
 
27
  # --------- hydra --------- #
28
  hydra-core>=1.3.0
retriever/related_span_retriever_with_relations_from_other_docs.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.langchain_modules.DocumentAwareSpanRetrieverWithRelations
2
+ symmetric_relations:
3
+ - contradicts
4
+ reversed_relations_suffix: _reversed
5
+ relation_labels:
6
+ - supports_reversed
7
+ - contradicts
8
+ retrieve_from_same_document: false
9
+ retrieve_from_different_documents: true
10
+ pie_document_type:
11
+ _target_: pie_modules.utils.resolve_type
12
+ type_or_str: pytorch_ie.documents.TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
13
+ docstore:
14
+ _target_: src.langchain_modules.DatasetsPieDocumentStore
15
+ search_kwargs:
16
+ k: 10
17
+ search_type: similarity_score_threshold
18
+ vectorstore:
19
+ _target_: src.langchain_modules.QdrantSpanVectorStore
20
+ embedding:
21
+ _target_: src.langchain_modules.HuggingFaceSpanEmbeddings
22
+ model:
23
+ _target_: src.models.utils.load_model_with_adapter
24
+ model_kwargs:
25
+ pretrained_model_name_or_path: allenai/specter2_base
26
+ adapter_kwargs:
27
+ adapter_name_or_path: allenai/specter2
28
+ load_as: proximity
29
+ source: hf
30
+ pipeline_kwargs:
31
+ tokenizer: allenai/specter2_base
32
+ stride: 64
33
+ batch_size: 32
34
+ model_max_length: 512
35
+ client:
36
+ _target_: qdrant_client.QdrantClient
37
+ location: ":memory:"
38
+ collection_name: adus
39
+ vector_params:
40
+ distance:
41
+ _target_: qdrant_client.http.models.Distance
42
+ value: Cosine
43
+ label_mapping:
44
+ background_claim:
45
+ - background_claim
46
+ - own_claim
47
+ own_claim:
48
+ - background_claim
49
+ - own_claim
src/analysis/__init__.py ADDED
File without changes
src/analysis/combine_job_returns.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+
3
+ root = pyrootutils.setup_root(
4
+ search_from=__file__,
5
+ indicator=[".project-root"],
6
+ pythonpath=True,
7
+ dotenv=False,
8
+ )
9
+
10
+ import argparse
11
+ import os
12
+
13
+ import pandas as pd
14
+
15
+ from src.analysis.common import read_nested_jsons
16
+
17
+
18
+ def separate_path_and_id(path_and_maybe_id: str, separator: str = ":") -> tuple[str | None, str]:
19
+ parts = path_and_maybe_id.split(separator, 1)
20
+ if len(parts) == 1:
21
+ return None, parts[0]
22
+ return parts[0], parts[1]
23
+
24
+
25
+ def get_file_paths(paths_file: str, file_name: str, use_aggregated: bool) -> dict[str, str]:
26
+ with open(paths_file, "r") as f:
27
+ paths_maybe_with_ids = f.readlines()
28
+ ids, paths = zip(*[separate_path_and_id(path.strip()) for path in paths_maybe_with_ids])
29
+
30
+ if use_aggregated:
31
+ file_base_name, ext = os.path.splitext(file_name)
32
+ file_name = f"{file_base_name}.aggregated{ext}"
33
+ file_paths = [os.path.join(path, file_name) for path in paths]
34
+ return {
35
+ id if id is not None else f"idx={idx}": path
36
+ for idx, (id, path) in enumerate(zip(ids, file_paths))
37
+ }
38
+
39
+
40
+ def main(
41
+ paths_file: str,
42
+ file_name: str,
43
+ use_aggregated: bool,
44
+ columns: list[str] | None,
45
+ round_precision: int | None,
46
+ format: str,
47
+ transpose: bool = False,
48
+ unpack_multirun_results: bool = False,
49
+ in_percent: bool = False,
50
+ ):
51
+ file_paths = get_file_paths(
52
+ paths_file=paths_file, file_name=file_name, use_aggregated=use_aggregated
53
+ )
54
+ data = read_nested_jsons(json_paths=file_paths)
55
+
56
+ if columns is not None:
57
+ columns_multi_index = [tuple(col.split("/")) for col in columns]
58
+ try:
59
+ data_series = [data[col] for col in columns_multi_index]
60
+ except KeyError as e:
61
+ print(
62
+ f"Columns {columns_multi_index} not found in the data. Available columns are {list(data.columns)}."
63
+ )
64
+ raise e
65
+ data = pd.concat(data_series, axis=1)
66
+
67
+ # drop rows that are all NaN
68
+ data = data.dropna(how="all")
69
+
70
+ # if more than one data point, drop the index levels that are everywhere the same
71
+ if len(data) > 1:
72
+ unique_levels = [
73
+ idx
74
+ for idx, level in enumerate(data.index.levels)
75
+ if len(data.index.get_level_values(idx).unique()) == 1
76
+ ]
77
+ for level in sorted(unique_levels, reverse=True):
78
+ data.index = data.index.droplevel(level)
79
+
80
+ # if more than one column, drop the columns that are everywhere the same
81
+ if len(data.columns) > 1:
82
+ unique_column_levels = [
83
+ idx
84
+ for idx, level in enumerate(data.columns.levels)
85
+ if len(data.columns.get_level_values(idx).unique()) == 1
86
+ ]
87
+ for level in sorted(unique_column_levels, reverse=True):
88
+ data.columns = data.columns.droplevel(level)
89
+
90
+ if unpack_multirun_results:
91
+ index_names = list(data.index.names)
92
+ data_series_lists = data.unstack()
93
+ data = pd.DataFrame.from_records(
94
+ data_series_lists.values, index=data_series_lists.index
95
+ ).stack()
96
+ for _, index_name in enumerate(index_names):
97
+ data = data.unstack(index_name)
98
+ data = data.T
99
+
100
+ if transpose:
101
+ data = data.T
102
+
103
+ # needs to happen before rounding, otherwise the rounding will be off
104
+ if in_percent:
105
+ data = data * 100
106
+
107
+ if round_precision is not None:
108
+ data = data.round(round_precision)
109
+
110
+ if format == "markdown":
111
+ print(data.to_markdown())
112
+ elif format == "markdown_mean_and_std":
113
+ if transpose:
114
+ data = data.T
115
+ if "mean" not in data.columns or "std" not in data.columns:
116
+ raise ValueError("Columns 'mean' and 'std' are required for this format.")
117
+ # create a single column with mean and std in the format: mean ± std
118
+ data = pd.DataFrame(
119
+ data["mean"].astype(str) + " ± " + data["std"].astype(str), columns=["mean ± std"]
120
+ )
121
+ if transpose:
122
+ data = data.T
123
+ print(data.to_markdown())
124
+ elif format == "json":
125
+ print(data.to_json())
126
+ else:
127
+ raise ValueError(f"Invalid format: {format}. Use 'markdown' or 'json'.")
128
+
129
+
130
+ if __name__ == "__main__":
131
+ parser = argparse.ArgumentParser(description="Combine job returns and show as Markdown table")
132
+ parser.add_argument(
133
+ "--paths-file", type=str, help="Path to the file containing the paths to the job returns"
134
+ )
135
+ parser.add_argument(
136
+ "--use-aggregated", action="store_true", help="Whether to use the aggregated job returns"
137
+ )
138
+ parser.add_argument(
139
+ "--file-name",
140
+ type=str,
141
+ default="job_return_value.json",
142
+ help="Name of the file to write the aggregated job returns to",
143
+ )
144
+ parser.add_argument(
145
+ "--columns", type=str, nargs="+", help="Columns to select from the combined job returns"
146
+ )
147
+ parser.add_argument(
148
+ "--unpack-multirun-results", action="store_true", help="Unpack multirun results"
149
+ )
150
+ parser.add_argument("--transpose", action="store_true", help="Transpose the table")
151
+ parser.add_argument(
152
+ "--round-precision",
153
+ type=int,
154
+ help="Round the values in the combined job returns to the specified precision",
155
+ )
156
+ parser.add_argument(
157
+ "--in-percent", action="store_true", help="Show the values in percent (multiply by 100)"
158
+ )
159
+ parser.add_argument(
160
+ "--format",
161
+ type=str,
162
+ default="markdown",
163
+ choices=["markdown", "markdown_mean_and_std", "json"],
164
+ help="Format to output the combined job returns",
165
+ )
166
+
167
+ args = parser.parse_args()
168
+ kwargs = vars(args)
169
+ main(**kwargs)
src/analysis/common.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Dict, List, Optional
3
+
4
+ import pandas as pd
5
+
6
+
7
+ def parse_identifier(
8
+ identifier_str, defaults: Dict[str, str], parts_sep: str = ",", key_val_sep: str = "="
9
+ ) -> Dict[str, str]:
10
+ parts = [
11
+ part.split(key_val_sep)
12
+ for part in identifier_str.strip().split(parts_sep)
13
+ if key_val_sep in part
14
+ ]
15
+ parts_dict = dict(parts)
16
+ return {**defaults, **parts_dict}
17
+
18
+
19
+ def read_nested_json(path: str) -> pd.DataFrame:
20
+ # Read the nested JSON data into a pandas DataFrame
21
+ with open(path, "r") as f:
22
+ data = json.load(f)
23
+ result = pd.json_normalize(data, sep="/")
24
+ result.index.name = "entry"
25
+ return result
26
+
27
+
28
+ def read_nested_jsons(
29
+ json_paths: Dict[str, str],
30
+ default_key_values: Optional[Dict[str, str]] = None,
31
+ column_level_names: Optional[List[str]] = None,
32
+ ) -> pd.DataFrame:
33
+ identifier_strings = json_paths.keys()
34
+ dfs = [read_nested_json(json_paths[identifier_str]) for identifier_str in identifier_strings]
35
+ new_index_levels = pd.MultiIndex.from_frame(
36
+ pd.DataFrame(
37
+ [
38
+ parse_identifier(identifier_str, default_key_values or {})
39
+ for identifier_str in identifier_strings
40
+ ]
41
+ )
42
+ )
43
+ dfs_concat = pd.concat(dfs, keys=list(new_index_levels), names=new_index_levels.names, axis=0)
44
+ dfs_concat.columns = pd.MultiIndex.from_tuples(
45
+ [col.split("/") for col in dfs_concat.columns], names=column_level_names
46
+ )
47
+ return dfs_concat
src/analysis/compare_job_returns.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+
3
+ root = pyrootutils.setup_root(
4
+ search_from=__file__,
5
+ indicator=[".project-root"],
6
+ pythonpath=True,
7
+ dotenv=False,
8
+ )
9
+
10
+ import argparse
11
+ import io
12
+ import re
13
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
14
+
15
+ import pandas as pd
16
+ import plotly.graph_objects as go
17
+
18
+ from src.analysis.common import parse_identifier, read_nested_jsons
19
+
20
+
21
+ def read_markdown_table(
22
+ markdown_data: str,
23
+ default_key_values: Optional[Dict[str, str]] = None,
24
+ column_level_names: Optional[List[str]] = None,
25
+ ) -> pd.DataFrame:
26
+ # Read the markdown data into a pandas DataFrame
27
+ df = pd.read_csv(io.StringIO(markdown_data), sep="|", engine="python", skiprows=1)
28
+
29
+ # Clean up the DataFrame
30
+ # drop the first and last columns
31
+ df = df.drop(columns=[df.columns[0], df.columns[-1]])
32
+ # drop the first row
33
+ df = df.drop(0)
34
+ # make the index from the first column: parse the string and extract the values
35
+ df.index = pd.MultiIndex.from_tuples(
36
+ [tuple(x.strip()[2:-2].split("', '")) for x in df[df.columns[0]]]
37
+ )
38
+ # drop the first column
39
+ df = df.drop(columns=[df.columns[0]])
40
+ # parse the column names and create a MultiIndex
41
+ columns = pd.DataFrame(
42
+ [parse_identifier(col, defaults=default_key_values or {}) for col in df.columns]
43
+ )
44
+ df.columns = pd.MultiIndex.from_frame(columns)
45
+
46
+ # Function to parse the values and errors
47
+ def parse_value_error(value_error_str: str) -> Tuple[float, float]:
48
+ match = re.match(r"([0-9.]+) \(?± ?([0-9.]+)\)?", value_error_str.strip())
49
+ if match:
50
+ return float(match.group(1)), float(match.group(2))
51
+ raise ValueError(f"Invalid value error string: {value_error_str}")
52
+
53
+ df_mean_and_std_cells = df.map(lambda x: parse_value_error(x))
54
+ # make a new DataFrame with the mean and std values as new rows
55
+ result = pd.concat(
56
+ {
57
+ "mean": df_mean_and_std_cells.map(lambda x: x[0]),
58
+ "std": df_mean_and_std_cells.map(lambda x: x[1]),
59
+ },
60
+ axis=0,
61
+ )
62
+ # transpose the DataFrame
63
+ result = result.T
64
+ # move new column index level to the most inner level
65
+ result.columns = pd.MultiIndex.from_tuples(
66
+ [col[1:] + (col[0],) for col in result.columns], names=column_level_names
67
+ )
68
+ return result
69
+
70
+
71
+ def rearrange_for_plotting(
72
+ data: Union[pd.DataFrame, pd.Series], x_axis: str, x_is_numeric: bool
73
+ ) -> pd.DataFrame:
74
+ # rearrange the DataFrame for plotting
75
+ while not isinstance(data, pd.Series):
76
+ data = data.unstack()
77
+ result = data.unstack(x_axis)
78
+ if x_is_numeric:
79
+ result.columns = result.columns.astype(float)
80
+ return result
81
+
82
+
83
+ # Function to create plots
84
+ def create_plot(
85
+ title,
86
+ x_axis: str,
87
+ data: pd.DataFrame,
88
+ data_err: Optional[pd.DataFrame] = None,
89
+ x_is_numeric: bool = False,
90
+ marker_getter: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
91
+ line_getter: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
92
+ ):
93
+ data = rearrange_for_plotting(data, x_axis, x_is_numeric)
94
+ # sort the columns by the x_axis values
95
+ data = data[data.columns.sort_values()]
96
+
97
+ if data_err is not None:
98
+ data_err = rearrange_for_plotting(data_err, x_axis, x_is_numeric)
99
+ data_err = data_err[data.columns]
100
+
101
+ fig = go.Figure()
102
+ for trace_idx, row_data_mean in data.iterrows():
103
+ trace_meta = dict(zip(data.index.names, trace_idx))
104
+ if data_err is not None:
105
+ error_y = dict(type="data", array=data_err.loc[trace_idx])
106
+ else:
107
+ error_y = None
108
+ fig.add_trace(
109
+ go.Scatter(
110
+ x=row_data_mean.index,
111
+ y=row_data_mean,
112
+ error_y=error_y,
113
+ mode="lines+markers",
114
+ marker=marker_getter(trace_meta) if marker_getter is not None else None,
115
+ line=line_getter(trace_meta) if line_getter is not None else None,
116
+ name=", ".join(trace_idx),
117
+ )
118
+ )
119
+ fig.update_layout(title=title, xaxis_title=x_axis, yaxis_title="Values")
120
+ fig.show()
121
+
122
+
123
+ def prepare_for_markdown(
124
+ df: pd.DataFrame,
125
+ aggregation_column_level: Optional[str] = None,
126
+ round_precision: Optional[int] = None,
127
+ ) -> pd.DataFrame:
128
+ result = df.copy()
129
+ # simplify index: create single index from all levels in the format "level1_name=level1_val,level2_name=level2_val,..."
130
+ if isinstance(result.index, pd.MultiIndex):
131
+ result.index = [
132
+ ",".join([f"{name}={val}" for name, val in zip(result.index.names, idx)])
133
+ for idx in result.index
134
+ ]
135
+ else:
136
+ result.index = [f"{result.index.name}={idx}" for idx in result.index]
137
+ result = result.T
138
+ if round_precision is not None:
139
+ result = result.round(round_precision)
140
+ if aggregation_column_level is not None:
141
+ result_mean = result.xs("mean", level=aggregation_column_level, axis="index")
142
+ result_std = result.xs("std", level=aggregation_column_level, axis="index")
143
+ # combine each cell with mean and std into a single string
144
+ result = pd.DataFrame(
145
+ {
146
+ col: [f"{mean} (±{std})" for mean, std in zip(result_mean[col], result_std[col])]
147
+ for col in result_mean.columns
148
+ },
149
+ index=result_mean.index,
150
+ )
151
+
152
+ return result
153
+
154
+
155
+ def combine_job_returns_and_plot(
156
+ x_axis: str,
157
+ plot_column_level: str,
158
+ job_return_paths: Optional[Dict[str, str]] = None,
159
+ markdown_str: Optional[str] = None,
160
+ default_key_values: Optional[Dict[str, str]] = None,
161
+ column_level_names: Optional[List[str]] = None,
162
+ drop_columns: Optional[Dict[str, str]] = None,
163
+ aggregation_column_level: Optional[str] = None,
164
+ title_prefix: Optional[str] = None,
165
+ x_is_not_numeric: bool = False,
166
+ show_as: str = "plot",
167
+ markdown_round_precision: Optional[int] = None,
168
+ marker_getter: Optional[Callable] = None,
169
+ line_getter: Optional[Callable] = None,
170
+ # placeholder to allow description in CONFIGS
171
+ description: Optional[str] = None,
172
+ ):
173
+
174
+ if job_return_paths is not None:
175
+ df_all = read_nested_jsons(
176
+ json_paths=job_return_paths,
177
+ default_key_values=default_key_values,
178
+ column_level_names=column_level_names,
179
+ )
180
+ elif markdown_str is not None:
181
+ df_all = read_markdown_table(
182
+ markdown_data=markdown_str,
183
+ default_key_values=default_key_values,
184
+ column_level_names=column_level_names,
185
+ )
186
+ else:
187
+ raise ValueError("Either job_return_paths or markdown_str must be provided")
188
+
189
+ for metric, value in (drop_columns or {}).items():
190
+ df_all = df_all.drop(columns=value, level=metric)
191
+
192
+ # drop index levels where all values are the same
193
+ index_levels_to_drop = [
194
+ i
195
+ for i in range(df_all.index.nlevels)
196
+ if len(df_all.index.get_level_values(i).unique()) == 1
197
+ ]
198
+ dropped_index_levels = {
199
+ df_all.index.names[i]: df_all.index.get_level_values(i).unique()[0]
200
+ for i in index_levels_to_drop
201
+ }
202
+ if len(index_levels_to_drop) > 0:
203
+ print(f"Drop index levels: {dropped_index_levels}")
204
+ df_all = df_all.droplevel(index_levels_to_drop, axis="index")
205
+ # drop column levels where all values are the same
206
+ column_levels_to_drop = [
207
+ i
208
+ for i in range(df_all.columns.nlevels)
209
+ if len(df_all.columns.get_level_values(i).unique()) == 1
210
+ ]
211
+ dropped_column_levels = {
212
+ df_all.columns.names[i]: df_all.columns.get_level_values(i).unique()[0]
213
+ for i in column_levels_to_drop
214
+ }
215
+ if len(column_levels_to_drop) > 0:
216
+ print(f"Drop column levels: {dropped_column_levels}")
217
+ df_all = df_all.droplevel(column_levels_to_drop, axis="columns")
218
+
219
+ if show_as == "markdown":
220
+ print(
221
+ prepare_for_markdown(
222
+ df_all,
223
+ aggregation_column_level=aggregation_column_level,
224
+ round_precision=markdown_round_precision,
225
+ ).to_markdown()
226
+ )
227
+ elif show_as == "plots":
228
+ # create plots for each "average" value, i.e. MACRO, MICRO, but also label specific values
229
+ plot_names = df_all.columns.get_level_values(plot_column_level).unique()
230
+ for plot_name in plot_names:
231
+ data_plot = df_all.xs(plot_name, level=plot_column_level, axis="columns")
232
+ data_err = None
233
+ if aggregation_column_level is not None:
234
+ data_err = data_plot.xs("std", level=aggregation_column_level, axis="columns")
235
+ data_plot = data_plot.xs("mean", level=aggregation_column_level, axis="columns")
236
+
237
+ # Create plot for MACRO values
238
+ if title_prefix is not None:
239
+ plot_name = f"{title_prefix}: {plot_name}"
240
+ create_plot(
241
+ title=plot_name,
242
+ data=data_plot,
243
+ data_err=data_err,
244
+ x_axis=x_axis,
245
+ x_is_numeric=not x_is_not_numeric,
246
+ marker_getter=marker_getter,
247
+ line_getter=line_getter,
248
+ )
249
+ else:
250
+ raise ValueError(f"Invalid show_as: {show_as}")
251
+
252
+
253
+ CONFIGS = {
254
+ "joint model (adus) - last vs best val checkpoint @test": dict(
255
+ job_return_paths={
256
+ "epochs=75": "logs/document_evaluation/multiruns/default/2025-01-12_13-28-54/job_return_value.aggregated.json",
257
+ "epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_18-36-35/job_return_value.aggregated.json",
258
+ "epochs=150": "logs/document_evaluation/multiruns/default/2025-01-15_16-02-04/job_return_value.aggregated.json",
259
+ "epochs=150,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_22-07-14/job_return_value.aggregated.json",
260
+ "epochs=300": "logs/document_evaluation/multiruns/default/2025-01-16_18-50-43/job_return_value.aggregated.json",
261
+ "epochs=300,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_23-12-52/job_return_value.aggregated.json",
262
+ },
263
+ x_axis="epochs",
264
+ default_key_values={"checkpoint": "best_val"},
265
+ description="data from https://github.com/ArneBinder/pie-document-level/issues/334",
266
+ ),
267
+ "joint model (relations) - last vs best val checkpoint @test": dict(
268
+ job_return_paths={
269
+ "epochs=75": "logs/document_evaluation/multiruns/default/2025-01-12_13-30-25/job_return_value.aggregated.json",
270
+ "epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_18-38-55/job_return_value.aggregated.json",
271
+ "epochs=150": "logs/document_evaluation/multiruns/default/2025-01-15_13-32-33/job_return_value.aggregated.json",
272
+ "epochs=150,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_22-08-43/job_return_value.aggregated.json",
273
+ "epochs=300": "logs/document_evaluation/multiruns/default/2025-01-11_16-42-17/job_return_value.aggregated.json",
274
+ "epochs=300,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_23-14-13/job_return_value.aggregated.json",
275
+ },
276
+ x_axis="epochs",
277
+ default_key_values={"checkpoint": "best_val"},
278
+ description="data from https://github.com/ArneBinder/pie-document-level/issues/334",
279
+ ),
280
+ "joint model (adus) - last vs best val checkpoint @val": dict(
281
+ job_return_paths={
282
+ "epochs=75": "logs/document_evaluation/multiruns/default/2025-01-17_17-13-46/job_return_value.aggregated.json",
283
+ "epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-17_22-26-17/job_return_value.aggregated.json",
284
+ "epochs=150": "logs/document_evaluation/multiruns/default/2025-01-17_19-41-17/job_return_value.aggregated.json",
285
+ "epochs=150,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-17_22-40-54/job_return_value.aggregated.json",
286
+ "epochs=300": "logs/document_evaluation/multiruns/default/2025-01-17_20-00-01/job_return_value.aggregated.json",
287
+ "epochs=300,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-17_22-51-51/job_return_value.aggregated.json",
288
+ },
289
+ x_axis="epochs",
290
+ default_key_values={"checkpoint": "best_val"},
291
+ description="data from https://github.com/ArneBinder/pie-document-level/issues/334",
292
+ ),
293
+ "joint model (relations) - last vs best val checkpoint @val": dict(
294
+ job_return_paths={
295
+ "epochs=75": "logs/document_evaluation/multiruns/default/2025-01-17_17-16-01/job_return_value.aggregated.json",
296
+ "epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-17_22-28-24/job_return_value.aggregated.json",
297
+ "epochs=150": "logs/document_evaluation/multiruns/default/2025-01-17_19-42-59/job_return_value.aggregated.json",
298
+ "epochs=150,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-17_22-41-19/job_return_value.aggregated.json",
299
+ "epochs=300": "logs/document_evaluation/multiruns/default/2025-01-17_20-01-16/job_return_value.aggregated.json",
300
+ "epochs=300,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-17_22-52-16/job_return_value.aggregated.json",
301
+ },
302
+ x_axis="epochs",
303
+ default_key_values={"checkpoint": "best_val"},
304
+ description="data from https://github.com/ArneBinder/pie-document-level/issues/334",
305
+ ),
306
+ "joint model (adus) - 27 vs 31 train docs @test": dict(
307
+ job_return_paths={
308
+ "num_train_docs=27,epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_18-36-35/job_return_value.aggregated.json",
309
+ "num_train_docs=31,epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_11-35-12/job_return_value.aggregated.json",
310
+ },
311
+ x_axis="num_train_docs",
312
+ description="data from https://github.com/ArneBinder/pie-document-level/issues/334",
313
+ ),
314
+ "joint model (relations) - 27 vs 31 train docs @test": dict(
315
+ job_return_paths={
316
+ "num_train_docs=27,epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_18-38-55/job_return_value.aggregated.json",
317
+ "num_train_docs=31,epochs=75,checkpoint=last": "logs/document_evaluation/multiruns/default/2025-01-16_11-36-52/job_return_value.aggregated.json",
318
+ },
319
+ x_axis="num_train_docs",
320
+ description="data from https://github.com/ArneBinder/pie-document-level/issues/334",
321
+ ),
322
+ }
323
+
324
+ DEFAULT_KWARGS = {
325
+ "column_level_names": ["split", "average", "metric", "aggr"],
326
+ "plot_column_level": "average",
327
+ "x_is_not_numeric": False,
328
+ "aggregation_column_level": "aggr",
329
+ "drop_columns": {"metric": "s"},
330
+ "show_as": "plots",
331
+ "markdown_round_precision": 3,
332
+ }
333
+
334
+ if __name__ == "__main__":
335
+
336
+ parser = argparse.ArgumentParser(
337
+ description="Compare multiple job results for predefined setups (see positional choice argument) "
338
+ "by creating plots or a markdown table."
339
+ )
340
+ parser.add_argument(
341
+ "config",
342
+ type=str,
343
+ help="config name (will also be the title prefix)",
344
+ choices=CONFIGS.keys(),
345
+ )
346
+ parser.add_argument(
347
+ "--column-level-names",
348
+ type=lambda x: x.strip().split(","),
349
+ help="comma separated list of column level names. Note that column levels are "
350
+ "created for each nesting level in the JSON data",
351
+ )
352
+ parser.add_argument("--plot-column-level", type=str, help="column level to create plots for")
353
+ parser.add_argument("--x-axis", type=str, help="column level to use as x-axis")
354
+ parser.add_argument("--x-is-not-numeric", help="set if x-axis is not numeric")
355
+ parser.add_argument(
356
+ "--aggregation-column-level",
357
+ type=str,
358
+ help="column level that contains the aggregation type (e.g. mean, std)",
359
+ )
360
+ parser.add_argument(
361
+ "--drop-columns",
362
+ type=lambda x: dict(part.split(":") for part in x.strip().split(",")),
363
+ help="a comma separated list of key-value pairs in the format level_name=level_value to "
364
+ "drop columns with the specific level values",
365
+ )
366
+ parser.add_argument(
367
+ "--show-as",
368
+ type=str,
369
+ help="show the data as 'plots' or 'markdown'",
370
+ )
371
+ parser.add_argument(
372
+ "--markdown-round-precision", type=int, help="round precision for show markdown"
373
+ )
374
+ args = parser.parse_args()
375
+
376
+ user_kwargs = vars(args)
377
+ config_name = user_kwargs.pop("config")
378
+
379
+ def get_marker(trace_meta):
380
+ checkpoint2marker_style = {"best_val": "circle-open", "last": "x"}
381
+ return dict(symbol=checkpoint2marker_style[trace_meta.get("checkpoint", "last")], size=12)
382
+
383
+ def get_line(trace_meta):
384
+ metric2checkpoint2color = {
385
+ "f1": {"best_val": "lightblue", "last": "blue"},
386
+ "f": {"best_val": "lightblue", "last": "blue"},
387
+ "p": {"best_val": "lightgreen", "last": "green"},
388
+ "r": {"best_val": "lightcoral", "last": "red"},
389
+ }
390
+ return dict(
391
+ color=metric2checkpoint2color[trace_meta["metric"]][
392
+ trace_meta.get("checkpoint", "last")
393
+ ],
394
+ width=2,
395
+ )
396
+
397
+ kwargs = {
398
+ "title_prefix": config_name,
399
+ "line_getter": get_line,
400
+ "marker_getter": get_marker,
401
+ **DEFAULT_KWARGS,
402
+ **CONFIGS[config_name],
403
+ }
404
+ for key, value in user_kwargs.items():
405
+ if value is not None:
406
+ kwargs[key] = value
407
+ combine_job_returns_and_plot(**kwargs)
src/data/acl_anthology_crawler.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+
3
+ root = pyrootutils.setup_root(
4
+ search_from=__file__,
5
+ indicator=[".project-root"],
6
+ pythonpath=True,
7
+ dotenv=True,
8
+ )
9
+
10
+ import os
11
+ from argparse import ArgumentParser, RawTextHelpFormatter
12
+ from dataclasses import dataclass, field
13
+ from pathlib import Path
14
+
15
+ from acl_anthology import Anthology
16
+ from tqdm import tqdm
17
+
18
+ from src.utils.pdf_utils.acl_anthology_utils import XML2RawPapers
19
+ from src.utils.pdf_utils.process_pdf import (
20
+ FulltextExtractor,
21
+ GrobidFulltextExtractor,
22
+ PDFDownloader,
23
+ )
24
+
25
+ HELP_MSG = """
26
+ Generate paper json files from an ACL Anthology collection, with fulltext extraction.
27
+
28
+ Iterate over entries in the ACL Anthology metadata, and for each entry:
29
+ 1. extract relevant paper info from the xml entry
30
+ 2. download pdf file
31
+ 3. extract fulltext
32
+ 4. format a json file and save
33
+
34
+ pre-requisites:
35
+ - Install the requirements: pip install acl-anthology-py>=0.4.3 bs4 jsonschema
36
+ - Get the meta data from ACL Anthology: git clone git@github.com:acl-org/acl-anthology.git
37
+ - Start Grobid Docker container: docker run --rm --init --ulimit core=0 -p 8070:8070 lfoppiano/grobid:0.8.0
38
+ """
39
+
40
+
41
+ @dataclass
42
+ class XML2Jsons:
43
+ base_output_dir: Path
44
+ pdf_output_dir: Path
45
+
46
+ xml2raw_papers: XML2RawPapers
47
+ pdf_downloader: PDFDownloader = field(default_factory=PDFDownloader)
48
+ fulltext_extractor: FulltextExtractor = field(default_factory=GrobidFulltextExtractor)
49
+ show_progress: bool = True
50
+
51
+ @classmethod
52
+ def from_cli(cls) -> "XML2Jsons":
53
+ parser = ArgumentParser(description=HELP_MSG, formatter_class=RawTextHelpFormatter)
54
+ parser.add_argument(
55
+ "--base-output-dir", type=str, help="Directory to save all the paper json files"
56
+ )
57
+ parser.add_argument(
58
+ "--pdf-output-dir", type=str, help="Directory to save all the downloaded pdf files"
59
+ )
60
+ parser.add_argument(
61
+ "--anthology-data-dir",
62
+ type=str,
63
+ help="Path to ACL Anthology metadata directory, e.g., /path/to/acl-anthology-repo/data. "
64
+ "You can obtain the data via: git clone git@github.com:acl-org/acl-anthology.git",
65
+ )
66
+ parser.add_argument(
67
+ "--collection-id-filters",
68
+ nargs="+",
69
+ type=str,
70
+ default=None,
71
+ help="If provided, only papers from the collections whose id (Anthology ID) contains the "
72
+ "specified strings will be processed.",
73
+ )
74
+ parser.add_argument(
75
+ "--venue-id-whitelist",
76
+ nargs="+",
77
+ type=str,
78
+ default=None,
79
+ help="If provided, only papers from the specified venues will be processed. See here for "
80
+ "the list of venues: https://aclanthology.org/venues",
81
+ )
82
+ args = parser.parse_args()
83
+
84
+ return cls(
85
+ base_output_dir=Path(args.base_output_dir),
86
+ pdf_output_dir=Path(args.pdf_output_dir),
87
+ xml2raw_papers=XML2RawPapers(
88
+ anthology=Anthology(datadir=args.anthology_data_dir),
89
+ collection_id_filters=args.collection_id_filters,
90
+ venue_id_whitelist=args.venue_id_whitelist,
91
+ ),
92
+ )
93
+
94
+ def run(self):
95
+ os.makedirs(self.pdf_output_dir, exist_ok=True)
96
+ papers = self.xml2raw_papers()
97
+ if self.show_progress:
98
+ papers = tqdm(list(papers), desc="extracting fulltext")
99
+ for paper in papers:
100
+ volume_dir = self.base_output_dir / paper.volume_id
101
+ if paper.url is not None:
102
+ pdf_save_path = self.pdf_downloader.download(
103
+ paper.url, opath=self.pdf_output_dir / f"{paper.name}.pdf"
104
+ )
105
+ fulltext_extraction_output = self.fulltext_extractor(pdf_save_path)
106
+
107
+ if fulltext_extraction_output:
108
+ plain_text, extraction_data = fulltext_extraction_output
109
+ paper.fulltext = extraction_data.get("sections")
110
+ if not paper.abstract:
111
+ paper.abstract = extraction_data.get("abstract")
112
+ paper.save(str(volume_dir))
113
+
114
+
115
+ if __name__ == "__main__":
116
+ xml2jsons = XML2Jsons.from_cli()
117
+ xml2jsons.run()
src/data/calc_iaa_for_brat.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Iterable
2
+
3
+ import pyrootutils
4
+ from pytorch_ie import Document
5
+
6
+ root = pyrootutils.setup_root(
7
+ search_from=__file__,
8
+ indicator=[".project-root"],
9
+ pythonpath=True,
10
+ dotenv=True,
11
+ )
12
+
13
+ import argparse
14
+ from functools import partial
15
+ from typing import Callable, List, Optional, Union
16
+
17
+ import pandas as pd
18
+ from pie_datasets import Dataset, load_dataset
19
+ from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans
20
+ from pie_modules.document.processing import RelationArgumentSorter, SpansViaRelationMerger
21
+ from pytorch_ie.metrics import F1Metric
22
+
23
+ from src.document.processing import align_predicted_span_annotations
24
+
25
+
26
+ def add_annotations_as_predictions(document: BratDocument, other: BratDocument) -> BratDocument:
27
+ document = document.copy()
28
+ other = other.copy()
29
+ document.spans.predictions.extend(other.spans.clear())
30
+ gold2gold_span_mapping = {span: span for span in document.spans}
31
+ predicted2maybe_gold_span = {}
32
+ for span in document.spans.predictions:
33
+ predicted2maybe_gold_span[span] = gold2gold_span_mapping.get(span, span)
34
+ predicted_relations = [
35
+ rel.copy(
36
+ head=predicted2maybe_gold_span[rel.head], tail=predicted2maybe_gold_span[rel.tail]
37
+ )
38
+ for rel in other.relations.clear()
39
+ ]
40
+ document.relations.predictions.extend(predicted_relations)
41
+ return document
42
+
43
+
44
+ def remove_annotations_existing_in_other(
45
+ document: BratDocumentWithMergedSpans, other: BratDocumentWithMergedSpans
46
+ ) -> BratDocumentWithMergedSpans:
47
+ result = document.copy(with_annotations=False)
48
+ document = document.copy()
49
+ other = other.copy()
50
+
51
+ spans = set(document.spans.clear()) - set(other.spans.clear())
52
+ relations = set(document.relations.clear()) - set(other.relations.clear())
53
+ result.spans.extend(spans)
54
+ result.relations.extend(relations)
55
+
56
+ return result
57
+
58
+
59
+ def unnest_dict(d):
60
+ result = {}
61
+ for key, value in d.items():
62
+ if isinstance(value, dict):
63
+ unnested = unnest_dict(value)
64
+ for k, v in unnested.items():
65
+ result[(key,) + k] = v
66
+ else:
67
+ result[(key,)] = value
68
+ return result
69
+
70
+
71
+ def calc_brat_iaas(
72
+ annotator_dirs: List[str],
73
+ ignore_annotation_dir: Optional[str] = None,
74
+ combine_fragmented_spans_via_relation: Optional[str] = None,
75
+ sort_arguments_of_relations: Optional[List[str]] = None,
76
+ align_spans: bool = False,
77
+ show_results: bool = False,
78
+ per_file: bool = False,
79
+ ) -> Union[pd.Series, List[pd.Series]]:
80
+ if len(annotator_dirs) < 2:
81
+ raise ValueError("At least two annotation dirs must be provided")
82
+
83
+ span_aligner = None
84
+ if align_spans:
85
+ span_aligner = partial(align_predicted_span_annotations, span_layer="spans")
86
+
87
+ if combine_fragmented_spans_via_relation is not None:
88
+ print(f"Combine fragmented spans via {combine_fragmented_spans_via_relation} relations")
89
+ merger = SpansViaRelationMerger(
90
+ relation_layer="relations",
91
+ link_relation_label=combine_fragmented_spans_via_relation,
92
+ create_multi_spans=True,
93
+ result_document_type=BratDocument,
94
+ result_field_mapping={"spans": "spans", "relations": "relations"},
95
+ )
96
+ else:
97
+ merger = None
98
+
99
+ if sort_arguments_of_relations is not None and len(sort_arguments_of_relations) > 0:
100
+ print(f"Sort arguments of relations with labels {sort_arguments_of_relations}")
101
+ relation_argument_sorter = RelationArgumentSorter(
102
+ relation_layer="relations",
103
+ label_whitelist=sort_arguments_of_relations, # ["parts_of_same", "semantically_same", "contradicts"],
104
+ )
105
+ else:
106
+ relation_argument_sorter = None
107
+
108
+ all_docs = [
109
+ load_dataset(
110
+ "pie/brat",
111
+ name="merge_fragmented_spans",
112
+ base_dataset_kwargs=dict(data_dir=annotation_dir),
113
+ split="train",
114
+ ).map(lambda doc: doc.deduplicate_annotations())
115
+ for annotation_dir in annotator_dirs
116
+ ]
117
+
118
+ if ignore_annotation_dir is not None:
119
+ print(f"Ignoring annotations loaded from {ignore_annotation_dir}")
120
+ ignore_annotation_docs = load_dataset(
121
+ "pie/brat",
122
+ name="merge_fragmented_spans",
123
+ base_dataset_kwargs=dict(data_dir=ignore_annotation_dir),
124
+ split="train",
125
+ )
126
+ ignore_annotation_docs_dict = {doc.id: doc for doc in ignore_annotation_docs}
127
+ all_docs = [
128
+ docs.map(
129
+ lambda doc: remove_annotations_existing_in_other(
130
+ doc, other=ignore_annotation_docs_dict[doc.id]
131
+ )
132
+ )
133
+ for docs in all_docs
134
+ ]
135
+
136
+ if relation_argument_sorter is not None:
137
+ all_docs = [docs.map(relation_argument_sorter) for docs in all_docs]
138
+
139
+ if per_file:
140
+ results_per_doc = []
141
+ for docs_tuple in zip(*all_docs):
142
+ if show_results:
143
+ print(f"\ncalculate scores for document id={docs_tuple[0].id} ...")
144
+ docs = [Dataset.from_documents([doc]) for doc in docs_tuple]
145
+ result_per_doc = calc_brat_iaas_for_docs(
146
+ docs, span_aligner=span_aligner, merger=merger, show_results=show_results
147
+ )
148
+ results_per_doc.append(result_per_doc)
149
+ return results_per_doc
150
+
151
+ else:
152
+ return calc_brat_iaas_for_docs(
153
+ all_docs, span_aligner=span_aligner, merger=merger, show_results=show_results
154
+ )
155
+
156
+
157
+ def calc_brat_iaas_for_docs(
158
+ all_docs: List[Dataset],
159
+ span_aligner: Optional[Callable] = None,
160
+ merger: Optional[Callable] = None,
161
+ show_results: bool = False,
162
+ ) -> pd.Series:
163
+ num_annotators = len(all_docs)
164
+ all_docs_dict = [{doc.id: doc for doc in docs} for docs in all_docs]
165
+ gold_predicted = {}
166
+ for gold_annotator_idx in range(num_annotators):
167
+ gold = all_docs[gold_annotator_idx]
168
+ for predicted_annotator_idx in range(num_annotators):
169
+ if gold_annotator_idx == predicted_annotator_idx:
170
+ continue
171
+ predicted_dict = all_docs_dict[predicted_annotator_idx]
172
+ gold_predicted[(gold_annotator_idx, predicted_annotator_idx)] = gold.map(
173
+ lambda doc: add_annotations_as_predictions(doc, other=predicted_dict[doc.id])
174
+ )
175
+
176
+ spans_metric = F1Metric(layer="spans", labels="INFERRED", show_as_markdown=True)
177
+ relations_metric = F1Metric(layer="relations", labels="INFERRED", show_as_markdown=True)
178
+
179
+ metric_values = {}
180
+ for gold_annotator_idx, predicted_annotator_idx in gold_predicted:
181
+ print(
182
+ f"calculate scores for annotations {gold_annotator_idx} -> {predicted_annotator_idx}"
183
+ )
184
+ for doc in gold_predicted[(gold_annotator_idx, predicted_annotator_idx)]:
185
+ if span_aligner is not None:
186
+ doc = span_aligner(doc)
187
+ if merger is not None:
188
+ doc = merger(doc)
189
+ spans_metric(doc)
190
+ relations_metric(doc)
191
+ metric_id = f"gold:{gold_annotator_idx},predicted:{predicted_annotator_idx}"
192
+ metric_values[metric_id] = {
193
+ "spans": spans_metric.compute(reset=True),
194
+ "relations": relations_metric.compute(reset=True),
195
+ }
196
+
197
+ result = pd.Series(unnest_dict(metric_values))
198
+ if show_results:
199
+ metric_values_series_mean = result.unstack(0).mean(axis=1)
200
+ metric_values_relations = metric_values_series_mean.xs("relations").unstack()
201
+ metric_values_spans = metric_values_series_mean.xs("spans").unstack()
202
+
203
+ print("\nspans:")
204
+ print(metric_values_spans.round(decimals=3).to_markdown())
205
+
206
+ print("\nrelations:")
207
+ print(metric_values_relations.round(decimals=3).to_markdown())
208
+
209
+ return result
210
+
211
+
212
+ if __name__ == "__main__":
213
+
214
+ """
215
+ example call:
216
+ python calc_iaa_for_brat.py \
217
+ --annotation-dirs annotations/sciarg/v0.9/with_abstracts_rin annotations/sciarg/v0.9/with_abstracts_alisa \
218
+ --ignore-annotation-dir annotations/sciarg/v0.9/original
219
+ """
220
+
221
+ parser = argparse.ArgumentParser(
222
+ description="Calculate inter-annotator agreement for spans and relations in means of F1 "
223
+ "(exact match, i.e. offsets / arguments and labels must match) for two or more BRAT "
224
+ "annotation directories."
225
+ )
226
+ parser.add_argument(
227
+ "--annotation-dirs",
228
+ type=str,
229
+ required=True,
230
+ nargs="+",
231
+ help="List of annotation directories. At least two directories must be provided.",
232
+ )
233
+ parser.add_argument(
234
+ "--ignore-annotation-dir",
235
+ type=str,
236
+ default=None,
237
+ help="If set, ignore annotations loaded from this directory.",
238
+ )
239
+ parser.add_argument(
240
+ "--combine-fragmented-spans-via-relation",
241
+ type=str,
242
+ default=None,
243
+ help="If set, combine fragmented spans via this relation type.",
244
+ )
245
+ parser.add_argument(
246
+ "--sort-arguments-of-relations",
247
+ type=str,
248
+ default=None,
249
+ nargs="+",
250
+ help="If set, sort the arguments of the relations with the given labels.",
251
+ )
252
+ parser.add_argument(
253
+ "--align-spans",
254
+ action="store_true",
255
+ help="If set, align the spans of the predicted annotations to the gold annotations.",
256
+ )
257
+ parser.add_argument(
258
+ "--per-file",
259
+ action="store_true",
260
+ help="If set, calculate IAA per file.",
261
+ )
262
+ args = parser.parse_args()
263
+
264
+ metric_values_series = calc_brat_iaas(
265
+ annotator_dirs=args.annotation_dirs,
266
+ ignore_annotation_dir=args.ignore_annotation_dir,
267
+ combine_fragmented_spans_via_relation=args.combine_fragmented_spans_via_relation,
268
+ sort_arguments_of_relations=args.sort_arguments_of_relations,
269
+ align_spans=args.align_spans,
270
+ per_file=args.per_file,
271
+ show_results=True,
272
+ )
src/data/construct_sciarg_abstracts_remaining_gold_retrieval.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+
3
+ root = pyrootutils.setup_root(
4
+ search_from=__file__,
5
+ indicator=[".project-root"],
6
+ pythonpath=True,
7
+ # dotenv=True,
8
+ )
9
+
10
+ import argparse
11
+ import logging
12
+ import os
13
+ from collections import defaultdict
14
+ from typing import List, Optional, Sequence, Tuple, TypeVar
15
+
16
+ import pandas as pd
17
+ from pie_datasets import load_dataset
18
+ from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans
19
+ from pytorch_ie.annotations import LabeledMultiSpan
20
+ from pytorch_ie.documents import (
21
+ TextBasedDocument,
22
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
23
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
24
+ )
25
+
26
+ from src.document.processing import replace_substrings_in_text_with_spaces
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def multi_span_is_in_span(multi_span: LabeledMultiSpan, range_span: Tuple[int, int]) -> bool:
32
+ start, end = range_span
33
+ starts, ends = zip(*multi_span.slices)
34
+ return start <= min(starts) and max(ends) <= end
35
+
36
+
37
+ def filter_multi_spans(
38
+ multi_spans: Sequence[LabeledMultiSpan], filter_span: Tuple[int, int]
39
+ ) -> List[LabeledMultiSpan]:
40
+ return [
41
+ span
42
+ for span in multi_spans
43
+ if multi_span_is_in_span(multi_span=span, range_span=filter_span)
44
+ ]
45
+
46
+
47
+ def shift_multi_span_slices(
48
+ slices: Sequence[Tuple[int, int]], shift: int
49
+ ) -> List[Tuple[int, int]]:
50
+ return [(start + shift, end + shift) for start, end in slices]
51
+
52
+
53
+ def construct_gold_retrievals(
54
+ doc: TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
55
+ symmetric_relations: Optional[List[str]] = None,
56
+ relation_label_whitelist: Optional[List[str]] = None,
57
+ ) -> Optional[pd.DataFrame]:
58
+ abstract_annotations = [
59
+ span for span in doc.labeled_partitions if span.label.lower().strip() == "abstract"
60
+ ]
61
+ if len(abstract_annotations) != 1:
62
+ logger.warning(
63
+ f"Expected exactly one abstract annotation, found {len(abstract_annotations)}"
64
+ )
65
+ return None
66
+ abstract_annotation = abstract_annotations[0]
67
+ span_abstract = (abstract_annotation.start, abstract_annotation.end)
68
+ span_remaining = (abstract_annotation.end, len(doc.text))
69
+ labeled_multi_spans = list(doc.labeled_multi_spans)
70
+ spans_in_abstract = set(
71
+ span for span in labeled_multi_spans if multi_span_is_in_span(span, span_abstract)
72
+ )
73
+ spans_in_remaining = set(
74
+ span for span in labeled_multi_spans if multi_span_is_in_span(span, span_remaining)
75
+ )
76
+ spans_not_covered = set(labeled_multi_spans) - spans_in_abstract - spans_in_remaining
77
+ if len(spans_not_covered) > 0:
78
+ logger.warning(
79
+ f"Found {len(spans_not_covered)} spans not covered by abstract or remaining text"
80
+ )
81
+
82
+ rel_arg_and_label2other = defaultdict(list)
83
+ for rel in doc.binary_relations:
84
+ rel_arg_and_label2other[rel.head].append((rel.tail, rel.label))
85
+ if symmetric_relations is not None and rel.label in symmetric_relations:
86
+ label_reversed = rel.label
87
+ else:
88
+ label_reversed = f"{rel.label}_reversed"
89
+ rel_arg_and_label2other[rel.tail].append((rel.head, label_reversed))
90
+
91
+ result_rows = []
92
+ for rel in doc.binary_relations:
93
+ # we check all semantically_same relations that point from (head) remaining to abstract (tail) ...
94
+ if rel.label == "semantically_same":
95
+ if rel.head in spans_in_abstract and rel.tail in spans_in_remaining:
96
+ # ... and if the head is
97
+ # candidate_query_span = rel.tail
98
+ candidate_spans_with_label = rel_arg_and_label2other[rel.tail]
99
+ for candidate_span, rel_label in candidate_spans_with_label:
100
+ if (
101
+ relation_label_whitelist is not None
102
+ and rel_label not in relation_label_whitelist
103
+ ):
104
+ continue
105
+ result_row = {
106
+ "doc_id": f"{doc.id}.remaining.{span_remaining[0]}.txt",
107
+ "query_doc_id": f"{doc.id}.abstract.{span_abstract[0]}_{span_abstract[1]}.txt",
108
+ "span": shift_multi_span_slices(candidate_span.slices, -span_remaining[0]),
109
+ "query_span": shift_multi_span_slices(rel.head.slices, -span_abstract[0]),
110
+ "ref_span": shift_multi_span_slices(rel.tail.slices, -span_remaining[0]),
111
+ "type": rel_label,
112
+ "label": candidate_span.label,
113
+ "ref_label": rel.tail.label,
114
+ }
115
+ result_rows.append(result_row)
116
+
117
+ if len(result_rows) > 0:
118
+ return pd.DataFrame(result_rows)
119
+ else:
120
+ return None
121
+
122
+
123
+ D_text = TypeVar("D_text", bound=TextBasedDocument)
124
+
125
+
126
+ def clean_doc(doc: D_text) -> D_text:
127
+ # remove xml tags. Note that we also remove the Abstract tag, in contrast to the preprocessing
128
+ # pipeline (see configs/dataset/sciarg_cleaned.yaml). This is because there, the abstracts are
129
+ # removed at completely.
130
+ doc = replace_substrings_in_text_with_spaces(
131
+ doc,
132
+ substrings=[
133
+ "</H2>",
134
+ "<H3>",
135
+ "</Document>",
136
+ "<H1>",
137
+ "<H2>",
138
+ "</H3>",
139
+ "</H1>",
140
+ "<Abstract>",
141
+ "</Abstract>",
142
+ ],
143
+ )
144
+ return doc
145
+
146
+
147
+ def main(
148
+ data_dir: str,
149
+ out_path: str,
150
+ doc_id_whitelist: Optional[List[str]] = None,
151
+ symmetric_relations: Optional[List[str]] = None,
152
+ relation_label_whitelist: Optional[List[str]] = None,
153
+ ) -> None:
154
+ logger.info(f"Loading dataset from {data_dir}")
155
+ sciarg_with_abstracts = load_dataset(
156
+ "pie/sciarg",
157
+ revision="171478ce3c13cc484be5d7c9bc8f66d7d2f1c210",
158
+ base_dataset_kwargs={"data_dir": data_dir, "split_paths": None},
159
+ name="resolve_parts_of_same",
160
+ split="train",
161
+ )
162
+ if issubclass(sciarg_with_abstracts.document_type, BratDocument):
163
+ ds_converted = sciarg_with_abstracts.to_document_type(
164
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
165
+ )
166
+ elif issubclass(sciarg_with_abstracts.document_type, BratDocumentWithMergedSpans):
167
+ ds_converted = sciarg_with_abstracts.to_document_type(
168
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
169
+ )
170
+ else:
171
+ raise ValueError(f"Unsupported document type {sciarg_with_abstracts.document_type}")
172
+
173
+ ds_clean = ds_converted.map(clean_doc)
174
+ if doc_id_whitelist is not None:
175
+ num_before = len(ds_clean)
176
+ ds_clean = [doc for doc in ds_clean if doc.id in doc_id_whitelist]
177
+ logger.info(
178
+ f"Filtered dataset from {num_before} to {len(ds_clean)} documents based on doc_id_whitelist"
179
+ )
180
+
181
+ results_per_doc = [
182
+ construct_gold_retrievals(
183
+ doc,
184
+ symmetric_relations=symmetric_relations,
185
+ relation_label_whitelist=relation_label_whitelist,
186
+ )
187
+ for doc in ds_clean
188
+ ]
189
+ results_per_doc_not_empty = [doc for doc in results_per_doc if doc is not None]
190
+ if len(results_per_doc_not_empty) > 0:
191
+ results = pd.concat(results_per_doc_not_empty, ignore_index=True)
192
+ # sort to make the output deterministic
193
+ results = results.sort_values(
194
+ by=results.columns.tolist(), ignore_index=True, key=lambda s: s.apply(str)
195
+ )
196
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
197
+ logger.info(f"Saving result ({len(results)}) to {out_path}")
198
+ results.to_json(out_path)
199
+ else:
200
+ logger.warning("No results found")
201
+
202
+
203
+ if __name__ == "__main__":
204
+ parser = argparse.ArgumentParser(
205
+ description="Create gold retrievals for SciArg-abstracts-remaining in the same format as the retrieval results"
206
+ )
207
+ parser.add_argument(
208
+ "--data_dir",
209
+ type=str,
210
+ default="data/annotations/sciarg-with-abstracts-and-cross-section-rels",
211
+ help="Path to the sciarg data directory",
212
+ )
213
+ parser.add_argument(
214
+ "--out_path",
215
+ type=str,
216
+ default="data/retrieval_results/sciarg-with-abstracts-and-cross-section-rels/gold.json",
217
+ help="Path to save the results",
218
+ )
219
+ parser.add_argument(
220
+ "--symmetric_relations",
221
+ type=str,
222
+ nargs="+",
223
+ default=None,
224
+ help="Relations that are symmetric, i.e., if A is related to B, then B is related to A",
225
+ )
226
+ parser.add_argument(
227
+ "--relation_label_whitelist",
228
+ type=str,
229
+ nargs="+",
230
+ default=None,
231
+ help="Only consider relations with these labels",
232
+ )
233
+
234
+ logging.basicConfig(level=logging.INFO)
235
+
236
+ kwargs = vars(parser.parse_args())
237
+ main(**kwargs)
238
+ logger.info("Done")
src/data/prepare_sciarg_crosssection_annotations.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import re
5
+ import shutil
6
+ from collections import defaultdict
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+ import pandas as pd
10
+ from pie_datasets import Dataset, IterableDataset, load_dataset
11
+ from pie_datasets.builders.brat import BratDocumentWithMergedSpans
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def find_span_idx(raw_text: str, span_string: str) -> Optional[List]:
17
+ """
18
+ Match span string to raw text (document).
19
+ Return either
20
+ 1) Tuple, 2) List of Tuples (more than one span match), or 3) empty List (no span match).
21
+ """
22
+ # remove possibly accidentally added white spaces
23
+ span_string.strip()
24
+ # use raw text input as regex-safe pattern
25
+ safe = re.escape(span_string)
26
+ pattern = rf"{safe}"
27
+ # find match(es)
28
+ out = [(s.start(), s.end()) for s in re.finditer(pattern, raw_text)]
29
+ return out
30
+
31
+
32
+ def append_spans_start_and_end(
33
+ raw_text: str,
34
+ pd_table: pd.DataFrame,
35
+ input_cols: List[str],
36
+ input_idx_cols: List[str],
37
+ output_cols: List[str],
38
+ doc_id_col: str = "doc ID",
39
+ ) -> pd.DataFrame:
40
+ """
41
+ Create new column(s) for span indexes (i.e. start and end as Tuple) in pd.DataFrame from span strings.
42
+ Warn if
43
+ 1) span string does not match anything in document -> None,
44
+ 2) span string is not unique in the document -> List[Tuple].
45
+ """
46
+ pd_table = pd_table.join(pd.DataFrame(columns=output_cols))
47
+ for idx, pd_row in pd_table.iterrows():
48
+ for in_col, idx_col, out_col in zip(input_cols, input_idx_cols, output_cols):
49
+ span_indices = find_span_idx(raw_text, pd_row[in_col])
50
+ str_idx = pd_row[idx_col]
51
+ span_idx = None
52
+ if span_indices is None or len(span_indices) == 0:
53
+ logger.warning(
54
+ f'The span "{pd_row[in_col]}" in Column "{in_col}" does not exist in {pd_row[doc_id_col]}.'
55
+ )
56
+ elif len(span_indices) == 1:
57
+ # warn if column is not empty, but span is unique
58
+ if str_idx == str_idx:
59
+ logger.warning(f'Column "{idx_col}" is not empty. It has value: {str_idx}.')
60
+ span_idx = span_indices.pop()
61
+ else:
62
+ # warn if span not unique, but column is empty
63
+ if str_idx != str_idx:
64
+ logger.warning(
65
+ f'The span "{pd_row[in_col]}" in Column "{in_col}" is not unique,'
66
+ f'but, column "{idx_col}" is empty. '
67
+ f"Need a string index to specify the non-unique span."
68
+ )
69
+ else:
70
+ span_idx = span_indices.pop(int(str_idx))
71
+
72
+ if span_idx is not None:
73
+ pd_table.at[idx, out_col] = span_idx
74
+
75
+ # sanity check (NOTE: this should live in a test)
76
+ search_string = pd_row[in_col]
77
+ reconstructed_string = raw_text[span_idx[0] : span_idx[1]]
78
+ if search_string != reconstructed_string:
79
+ raise ValueError(
80
+ f"Reconstructed string does not match the original string. "
81
+ f"Original: {search_string}, Reconstructed: {reconstructed_string}"
82
+ )
83
+ return pd_table
84
+
85
+
86
+ def get_texts_from_pie_dataset(
87
+ doc_ids: List[str], **dataset_kwargs
88
+ ) -> Dict[str, BratDocumentWithMergedSpans]:
89
+ """Get texts from a PIE dataset for a list of document IDs.
90
+
91
+ :param doc_ids: list of document IDs
92
+ :param dataset_kwargs: keyword arguments to pass to load_dataset
93
+
94
+ :return: a dictionary with document IDs as keys and texts as values
95
+ """
96
+
97
+ text_based_dataset = load_dataset(**dataset_kwargs)
98
+ if not isinstance(text_based_dataset, (Dataset, IterableDataset)):
99
+ raise ValueError(
100
+ f"Expected a PIE Dataset or PIE IterableDataset, but got a {type(text_based_dataset)} instead."
101
+ )
102
+ if not issubclass(text_based_dataset.document_type, BratDocumentWithMergedSpans):
103
+ raise ValueError(
104
+ f"Expected a PIE Dataset with BratDocumentWithMergedSpans as document type, "
105
+ f"but got {text_based_dataset.document_type} instead."
106
+ )
107
+ doc_id2text = {doc.id: doc for doc in text_based_dataset}
108
+ return {doc_id: doc_id2text[doc_id] for doc_id in doc_ids}
109
+
110
+
111
+ def set_span_annotation_ids(
112
+ table: pd.DataFrame,
113
+ doc_id2doc: Dict[str, BratDocumentWithMergedSpans],
114
+ doc_id_col: str,
115
+ span_idx_cols: List[str],
116
+ span_id_cols: List[str],
117
+ ) -> pd.DataFrame:
118
+ """
119
+ Create new column(s) for span annotation IDs in pd.DataFrame from span indexes. The annotation IDs are
120
+ retrieved from the TextBasedDocument object using the span indexes.
121
+
122
+ :param table: pd.DataFrame with span indexes, document IDs, and other columns
123
+ :param doc_id2doc: dictionary with document IDs as keys and BratDocumentWithMergedSpans objects as values
124
+ :param doc_id_col: column name that contains document IDs
125
+ :param span_idx_cols: column names that contain span indexes
126
+ :param span_id_cols: column names for new span ID columns
127
+
128
+ :return: pd.DataFrame with new columns for span annotation IDs
129
+ """
130
+ table = table.join(pd.DataFrame(columns=span_id_cols))
131
+ span2id: Dict[str, Dict[Tuple[int, int], str]] = defaultdict(dict)
132
+ for doc_id, doc in doc_id2doc.items():
133
+ for span_id, span in zip(doc.metadata["span_ids"], doc.spans):
134
+ span2id[doc_id][(span.start, span.end)] = span_id
135
+
136
+ for span_idx_col, span_id_col in zip(span_idx_cols, span_id_cols):
137
+ table[span_id_col] = table.apply(
138
+ lambda row: span2id[row[doc_id_col]][tuple(row[span_idx_col])], axis=1
139
+ )
140
+
141
+ return table
142
+
143
+
144
+ def set_relation_annotation_ids(
145
+ table: pd.DataFrame,
146
+ doc_id2doc: Dict[str, BratDocumentWithMergedSpans],
147
+ doc_id_col: str,
148
+ relation_id_col: str,
149
+ ) -> pd.DataFrame:
150
+ """create new column for relation annotation IDs in pd.DataFrame. They are simply new ids starting from the last
151
+ relation annotation id in the document.
152
+
153
+ :param table: pd.DataFrame with document IDs and other columns
154
+ :param doc_id2doc: dictionary with document IDs as keys and BratDocumentWithMergedSpans objects as values
155
+ :param doc_id_col: column name that contains document IDs
156
+ :param relation_id_col: column name for new relation ID column
157
+
158
+ :return: pd.DataFrame with new column for relation annotation IDs
159
+ """
160
+
161
+ table = table.join(pd.DataFrame(columns=[relation_id_col]))
162
+ doc_id2highest_relation_id = defaultdict(int)
163
+
164
+ for doc_id, doc in doc_id2doc.items():
165
+ # relation ids are prefixed with "R" in the dataset
166
+ doc_id2highest_relation_id[doc_id] = max(
167
+ [int(relation_id[1:]) for relation_id in doc.metadata["relation_ids"]]
168
+ )
169
+
170
+ for idx, row in table.iterrows():
171
+ doc_id = row[doc_id_col]
172
+ doc_id2highest_relation_id[doc_id] += 1
173
+ table.at[idx, relation_id_col] = f"R{doc_id2highest_relation_id[doc_id]}"
174
+
175
+ return table
176
+
177
+
178
+ def main(
179
+ input_path: str,
180
+ output_path: str,
181
+ brat_data_dir: str,
182
+ input_encoding: str,
183
+ include_unsure: bool = False,
184
+ doc_id_col: str = "doc ID",
185
+ unsure_col: str = "unsure",
186
+ span_str_cols: List[str] = ["head argument string", "tail argument string"],
187
+ str_idx_cols: List[str] = ["head string index", "tail string index"],
188
+ span_idx_cols: List[str] = ["head_span_idx", "tail_span_idx"],
189
+ span_id_cols: List[str] = ["head_span_id", "tail_span_id"],
190
+ relation_id_col: str = "relation_id",
191
+ set_annotation_ids: bool = False,
192
+ relation_type: str = "relation",
193
+ ) -> None:
194
+ """
195
+ Convert long dependency annotations from a CSV file to a JSON format. The input table should have
196
+ columns for document IDs, argument span strings, and string indexes (required in the case that the
197
+ span string occurs multiple times in the base text). The argument span strings are matched to the
198
+ base text to get the start and end indexes of the span. The output JSON file will have the same
199
+ columns as the input file, plus two additional columns for the start and end indexes of the spans.
200
+
201
+ :param input_path: path to a CSV/Excel file that contains annotations
202
+ :param output_path: path to save JSON output
203
+ :param brat_data_dir: directory where the BRAT data (base texts and annotations) is located
204
+ :param input_encoding: encoding of the input file. Only used for CSV files. Default: "cp1252"
205
+ :param include_unsure: include annotations marked as unsure. Default: False
206
+ :param doc_id_col: column name that contains document IDs. Default: "doc ID"
207
+ :param unsure_col: column name that contains unsure annotations. Default: "unsure"
208
+ :param span_str_cols: column names that contain span strings. Default: ["head argument string", "tail argument string"]
209
+ :param str_idx_cols: column names that contain string indexes. Default: ["head string index", "tail string index"]
210
+ :param span_idx_cols: column names for new span-index columns. Default: ["head_span_idx", "tail_span_idx"]
211
+ :param span_id_cols: column names for new span-ID columns. Default: ["head_span_id", "tail_span_id"]
212
+ :param relation_id_col: column name for new relation-ID column. Default: "relation_id"
213
+ :param set_annotation_ids: set annotation IDs for the spans and relations. Default: False
214
+ :param relation_type: specify the relation type for the BRAT output. Default: "relation"
215
+
216
+ :return: None
217
+ """
218
+ # get annotations from a csv file
219
+ if input_path.lower().endswith(".csv"):
220
+ input_df = pd.read_csv(input_path, encoding=input_encoding)
221
+ elif input_path.lower().endswith(".xlsx"):
222
+ logger.warning(
223
+ f"encoding parameter (--input-encoding={input_encoding}) is ignored for Excel files."
224
+ )
225
+ input_df = pd.read_excel(input_path)
226
+ else:
227
+ raise ValueError("Input file has unexpected format. Please provide a CSV or Excel file.")
228
+
229
+ # remove unsure
230
+ if not include_unsure:
231
+ input_df = input_df[input_df[unsure_col].isna()]
232
+ # remove all empty columns
233
+ input_df = input_df.dropna(axis=1, how="all")
234
+
235
+ # define output DataFrame
236
+ result_df = pd.DataFrame(columns=[*input_df.columns, *span_idx_cols])
237
+
238
+ # get unique document IDs
239
+ doc_ids = list(input_df[doc_id_col].unique())
240
+
241
+ # get base texts from a PIE SciArg dataset
242
+ doc_id2doc = get_texts_from_pie_dataset(
243
+ doc_ids=doc_ids,
244
+ path="pie/brat",
245
+ name="merge_fragmented_spans",
246
+ split="train",
247
+ revision="769a15e44e7d691148dd05e54ae2b058ceaed1f0",
248
+ base_dataset_kwargs=dict(data_dir=brat_data_dir),
249
+ )
250
+
251
+ for doc_id in doc_ids:
252
+
253
+ # iterate over each sub-df that contains annotations for a single document
254
+ doc_df = input_df[input_df[doc_id_col] == doc_id]
255
+ input_df = input_df.drop(doc_df.index)
256
+ # get spans' start and end indexes as new columns
257
+ doc_with_span_indices_df = append_spans_start_and_end(
258
+ raw_text=doc_id2doc[doc_id].text,
259
+ pd_table=doc_df,
260
+ input_cols=span_str_cols,
261
+ input_idx_cols=str_idx_cols,
262
+ output_cols=span_idx_cols,
263
+ )
264
+ # append this sub-df (with spans' indexes columns) to result_df
265
+ result_df = pd.concat(
266
+ [result_df if not result_df.empty else None, doc_with_span_indices_df]
267
+ )
268
+
269
+ out_ext = os.path.splitext(output_path)[1]
270
+ save_as_brat = out_ext == ""
271
+
272
+ if set_annotation_ids or save_as_brat:
273
+ result_df = set_span_annotation_ids(
274
+ table=result_df,
275
+ doc_id2doc=doc_id2doc,
276
+ doc_id_col=doc_id_col,
277
+ span_idx_cols=span_idx_cols,
278
+ span_id_cols=span_id_cols,
279
+ )
280
+ result_df = set_relation_annotation_ids(
281
+ table=result_df,
282
+ doc_id2doc=doc_id2doc,
283
+ doc_id_col=doc_id_col,
284
+ relation_id_col=relation_id_col,
285
+ )
286
+
287
+ base_dir = os.path.dirname(output_path)
288
+ os.makedirs(base_dir, exist_ok=True)
289
+
290
+ if out_ext.lower() == ".json":
291
+ logger.info(f"Saving output in JSON format to {output_path} ...")
292
+ result_df.to_json(
293
+ path_or_buf=output_path,
294
+ orient="records",
295
+ lines=True,
296
+ ) # possible orient values: 'split','index', 'table','records', 'columns', 'values'
297
+ elif save_as_brat:
298
+ logger.info(f"Saving output in BRAT format to {output_path} ...")
299
+ os.makedirs(output_path, exist_ok=True)
300
+ for doc_id in doc_ids:
301
+ # handle the base text file (just copy from the BRAT data directory)
302
+ shutil.copy(
303
+ src=os.path.join(brat_data_dir, f"{doc_id}.txt"),
304
+ dst=os.path.join(output_path, f"{doc_id}.txt"),
305
+ )
306
+
307
+ # handle the annotation file
308
+ # first, read the original annotation file
309
+ input_ann_path = os.path.join(brat_data_dir, f"{doc_id}.ann")
310
+ with open(input_ann_path, "r") as f:
311
+ ann_lines = f.readlines()
312
+ # then, append new relation annotations
313
+ # The format for each line is (see https://brat.nlplab.org/standoff.html):
314
+ # R{relation_id}\t{relation_type} Arg1:{span_id1} Arg2:{span_id2}
315
+ doc_df = result_df[result_df[doc_id_col] == doc_id]
316
+ logger.info(f"Adding {len(doc_df)} relation annotations to {doc_id}.ann ...")
317
+ for idx, row in doc_df.iterrows():
318
+ head_span_id = row[span_id_cols[0]]
319
+ tail_span_id = row[span_id_cols[1]]
320
+ relation_id = row[relation_id_col]
321
+ ann_line = (
322
+ f"{relation_id}\t{relation_type} Arg1:{head_span_id} Arg2:{tail_span_id}\n"
323
+ )
324
+ ann_lines.append(ann_line)
325
+ # finally, write the new annotation file
326
+ output_ann_path = os.path.join(output_path, f"{doc_id}.ann")
327
+ with open(output_ann_path, "w") as f:
328
+ f.writelines(ann_lines)
329
+ else:
330
+ raise ValueError(
331
+ "Output file has unexpected format. Please provide a JSON file or a directory."
332
+ )
333
+
334
+ logger.info("Done!")
335
+
336
+
337
+ if __name__ == "__main__":
338
+
339
+ """
340
+ example call:
341
+ python src/data/prepare_sciarg_crosssection_annotations.py
342
+ // or //
343
+ python src/data/prepare_sciarg_crosssection_annotations.py \
344
+ --input-path data/annotations/sciarg-cross-section/aligned_input.csv \
345
+ --output-path data/annotations/sciarg-with-abstracts-and-cross-section-rels \
346
+ --brat-data-dir data/annotations/sciarg-abstracts/v0.9.3/alisa
347
+ """
348
+
349
+ logging.basicConfig(level=logging.INFO)
350
+
351
+ parser = argparse.ArgumentParser(
352
+ description="Read text files in a directory and a CSV file that contains cross-section annotations. "
353
+ "Transform the CSV file to a JSON format and save at a specified output directory."
354
+ )
355
+ parser.add_argument(
356
+ "--input-path",
357
+ type=str,
358
+ default="data/annotations/sciarg-cross-section/aligned_input.csv",
359
+ help="Locate a CSV/Excel file.",
360
+ )
361
+ parser.add_argument(
362
+ "--output-path",
363
+ type=str,
364
+ default="data/annotations/sciarg-with-abstracts-and-cross-section-rels",
365
+ help="Specify a path where output will be saved. Should be a JSON file or a directory for BRAT output.",
366
+ )
367
+ parser.add_argument(
368
+ "--brat-data-dir",
369
+ type=str,
370
+ default="data/annotations/sciarg-abstracts/v0.9.3/alisa",
371
+ help="Specify the directory where the BRAT data (base texts and annotations) is located.",
372
+ )
373
+ parser.add_argument(
374
+ "--relation-type",
375
+ type=str,
376
+ default="semantically_same",
377
+ help="Specify the relation type for the BRAT output.",
378
+ )
379
+ parser.add_argument(
380
+ "--input-encoding",
381
+ type=str,
382
+ default="cp1252",
383
+ help="Specify encoding for reading an input file.",
384
+ )
385
+ parser.add_argument(
386
+ "--include-unsure",
387
+ action="store_true",
388
+ help="Include annotations marked as unsure.",
389
+ )
390
+ parser.add_argument(
391
+ "--set-annotation-ids",
392
+ action="store_true",
393
+ help="Set BRAT annotation IDs for the spans and relations.",
394
+ )
395
+ args = parser.parse_args()
396
+ kwargs = vars(args)
397
+
398
+ main(**kwargs)
src/data/split_sciarg_abstracts.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+
3
+ root = pyrootutils.setup_root(
4
+ search_from=__file__,
5
+ indicator=[".project-root"],
6
+ pythonpath=True,
7
+ dotenv=True,
8
+ )
9
+
10
+ import argparse
11
+ import logging
12
+ import os
13
+ from typing import List, Optional, TypeVar
14
+
15
+ from pie_datasets import load_dataset
16
+ from pie_datasets.builders.brat import BratDocument, BratDocumentWithMergedSpans
17
+ from pytorch_ie.documents import (
18
+ TextBasedDocument,
19
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
20
+ TextDocumentWithLabeledPartitions,
21
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
22
+ )
23
+
24
+ from src.document.processing import replace_substrings_in_text_with_spaces
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def save_abstract_and_remaining_text(
30
+ doc: TextDocumentWithLabeledPartitions, base_path: str
31
+ ) -> None:
32
+ abstract_annotations = [
33
+ span for span in doc.labeled_partitions if span.label.lower().strip() == "abstract"
34
+ ]
35
+ if len(abstract_annotations) != 1:
36
+ logger.warning(
37
+ f"Expected exactly one abstract annotation, found {len(abstract_annotations)}"
38
+ )
39
+ return
40
+ abstract_annotation = abstract_annotations[0]
41
+ text_abstract = doc.text[abstract_annotation.start : abstract_annotation.end]
42
+ text_remaining = doc.text[abstract_annotation.end :]
43
+ with open(
44
+ f"{base_path}.abstract.{abstract_annotation.start}_{abstract_annotation.end}.txt", "w"
45
+ ) as f:
46
+ f.write(text_abstract)
47
+ with open(f"{base_path}.remaining.{abstract_annotation.end}.txt", "w") as f:
48
+ f.write(text_remaining)
49
+
50
+
51
+ D_text = TypeVar("D_text", bound=TextBasedDocument)
52
+
53
+
54
+ def clean_doc(doc: D_text) -> D_text:
55
+ # remove xml tags. Note that we also remove the Abstract tag, in contrast to the preprocessing
56
+ # pipeline (see configs/dataset/sciarg_cleaned.yaml). This is because there, the abstracts are
57
+ # removed at completely.
58
+ doc = replace_substrings_in_text_with_spaces(
59
+ doc,
60
+ substrings=[
61
+ "</H2>",
62
+ "<H3>",
63
+ "</Document>",
64
+ "<H1>",
65
+ "<H2>",
66
+ "</H3>",
67
+ "</H1>",
68
+ "<Abstract>",
69
+ "</Abstract>",
70
+ ],
71
+ )
72
+ return doc
73
+
74
+
75
+ def main(out_dir: str, doc_id_whitelist: Optional[List[str]] = None) -> None:
76
+ logger.info("Loading dataset from pie/sciarg")
77
+ sciarg_with_abstracts = load_dataset(
78
+ "pie/sciarg",
79
+ revision="171478ce3c13cc484be5d7c9bc8f66d7d2f1c210",
80
+ split="train",
81
+ )
82
+ if issubclass(sciarg_with_abstracts.document_type, BratDocument):
83
+ ds_converted = sciarg_with_abstracts.to_document_type(
84
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
85
+ )
86
+ elif issubclass(sciarg_with_abstracts.document_type, BratDocumentWithMergedSpans):
87
+ ds_converted = sciarg_with_abstracts.to_document_type(
88
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
89
+ )
90
+ else:
91
+ raise ValueError(f"Unsupported document type {sciarg_with_abstracts.document_type}")
92
+
93
+ ds_clean = ds_converted.map(clean_doc)
94
+ if doc_id_whitelist is not None:
95
+ num_before = len(ds_clean)
96
+ ds_clean = [doc for doc in ds_clean if doc.id in doc_id_whitelist]
97
+ logger.info(
98
+ f"Filtered dataset from {num_before} to {len(ds_clean)} documents based on doc_id_whitelist"
99
+ )
100
+
101
+ os.makedirs(out_dir, exist_ok=True)
102
+ logger.info(f"Saving dataset to {out_dir}")
103
+ for doc in ds_clean:
104
+ save_abstract_and_remaining_text(doc, os.path.join(out_dir, doc.id))
105
+
106
+
107
+ if __name__ == "__main__":
108
+ parser = argparse.ArgumentParser(
109
+ description="Split SciArg dataset into abstract and remaining text"
110
+ )
111
+ parser.add_argument(
112
+ "--out_dir",
113
+ type=str,
114
+ default="data/datasets/sciarg/abstracts_and_remaining_text",
115
+ help="Path to save the split data",
116
+ )
117
+ parser.add_argument(
118
+ "--doc_id_whitelist",
119
+ type=str,
120
+ nargs="+",
121
+ default=["A32", "A33", "A34", "A35", "A36", "A37", "A38", "A39", "A40"],
122
+ help="List of document ids to include in the split",
123
+ )
124
+
125
+ logging.basicConfig(level=logging.INFO)
126
+
127
+ kwargs = vars(parser.parse_args())
128
+ # allow for "all" to include all documents
129
+ if len(kwargs["doc_id_whitelist"]) == 1 and kwargs["doc_id_whitelist"][0].lower() == "all":
130
+ kwargs["doc_id_whitelist"] = None
131
+ main(**kwargs)
132
+ logger.info("Done")
src/demo/annotation_utils.py CHANGED
@@ -1,7 +1,9 @@
 
1
  import logging
2
- from typing import Optional, Sequence, Union
3
 
4
  import gradio as gr
 
5
  from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
6
 
7
  # this is required to dynamically load the PIE models
@@ -10,7 +12,6 @@ from pie_modules.taskmodules import * # noqa: F403
10
  from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
11
  from pytorch_ie import Pipeline
12
  from pytorch_ie.annotations import LabeledSpan
13
- from pytorch_ie.auto import AutoPipeline
14
  from pytorch_ie.documents import (
15
  TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
16
  TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
@@ -20,9 +21,25 @@ from pytorch_ie.documents import (
20
  from pytorch_ie.models import * # noqa: F403
21
  from pytorch_ie.taskmodules import * # noqa: F403
22
 
 
 
23
  logger = logging.getLogger(__name__)
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def annotate_document(
27
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
28
  argumentation_model: Pipeline,
@@ -40,23 +57,40 @@ def annotate_document(
40
  """
41
 
42
  # execute prediction pipeline
43
- argumentation_model(document)
 
 
44
 
45
  if handle_parts_of_same:
46
- merger = SpansViaRelationMerger(
47
- relation_layer="binary_relations",
48
- link_relation_label="parts_of_same",
49
- create_multi_spans=True,
50
- result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
51
- result_field_mapping={
52
- "labeled_spans": "labeled_multi_spans",
53
- "binary_relations": "binary_relations",
54
- "labeled_partitions": "labeled_partitions",
55
- },
56
- )
57
- document = merger(document)
58
 
59
- return document
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
  def create_document(
@@ -88,32 +122,45 @@ def create_document(
88
  return document
89
 
90
 
91
- def load_argumentation_model(
92
- model_name: str,
93
- revision: Optional[str] = None,
94
- device: str = "cpu",
95
- ) -> Pipeline:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  try:
97
- # the Pipeline class expects an integer for the device
98
- if device == "cuda":
99
- pipeline_device = 0
100
- elif device.startswith("cuda:"):
101
- pipeline_device = int(device.split(":")[1])
102
- elif device == "cpu":
103
- pipeline_device = -1
104
- else:
105
- raise gr.Error(f"Invalid device: {device}")
106
-
107
- model = AutoPipeline.from_pretrained(
108
- model_name,
109
- device=pipeline_device,
110
- num_workers=0,
111
- taskmodule_kwargs=dict(revision=revision),
112
- model_kwargs=dict(revision=revision),
113
- )
114
- gr.Info(
115
- f"Loaded argumentation model: model_name={model_name}, revision={revision}, device={device}"
116
- )
117
  except Exception as e:
118
  raise gr.Error(f"Failed to load argumentation model: {e}")
119
 
 
1
+ import json
2
  import logging
3
+ from typing import Iterable, Optional, Sequence, Union
4
 
5
  import gradio as gr
6
+ from hydra.utils import instantiate
7
  from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
8
 
9
  # this is required to dynamically load the PIE models
 
12
  from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
13
  from pytorch_ie import Pipeline
14
  from pytorch_ie.annotations import LabeledSpan
 
15
  from pytorch_ie.documents import (
16
  TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
17
  TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
 
21
  from pytorch_ie.models import * # noqa: F403
22
  from pytorch_ie.taskmodules import * # noqa: F403
23
 
24
+ from src.utils import parse_config
25
+
26
  logger = logging.getLogger(__name__)
27
 
28
 
29
+ def get_merger() -> SpansViaRelationMerger:
30
+ return SpansViaRelationMerger(
31
+ relation_layer="binary_relations",
32
+ link_relation_label="parts_of_same",
33
+ create_multi_spans=True,
34
+ result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
35
+ result_field_mapping={
36
+ "labeled_spans": "labeled_multi_spans",
37
+ "binary_relations": "binary_relations",
38
+ "labeled_partitions": "labeled_partitions",
39
+ },
40
+ )
41
+
42
+
43
  def annotate_document(
44
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
45
  argumentation_model: Pipeline,
 
57
  """
58
 
59
  # execute prediction pipeline
60
+ result: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions = argumentation_model(
61
+ document, inplace=True
62
+ )
63
 
64
  if handle_parts_of_same:
65
+ merger = get_merger()
66
+ result = merger(result)
 
 
 
 
 
 
 
 
 
 
67
 
68
+ return result
69
+
70
+
71
+ def annotate_documents(
72
+ documents: Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions],
73
+ argumentation_model: Pipeline,
74
+ handle_parts_of_same: bool = False,
75
+ ) -> Union[
76
+ Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions],
77
+ Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions],
78
+ ]:
79
+ """Annotate a sequence of documents with the provided pipeline.
80
+
81
+ Args:
82
+ documents: The documents to annotate.
83
+ argumentation_model: The pipeline to use for annotation.
84
+ handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
85
+ """
86
+ # execute prediction pipeline
87
+ result = argumentation_model(documents, inplace=True)
88
+
89
+ if handle_parts_of_same:
90
+ merger = get_merger()
91
+ result = [merger(document) for document in result]
92
+
93
+ return result
94
 
95
 
96
  def create_document(
 
122
  return document
123
 
124
 
125
+ def create_documents(
126
+ texts: Iterable[str], doc_ids: Iterable[str], split_regex: Optional[str] = None
127
+ ) -> Sequence[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
128
+ """Create a sequence of TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
129
+ texts.
130
+
131
+ Parameters:
132
+ texts: The texts to process.
133
+ doc_ids: The IDs of the documents.
134
+ split_regex: A regular expression pattern to use for splitting the text into partitions.
135
+
136
+ Returns:
137
+ The processed documents.
138
+ """
139
+ return [
140
+ create_document(text=text, doc_id=doc_id, split_regex=split_regex)
141
+ for text, doc_id in zip(texts, doc_ids)
142
+ ]
143
+
144
+
145
+ def load_argumentation_model(config_str: str, **kwargs) -> Pipeline:
146
  try:
147
+ config = parse_config(config_str, format="yaml")
148
+
149
+ # for PIE AutoPipeline, we need to handle the revision separately for
150
+ # the taskmodule and the model
151
+ if (
152
+ config.get("_target_") == "pytorch_ie.auto.AutoPipeline.from_pretrained"
153
+ and "revision" in config
154
+ ):
155
+ revision = config.pop("revision")
156
+ if "taskmodule_kwargs" not in config:
157
+ config["taskmodule_kwargs"] = {}
158
+ config["taskmodule_kwargs"]["revision"] = revision
159
+ if "model_kwargs" not in config:
160
+ config["model_kwargs"] = {}
161
+ config["model_kwargs"]["revision"] = revision
162
+ model = instantiate(config, **kwargs)
163
+ gr.Info(f"Loaded argumentation model: {json.dumps({**config, **kwargs})}")
 
 
 
164
  except Exception as e:
165
  raise gr.Error(f"Failed to load argumentation model: {e}")
166
 
src/demo/backend_utils.py CHANGED
@@ -2,17 +2,20 @@ import json
2
  import logging
3
  import os
4
  import tempfile
 
5
  from typing import Iterable, List, Optional, Sequence
6
 
7
  import gradio as gr
8
  import pandas as pd
 
9
  from pie_datasets import Dataset, IterableDataset, load_dataset
10
  from pytorch_ie import Pipeline
11
  from pytorch_ie.documents import (
12
  TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
13
  )
 
14
 
15
- from src.demo.annotation_utils import annotate_document, create_document
16
  from src.demo.data_utils import load_text_from_arxiv
17
  from src.demo.rendering_utils import (
18
  RENDER_WITH_DISPLACY,
@@ -25,6 +28,8 @@ from src.langchain_modules import (
25
  DocumentAwareSpanRetriever,
26
  DocumentAwareSpanRetrieverWithRelations,
27
  )
 
 
28
 
29
  logger = logging.getLogger(__name__)
30
 
@@ -58,20 +63,18 @@ def process_texts(
58
  # check that doc_ids are unique
59
  if len(set(doc_ids)) != len(list(doc_ids)):
60
  raise gr.Error("Document IDs must be unique.")
61
- pie_documents = [
62
- create_document(text=text, doc_id=doc_id, split_regex=split_regex_escaped)
63
- for text, doc_id in zip(texts, doc_ids)
64
- ]
 
65
  if verbose:
66
  gr.Info(f"Annotate {len(pie_documents)} documents...")
67
- pie_documents = [
68
- annotate_document(
69
- document=pie_document,
70
- argumentation_model=argumentation_model,
71
- handle_parts_of_same=handle_parts_of_same,
72
- )
73
- for pie_document in pie_documents
74
- ]
75
  add_annotated_pie_documents(
76
  retriever=retriever,
77
  pie_documents=pie_documents,
@@ -140,6 +143,94 @@ def process_uploaded_files(
140
  return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
141
 
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  def wrapped_add_annotated_pie_documents_from_dataset(
144
  retriever: DocumentAwareSpanRetriever, verbose: bool, layer_captions: dict[str, str], **kwargs
145
  ) -> pd.DataFrame:
@@ -193,6 +284,7 @@ def render_annotated_document(
193
  document_id: str,
194
  render_with: str,
195
  render_kwargs_json: str,
 
196
  ) -> str:
197
  text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document(
198
  retriever=retriever, document_id=document_id
@@ -213,6 +305,7 @@ def render_annotated_document(
213
  spans=spans,
214
  span_id2idx=span_id2idx,
215
  binary_relations=relations,
 
216
  **render_kwargs,
217
  )
218
  else:
 
2
  import logging
3
  import os
4
  import tempfile
5
+ from pathlib import Path
6
  from typing import Iterable, List, Optional, Sequence
7
 
8
  import gradio as gr
9
  import pandas as pd
10
+ from acl_anthology import Anthology
11
  from pie_datasets import Dataset, IterableDataset, load_dataset
12
  from pytorch_ie import Pipeline
13
  from pytorch_ie.documents import (
14
  TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
15
  )
16
+ from tqdm import tqdm
17
 
18
+ from src.demo.annotation_utils import annotate_documents, create_documents
19
  from src.demo.data_utils import load_text_from_arxiv
20
  from src.demo.rendering_utils import (
21
  RENDER_WITH_DISPLACY,
 
28
  DocumentAwareSpanRetriever,
29
  DocumentAwareSpanRetrieverWithRelations,
30
  )
31
+ from src.utils.pdf_utils.acl_anthology_utils import XML2RawPapers
32
+ from src.utils.pdf_utils.process_pdf import FulltextExtractor, PDFDownloader
33
 
34
  logger = logging.getLogger(__name__)
35
 
 
63
  # check that doc_ids are unique
64
  if len(set(doc_ids)) != len(list(doc_ids)):
65
  raise gr.Error("Document IDs must be unique.")
66
+ pie_documents = create_documents(
67
+ texts=texts,
68
+ doc_ids=doc_ids,
69
+ split_regex=split_regex_escaped,
70
+ )
71
  if verbose:
72
  gr.Info(f"Annotate {len(pie_documents)} documents...")
73
+ pie_documents = annotate_documents(
74
+ documents=pie_documents,
75
+ argumentation_model=argumentation_model,
76
+ handle_parts_of_same=handle_parts_of_same,
77
+ )
 
 
 
78
  add_annotated_pie_documents(
79
  retriever=retriever,
80
  pie_documents=pie_documents,
 
143
  return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
144
 
145
 
146
+ def process_uploaded_pdf_files(
147
+ pdf_fulltext_extractor: Optional[FulltextExtractor],
148
+ file_names: List[str],
149
+ retriever: DocumentAwareSpanRetriever,
150
+ layer_captions: dict[str, str],
151
+ **kwargs,
152
+ ) -> pd.DataFrame:
153
+ try:
154
+ if pdf_fulltext_extractor is None:
155
+ raise gr.Error("PDF fulltext extractor is not available.")
156
+ doc_ids = []
157
+ texts = []
158
+ for file_name in file_names:
159
+ if file_name.lower().endswith(".pdf"):
160
+ # extract the fulltext from the pdf
161
+ text_and_extraction_data = pdf_fulltext_extractor(file_name)
162
+ if text_and_extraction_data is None:
163
+ raise gr.Error(f"Failed to extract fulltext from PDF: {file_name}")
164
+ text, _ = text_and_extraction_data
165
+
166
+ base_file_name = os.path.basename(file_name)
167
+ doc_ids.append(base_file_name)
168
+ texts.append(text)
169
+
170
+ else:
171
+ raise gr.Error(f"Unsupported file format: {file_name}")
172
+ process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs)
173
+ except Exception as e:
174
+ raise gr.Error(f"Failed to process uploaded files: {e}")
175
+
176
+ return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
177
+
178
+
179
+ def load_acl_anthology_venues(
180
+ venues: List[str],
181
+ pdf_fulltext_extractor: Optional[FulltextExtractor],
182
+ retriever: DocumentAwareSpanRetriever,
183
+ layer_captions: dict[str, str],
184
+ acl_anthology_data_dir: Optional[str],
185
+ pdf_output_dir: Optional[str],
186
+ show_progress: bool = True,
187
+ **kwargs,
188
+ ) -> pd.DataFrame:
189
+ try:
190
+ if pdf_fulltext_extractor is None:
191
+ raise gr.Error("PDF fulltext extractor is not available.")
192
+ if acl_anthology_data_dir is None:
193
+ raise gr.Error("ACL Anthology data directory is not provided.")
194
+ if pdf_output_dir is None:
195
+ raise gr.Error("PDF output directory is not provided.")
196
+ xml2raw_papers = XML2RawPapers(
197
+ anthology=Anthology(datadir=Path(acl_anthology_data_dir)),
198
+ venue_id_whitelist=venues,
199
+ verbose=False,
200
+ )
201
+ pdf_downloader = PDFDownloader()
202
+ doc_ids = []
203
+ texts = []
204
+ os.makedirs(pdf_output_dir, exist_ok=True)
205
+ papers = xml2raw_papers()
206
+ if show_progress:
207
+ papers_list = list(papers)
208
+ papers = tqdm(papers_list, desc="extracting fulltext")
209
+ gr.Info(
210
+ f"Downloading and extracting fulltext from {len(papers_list)} papers in venues: {venues}"
211
+ )
212
+ for paper in papers:
213
+ if paper.url is not None:
214
+ pdf_save_path = pdf_downloader.download(
215
+ paper.url, opath=Path(pdf_output_dir) / f"{paper.name}.pdf"
216
+ )
217
+ fulltext_extraction_output = pdf_fulltext_extractor(pdf_save_path)
218
+
219
+ if fulltext_extraction_output:
220
+ text, _ = fulltext_extraction_output
221
+ doc_id = f"aclanthology.org/{paper.name}"
222
+ doc_ids.append(doc_id)
223
+ texts.append(text)
224
+ else:
225
+ gr.Warning(f"Failed to extract fulltext from PDF: {paper.url}")
226
+
227
+ process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs)
228
+ except Exception as e:
229
+ raise gr.Error(f"Failed to process uploaded files: {e}")
230
+
231
+ return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
232
+
233
+
234
  def wrapped_add_annotated_pie_documents_from_dataset(
235
  retriever: DocumentAwareSpanRetriever, verbose: bool, layer_captions: dict[str, str], **kwargs
236
  ) -> pd.DataFrame:
 
284
  document_id: str,
285
  render_with: str,
286
  render_kwargs_json: str,
287
+ highlight_span_ids: Optional[List[str]] = None,
288
  ) -> str:
289
  text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document(
290
  retriever=retriever, document_id=document_id
 
305
  spans=spans,
306
  span_id2idx=span_id2idx,
307
  binary_relations=relations,
308
+ highlight_span_ids=highlight_span_ids,
309
  **render_kwargs,
310
  )
311
  else:
src/demo/frontend_utils.py CHANGED
@@ -24,6 +24,18 @@ def close_accordion():
24
  return gr.Accordion(open=False)
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def change_tab(id: Union[int, str]):
28
  return gr.Tabs(selected=id)
29
 
 
24
  return gr.Accordion(open=False)
25
 
26
 
27
+ def open_accordion_with_stats(
28
+ overview: pd.DataFrame, base_label: str, caption2column: dict[str, str], total_column: str
29
+ ):
30
+ caption2value = {
31
+ caption: len(overview) if column == total_column else overview[column].sum()
32
+ for caption, column in caption2column.items()
33
+ }
34
+ stats_str = ", ".join([f"{value} {caption}" for caption, value in caption2value.items()])
35
+ label = f"{base_label} ({stats_str})"
36
+ return gr.Accordion(open=True, label=label)
37
+
38
+
39
  def change_tab(id: Union[int, str]):
40
  return gr.Tabs(selected=id)
41
 
src/demo/rendering_utils.py CHANGED
@@ -15,7 +15,7 @@ AVAILABLE_RENDER_MODES = [RENDER_WITH_DISPLACY, RENDER_WITH_PRETTY_TABLE]
15
 
16
  # adjusted from rendering_utils_displacy.TPL_ENT
17
  TPL_ENT_WITH_ID = """
18
- <mark class="entity" data-entity-id="{entity_id}" data-slice-idx="{slice_idx}" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
19
  {text}
20
  <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
21
  </mark>
@@ -31,8 +31,12 @@ HIGHLIGHT_SPANS_JS = """
31
  color = colors[colorDictKey];
32
  } catch (e) {}
33
  if (color) {
 
 
34
  entity.style.backgroundColor = color;
35
  entity.style.color = '#000';
 
 
36
  }
37
  }
38
 
@@ -42,6 +46,8 @@ HIGHLIGHT_SPANS_JS = """
42
  entities.forEach(entity => {
43
  const color = entity.getAttribute('data-color-original');
44
  entity.style.backgroundColor = color;
 
 
45
  entity.style.color = '';
46
  });
47
 
@@ -171,6 +177,7 @@ def render_displacy(
171
  spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
172
  span_id2idx: Dict[str, int],
173
  binary_relations: Sequence[BinaryRelation],
 
174
  inject_relations=True,
175
  colors_hover=None,
176
  entity_options={},
@@ -180,6 +187,9 @@ def render_displacy(
180
  ents: List[Dict[str, Any]] = []
181
  for entity_id, idx in span_id2idx.items():
182
  labeled_span = spans[idx]
 
 
 
183
  # pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
184
  # on hover and to inject the relation data.
185
  if isinstance(labeled_span, LabeledSpan):
@@ -188,7 +198,11 @@ def render_displacy(
188
  "start": labeled_span.start,
189
  "end": labeled_span.end,
190
  "label": labeled_span.label,
191
- "params": {"entity_id": entity_id, "slice_idx": 0},
 
 
 
 
192
  }
193
  )
194
  elif isinstance(labeled_span, LabeledMultiSpan):
@@ -198,7 +212,11 @@ def render_displacy(
198
  "start": start,
199
  "end": end,
200
  "label": labeled_span.label,
201
- "params": {"entity_id": entity_id, "slice_idx": i},
 
 
 
 
202
  }
203
  )
204
  else:
@@ -254,7 +272,9 @@ def inject_relation_data(
254
  entities = soup.find_all(class_="entity")
255
  for entity in entities:
256
  original_color = entity["style"].split("background:")[1].split(";")[0].strip()
 
257
  entity["data-color-original"] = original_color
 
258
  if additional_colors is not None:
259
  for key, color in additional_colors.items():
260
  entity[f"data-color-{key}"] = (
 
15
 
16
  # adjusted from rendering_utils_displacy.TPL_ENT
17
  TPL_ENT_WITH_ID = """
18
+ <mark class="entity" data-entity-id="{entity_id}" data-slice-idx="{slice_idx}" data-highlight-mode="{highlight_mode}" style="background: {bg}; border-width: {border_width}; border-color: {border_color}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
19
  {text}
20
  <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
21
  </mark>
 
31
  color = colors[colorDictKey];
32
  } catch (e) {}
33
  if (color) {
34
+ //const highlightMode = entity.getAttribute('data-highlight-mode');
35
+ //if (highlightMode === 'fill') {
36
  entity.style.backgroundColor = color;
37
  entity.style.color = '#000';
38
+ //}
39
+ entity.style.borderColor = color;
40
  }
41
  }
42
 
 
46
  entities.forEach(entity => {
47
  const color = entity.getAttribute('data-color-original');
48
  entity.style.backgroundColor = color;
49
+ const borderColor = entity.getAttribute('data-border-color-original');
50
+ entity.style.borderColor = borderColor;
51
  entity.style.color = '';
52
  });
53
 
 
177
  spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
178
  span_id2idx: Dict[str, int],
179
  binary_relations: Sequence[BinaryRelation],
180
+ highlight_span_ids: Optional[List[str]] = None,
181
  inject_relations=True,
182
  colors_hover=None,
183
  entity_options={},
 
187
  ents: List[Dict[str, Any]] = []
188
  for entity_id, idx in span_id2idx.items():
189
  labeled_span = spans[idx]
190
+ highlight_mode = (
191
+ "fill" if highlight_span_ids is None or entity_id in highlight_span_ids else "border"
192
+ )
193
  # pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
194
  # on hover and to inject the relation data.
195
  if isinstance(labeled_span, LabeledSpan):
 
198
  "start": labeled_span.start,
199
  "end": labeled_span.end,
200
  "label": labeled_span.label,
201
+ "params": {
202
+ "entity_id": entity_id,
203
+ "slice_idx": 0,
204
+ "highlight_mode": highlight_mode,
205
+ },
206
  }
207
  )
208
  elif isinstance(labeled_span, LabeledMultiSpan):
 
212
  "start": start,
213
  "end": end,
214
  "label": labeled_span.label,
215
+ "params": {
216
+ "entity_id": entity_id,
217
+ "slice_idx": i,
218
+ "highlight_mode": highlight_mode,
219
+ },
220
  }
221
  )
222
  else:
 
272
  entities = soup.find_all(class_="entity")
273
  for entity in entities:
274
  original_color = entity["style"].split("background:")[1].split(";")[0].strip()
275
+ original_border_color = entity["style"].split("border-color:")[1].split(";")[0].strip()
276
  entity["data-color-original"] = original_color
277
+ entity["data-border-color-original"] = original_border_color
278
  if additional_colors is not None:
279
  for key, color in additional_colors.items():
280
  entity[f"data-color-{key}"] = (
src/demo/rendering_utils_displacy.py CHANGED
@@ -200,7 +200,18 @@ class EntityRenderer(object):
200
  markup += "<br/>"
201
  if self.ents is None or label.upper() in self.ents:
202
  color = self.colors.get(label.upper(), self.default_color)
203
- ent_settings = {"label": label, "text": entity, "bg": color}
 
 
 
 
 
 
 
 
 
 
 
204
  ent_settings.update(additional_params)
205
  markup += self.ent_template.format(**ent_settings)
206
  else:
 
200
  markup += "<br/>"
201
  if self.ents is None or label.upper() in self.ents:
202
  color = self.colors.get(label.upper(), self.default_color)
203
+ ent_settings = {"label": label, "text": entity}
204
+ highlight_mode = additional_params.get("highlight_mode", "fill")
205
+ if highlight_mode == "fill":
206
+ ent_settings["bg"] = color
207
+ ent_settings["border_width"] = "0px"
208
+ ent_settings["border_color"] = color
209
+ elif highlight_mode == "border":
210
+ ent_settings["bg"] = "inherit"
211
+ ent_settings["border_width"] = "2px"
212
+ ent_settings["border_color"] = color
213
+ else:
214
+ raise ValueError(f"Invalid highlight_mode: {highlight_mode}")
215
  ent_settings.update(additional_params)
216
  markup += self.ent_template.format(**ent_settings)
217
  else:
src/demo/retrieve_and_dump_all_relevant.py CHANGED
@@ -9,6 +9,9 @@ root = pyrootutils.setup_root(
9
 
10
  import argparse
11
  import logging
 
 
 
12
 
13
  from src.demo.retriever_utils import (
14
  retrieve_all_relevant_spans,
@@ -55,6 +58,29 @@ if __name__ == "__main__":
55
  default=None,
56
  help="If provided, retrieve all spans for only this query span.",
57
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  args = parser.parse_args()
59
 
60
  logging.basicConfig(
@@ -74,9 +100,41 @@ if __name__ == "__main__":
74
  retriever.load_from_disc(args.data_path)
75
 
76
  search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
 
 
 
 
77
  logger.info(f"use search_kwargs: {search_kwargs}")
78
 
79
- if args.query_span_id is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  logger.warning(f"retrieving results for single span: {args.query_span_id}")
81
  all_spans_for_all_documents = retrieve_relevant_spans(
82
  retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
@@ -95,7 +153,8 @@ if __name__ == "__main__":
95
  logger.warning("no relevant spans found in any document")
96
  exit(0)
97
 
98
- logger.info(f"dumping results to {args.output_path}...")
 
99
  all_spans_for_all_documents.to_json(args.output_path)
100
 
101
  logger.info("done")
 
9
 
10
  import argparse
11
  import logging
12
+ import os
13
+
14
+ import pandas as pd
15
 
16
  from src.demo.retriever_utils import (
17
  retrieve_all_relevant_spans,
 
58
  default=None,
59
  help="If provided, retrieve all spans for only this query span.",
60
  )
61
+ parser.add_argument(
62
+ "--doc_id_whitelist",
63
+ type=str,
64
+ nargs="+",
65
+ default=None,
66
+ help="If provided, only consider documents with these IDs.",
67
+ )
68
+ parser.add_argument(
69
+ "--doc_id_blacklist",
70
+ type=str,
71
+ nargs="+",
72
+ default=None,
73
+ help="If provided, ignore documents with these IDs.",
74
+ )
75
+ parser.add_argument(
76
+ "--query_target_doc_id_pairs",
77
+ type=str,
78
+ nargs="+",
79
+ default=None,
80
+ help="One or more pairs of query and target document IDs "
81
+ '(each separated by ":") to retrieve spans for. If provided, '
82
+ "--query_doc_id and --query_span_id are ignored.",
83
+ )
84
  args = parser.parse_args()
85
 
86
  logging.basicConfig(
 
100
  retriever.load_from_disc(args.data_path)
101
 
102
  search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
103
+ if args.doc_id_whitelist is not None:
104
+ search_kwargs["doc_id_whitelist"] = args.doc_id_whitelist
105
+ if args.doc_id_blacklist is not None:
106
+ search_kwargs["doc_id_blacklist"] = args.doc_id_blacklist
107
  logger.info(f"use search_kwargs: {search_kwargs}")
108
 
109
+ if args.query_target_doc_id_pairs is not None:
110
+ all_spans_for_all_documents = None
111
+ for doc_id_pair in args.query_target_doc_id_pairs:
112
+ query_doc_id, target_doc_id = doc_id_pair.split(":")
113
+ current_result = retrieve_all_relevant_spans(
114
+ retriever=retriever,
115
+ query_doc_id=query_doc_id,
116
+ doc_id_whitelist=[target_doc_id],
117
+ **search_kwargs,
118
+ )
119
+ if current_result is None:
120
+ logger.warning(
121
+ f"no relevant spans found for query_doc_id={query_doc_id} and "
122
+ f"target_doc_id={target_doc_id}"
123
+ )
124
+ continue
125
+ logger.info(
126
+ f"retrieved {len(current_result)} spans for query_doc_id={query_doc_id} "
127
+ f"and target_doc_id={target_doc_id}"
128
+ )
129
+ current_result["query_doc_id"] = query_doc_id
130
+ if all_spans_for_all_documents is None:
131
+ all_spans_for_all_documents = current_result
132
+ else:
133
+ all_spans_for_all_documents = pd.concat(
134
+ [all_spans_for_all_documents, current_result], ignore_index=True
135
+ )
136
+
137
+ elif args.query_span_id is not None:
138
  logger.warning(f"retrieving results for single span: {args.query_span_id}")
139
  all_spans_for_all_documents = retrieve_relevant_spans(
140
  retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
 
153
  logger.warning("no relevant spans found in any document")
154
  exit(0)
155
 
156
+ logger.info(f"dumping results ({len(all_spans_for_all_documents)}) to {args.output_path}...")
157
+ os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
158
  all_spans_for_all_documents.to_json(args.output_path)
159
 
160
  logger.info("done")
src/demo/retriever_utils.py CHANGED
@@ -8,10 +8,8 @@ from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
8
  from typing_extensions import Protocol
9
 
10
  from src.langchain_modules import DocumentAwareSpanRetriever
11
- from src.langchain_modules.span_retriever import (
12
- DocumentAwareSpanRetrieverWithRelations,
13
- _parse_config,
14
- )
15
 
16
  logger = logging.getLogger(__name__)
17
 
@@ -22,13 +20,13 @@ def get_document_as_dict(retriever: DocumentAwareSpanRetriever, doc_id: str) ->
22
 
23
 
24
  def load_retriever(
25
- retriever_config_str: str,
26
  config_format: str,
27
  device: str = "cpu",
28
  previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None,
29
  ) -> DocumentAwareSpanRetrieverWithRelations:
30
  try:
31
- retriever_config = _parse_config(retriever_config_str, format=config_format)
32
  # set device for the embeddings pipeline
33
  retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device
34
  result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config)
@@ -153,6 +151,7 @@ def _retrieve_for_all_spans(
153
  query_doc_id: str,
154
  retrieve_func: RetrieverCallable,
155
  query_span_id_column: str = "query_span_id",
 
156
  **kwargs,
157
  ) -> Optional[pd.DataFrame]:
158
  if not query_doc_id.strip():
@@ -177,6 +176,9 @@ def _retrieve_for_all_spans(
177
  # add column with query_span_id
178
  for query_span_id, query_span_result in span_results_not_empty.items():
179
  query_span_result[query_span_id_column] = query_span_id
 
 
 
180
 
181
  if len(span_results_not_empty) == 0:
182
  gr.Info(f"No results found for any ADU in document {query_doc_id}.")
 
8
  from typing_extensions import Protocol
9
 
10
  from src.langchain_modules import DocumentAwareSpanRetriever
11
+ from src.langchain_modules.span_retriever import DocumentAwareSpanRetrieverWithRelations
12
+ from src.utils import parse_config
 
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
 
20
 
21
 
22
  def load_retriever(
23
+ config_str: str,
24
  config_format: str,
25
  device: str = "cpu",
26
  previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None,
27
  ) -> DocumentAwareSpanRetrieverWithRelations:
28
  try:
29
+ retriever_config = parse_config(config_str, format=config_format)
30
  # set device for the embeddings pipeline
31
  retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device
32
  result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config)
 
151
  query_doc_id: str,
152
  retrieve_func: RetrieverCallable,
153
  query_span_id_column: str = "query_span_id",
154
+ query_span_text_column: Optional[str] = None,
155
  **kwargs,
156
  ) -> Optional[pd.DataFrame]:
157
  if not query_doc_id.strip():
 
176
  # add column with query_span_id
177
  for query_span_id, query_span_result in span_results_not_empty.items():
178
  query_span_result[query_span_id_column] = query_span_id
179
+ if query_span_text_column is not None:
180
+ query_span = retriever.get_span_by_id(span_id=query_span_id)
181
+ query_span_result[query_span_text_column] = str(query_span)
182
 
183
  if len(span_results_not_empty) == 0:
184
  gr.Info(f"No results found for any ADU in document {query_doc_id}.")
src/document/processing.py CHANGED
@@ -1,13 +1,17 @@
1
  from __future__ import annotations
2
 
3
  import logging
4
- from typing import Any, Dict, Iterable, List, Sequence, Set, Tuple, TypeVar, Union
5
 
6
- import networkx as nx
7
- from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
8
  from pie_modules.documents import TextDocumentWithLabeledMultiSpansAndBinaryRelations
 
9
  from pytorch_ie import AnnotationLayer
10
  from pytorch_ie.core import Document
 
 
 
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
@@ -64,76 +68,7 @@ def remove_overlapping_entities(
64
  return new_doc
65
 
66
 
67
- def _merge_spans_via_relation(
68
- spans: Sequence[LabeledSpan],
69
- relations: Sequence[BinaryRelation],
70
- link_relation_label: str,
71
- create_multi_spans: bool = True,
72
- ) -> Tuple[Union[Set[LabeledSpan], Set[LabeledMultiSpan]], Set[BinaryRelation]]:
73
- # convert list of relations to a graph to easily calculate connected components to merge
74
- g = nx.Graph()
75
- link_relations = []
76
- other_relations = []
77
- for rel in relations:
78
- if rel.label == link_relation_label:
79
- link_relations.append(rel)
80
- # never merge spans that have not the same label
81
- if (
82
- not (isinstance(rel.head, LabeledSpan) or isinstance(rel.tail, LabeledSpan))
83
- or rel.head.label == rel.tail.label
84
- ):
85
- g.add_edge(rel.head, rel.tail)
86
- else:
87
- logger.debug(
88
- f"spans to merge do not have the same label, do not merge them: {rel.head}, {rel.tail}"
89
- )
90
- else:
91
- other_relations.append(rel)
92
-
93
- span_mapping = {}
94
- connected_components: Set[LabeledSpan]
95
- for connected_components in nx.connected_components(g):
96
- # all spans in a connected component have the same label
97
- label = list(span.label for span in connected_components)[0]
98
- connected_components_sorted = sorted(connected_components, key=lambda span: span.start)
99
- if create_multi_spans:
100
- new_span = LabeledMultiSpan(
101
- slices=tuple((span.start, span.end) for span in connected_components_sorted),
102
- label=label,
103
- )
104
- else:
105
- new_span = LabeledSpan(
106
- start=min(span.start for span in connected_components_sorted),
107
- end=max(span.end for span in connected_components_sorted),
108
- label=label,
109
- )
110
- for span in connected_components_sorted:
111
- span_mapping[span] = new_span
112
- for span in spans:
113
- if span not in span_mapping:
114
- if create_multi_spans:
115
- span_mapping[span] = LabeledMultiSpan(
116
- slices=((span.start, span.end),), label=span.label, score=span.score
117
- )
118
- else:
119
- span_mapping[span] = LabeledSpan(
120
- start=span.start, end=span.end, label=span.label, score=span.score
121
- )
122
-
123
- new_spans = set(span_mapping.values())
124
- new_relations = set(
125
- BinaryRelation(
126
- head=span_mapping[rel.head],
127
- tail=span_mapping[rel.tail],
128
- label=rel.label,
129
- score=rel.score,
130
- )
131
- for rel in other_relations
132
- )
133
-
134
- return new_spans, new_relations
135
-
136
-
137
  def merge_spans_via_relation(
138
  document: D,
139
  relation_layer: str,
@@ -186,15 +121,50 @@ def merge_spans_via_relation(
186
 
187
 
188
  def remove_partitions_by_labels(
189
- document: D, partition_layer: str, label_blacklist: List[str]
190
  ) -> D:
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  document = document.copy()
192
- layer: AnnotationLayer = document[partition_layer]
193
  new_partitions = []
194
- for partition in layer.clear():
195
  if partition.label not in label_blacklist:
196
  new_partitions.append(partition)
197
- layer.extend(new_partitions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  return document
199
 
200
 
@@ -221,3 +191,168 @@ def replace_substrings_in_text(
221
  def replace_substrings_in_text_with_spaces(document: D_text, substrings: Iterable[str]) -> D_text:
222
  replacements = {substring: " " * len(substring) for substring in substrings}
223
  return replace_substrings_in_text(document, replacements=replacements)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import logging
4
+ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
5
 
6
+ from pie_modules.document.processing.merge_spans_via_relation import _merge_spans_via_relation
 
7
  from pie_modules.documents import TextDocumentWithLabeledMultiSpansAndBinaryRelations
8
+ from pie_modules.utils.span import have_overlap
9
  from pytorch_ie import AnnotationLayer
10
  from pytorch_ie.core import Document
11
+ from pytorch_ie.core.document import Annotation, _enumerate_dependencies
12
+
13
+ from src.utils import distance
14
+ from src.utils.span_utils import get_overlap_len
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
68
  return new_doc
69
 
70
 
71
+ # TODO: remove and use pie_modules.document.processing.SpansViaRelationMerger instead
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def merge_spans_via_relation(
73
  document: D,
74
  relation_layer: str,
 
121
 
122
 
123
  def remove_partitions_by_labels(
124
+ document: D, partition_layer: str, label_blacklist: List[str], span_layer: Optional[str] = None
125
  ) -> D:
126
+ """Remove partitions with labels in the blacklist from a document.
127
+
128
+ Args:
129
+ document: The document to process.
130
+ partition_layer: The name of the partition layer.
131
+ label_blacklist: The list of labels to remove.
132
+ span_layer: The name of the span layer to remove spans from if they are not fully
133
+ contained in any remaining partition. Any dependent annotations will be removed as well.
134
+
135
+ Returns:
136
+ The processed document.
137
+ """
138
+
139
  document = document.copy()
140
+ p_layer: AnnotationLayer = document[partition_layer]
141
  new_partitions = []
142
+ for partition in p_layer.clear():
143
  if partition.label not in label_blacklist:
144
  new_partitions.append(partition)
145
+ p_layer.extend(new_partitions)
146
+
147
+ if span_layer is not None:
148
+ result = document.copy(with_annotations=False)
149
+ removed_span_ids = set()
150
+ for span in document[span_layer]:
151
+ # keep spans fully contained in any partition
152
+ if any(
153
+ partition.start <= span.start and span.end <= partition.end
154
+ for partition in new_partitions
155
+ ):
156
+ result[span_layer].append(span.copy())
157
+ else:
158
+ removed_span_ids.add(span._id)
159
+
160
+ result.add_all_annotations_from_other(
161
+ document,
162
+ removed_annotations={span_layer: removed_span_ids},
163
+ strict=False,
164
+ verbose=False,
165
+ )
166
+ document = result
167
+
168
  return document
169
 
170
 
 
191
  def replace_substrings_in_text_with_spaces(document: D_text, substrings: Iterable[str]) -> D_text:
192
  replacements = {substring: " " * len(substring) for substring in substrings}
193
  return replace_substrings_in_text(document, replacements=replacements)
194
+
195
+
196
+ def relabel_annotations(
197
+ document: D,
198
+ label_mapping: Dict[str, Dict[str, str]],
199
+ ) -> D:
200
+ """
201
+ Replace annotation labels in a document.
202
+
203
+ Args:
204
+ document: The document to process.
205
+ label_mapping: A mapping from layer names to mappings from old labels to new labels.
206
+
207
+ Returns:
208
+ The processed document.
209
+
210
+ """
211
+
212
+ dependency_ordered_fields: List[str] = []
213
+ _enumerate_dependencies(
214
+ dependency_ordered_fields,
215
+ dependency_graph=document._annotation_graph,
216
+ nodes=document._annotation_graph["_artificial_root"],
217
+ )
218
+ result = document.copy(with_annotations=False)
219
+ store: Dict[int, Annotation] = {}
220
+ # not yet used
221
+ invalid_annotation_ids: Set[int] = set()
222
+ for field_name in dependency_ordered_fields:
223
+ if field_name in document._annotation_fields:
224
+ layer = document[field_name]
225
+ for is_prediction, anns in [(False, layer), (True, layer.predictions)]:
226
+ for ann in anns:
227
+ new_ann = ann.copy_with_store(
228
+ override_annotation_store=store,
229
+ invalid_annotation_ids=invalid_annotation_ids,
230
+ )
231
+ if field_name in label_mapping:
232
+ if ann.label in label_mapping[field_name]:
233
+ new_label = label_mapping[field_name][ann.label]
234
+ new_ann = new_ann.copy(label=new_label)
235
+ else:
236
+ raise ValueError(
237
+ f"Label {ann.label} not found in label mapping for {field_name}"
238
+ )
239
+ store[ann._id] = new_ann
240
+ target_layer = result[field_name]
241
+ if is_prediction:
242
+ target_layer.predictions.append(new_ann)
243
+ else:
244
+ target_layer.append(new_ann)
245
+
246
+ return result
247
+
248
+
249
+ DWithSpans = TypeVar("DWithSpans", bound=Document)
250
+
251
+
252
+ def align_predicted_span_annotations(
253
+ document: DWithSpans, span_layer: str, distance_type: str = "center", verbose: bool = False
254
+ ) -> DWithSpans:
255
+ """
256
+ Aligns predicted span annotations with the closest gold spans in a document.
257
+
258
+ First, calculates the distance between each predicted span and each gold span. Then,
259
+ for each predicted span, the gold span with the smallest distance is selected. If the
260
+ predicted span and the gold span have an overlap of at least half of the maximum length
261
+ of the two spans, the predicted span is aligned with the gold span.
262
+
263
+ Args:
264
+ document: The document to process.
265
+ span_layer: The name of the span layer.
266
+ distance_type: The type of distance to calculate. One of: center, inner, outer
267
+ verbose: Whether to print debug information.
268
+
269
+ Returns:
270
+ The processed document.
271
+ """
272
+ gold_spans = document[span_layer]
273
+ if len(gold_spans) == 0:
274
+ return document.copy()
275
+
276
+ pred_spans = document[span_layer].predictions
277
+ old2new_pred_span = {}
278
+ span_id2gold_span = {}
279
+ for pred_span in pred_spans:
280
+
281
+ gold_spans_with_distance = [
282
+ (
283
+ gold_span,
284
+ distance(
285
+ start_end=(pred_span.start, pred_span.end),
286
+ other_start_end=(gold_span.start, gold_span.end),
287
+ distance_type=distance_type,
288
+ ),
289
+ )
290
+ for gold_span in gold_spans
291
+ ]
292
+
293
+ closest_gold_span, min_distance = min(gold_spans_with_distance, key=lambda x: x[1])
294
+ # if the closest gold span is the same as the predicted span, we don't need to align
295
+ if min_distance == 0.0:
296
+ continue
297
+
298
+ if have_overlap(
299
+ start_end=(pred_span.start, pred_span.end),
300
+ other_start_end=(closest_gold_span.start, closest_gold_span.end),
301
+ ):
302
+ overlap_len = get_overlap_len(
303
+ (pred_span.start, pred_span.end), (closest_gold_span.start, closest_gold_span.end)
304
+ )
305
+ # get the maximum length of the two spans
306
+ l_max = max(
307
+ pred_span.end - pred_span.start, closest_gold_span.end - closest_gold_span.start
308
+ )
309
+ # if the overlap is at least half of the maximum length, we consider it a valid match for alignment
310
+ valid_match = overlap_len >= (l_max / 2)
311
+ else:
312
+ valid_match = False
313
+
314
+ if valid_match:
315
+ aligned_pred_span = pred_span.copy(
316
+ start=closest_gold_span.start, end=closest_gold_span.end
317
+ )
318
+ old2new_pred_span[pred_span._id] = aligned_pred_span
319
+ span_id2gold_span[pred_span._id] = closest_gold_span
320
+
321
+ result = document.copy(with_annotations=False)
322
+
323
+ # multiple predicted spans can be aligned with the same gold span,
324
+ # so we need to keep track of the added spans
325
+ added_pred_span_ids = dict()
326
+ for pred_span in pred_spans:
327
+ # just add the predicted span if it was not aligned with a gold span
328
+ if pred_span._id not in old2new_pred_span:
329
+ # if this was not added before (e.g. as aligned span), add it
330
+ if pred_span._id not in added_pred_span_ids:
331
+ keep_pred_span = pred_span.copy()
332
+ result[span_layer].predictions.append(keep_pred_span)
333
+ added_pred_span_ids[pred_span._id] = keep_pred_span
334
+ elif verbose:
335
+ print(f"Skipping duplicate predicted span. pred_span='{str(pred_span)}'")
336
+ else:
337
+ aligned_pred_span = old2new_pred_span[pred_span._id]
338
+ # if this was not added before (e.g. as aligned or original pred span), add it
339
+ if aligned_pred_span._id not in added_pred_span_ids:
340
+ result[span_layer].predictions.append(aligned_pred_span)
341
+ added_pred_span_ids[aligned_pred_span._id] = aligned_pred_span
342
+ elif verbose:
343
+ prev_pred_span = added_pred_span_ids[aligned_pred_span._id]
344
+ gold_span = span_id2gold_span[pred_span._id]
345
+ print(
346
+ f"Skipping duplicate aligned predicted span. aligned gold_span='{str(gold_span)}', "
347
+ f"prev_pred_span='{str(prev_pred_span)}', current_pred_span='{str(pred_span)}'"
348
+ )
349
+ # print("bbb")
350
+
351
+ result[span_layer].extend([span.copy() for span in gold_spans])
352
+
353
+ # add remaining gold and predicted spans (the result, _aligned_spans, is just for debugging)
354
+ _aligned_spans = result.add_all_annotations_from_other(
355
+ document, override_annotations={span_layer: old2new_pred_span}
356
+ )
357
+
358
+ return result
src/hydra_callbacks/save_job_return_value.py CHANGED
@@ -3,7 +3,7 @@ import logging
3
  import os
4
  import pickle
5
  from pathlib import Path
6
- from typing import Any, Dict, Generator, List, Tuple, Union
7
 
8
  import numpy as np
9
  import pandas as pd
@@ -33,38 +33,48 @@ def to_py_obj(obj):
33
  def list_of_dicts_to_dict_of_lists_recursive(list_of_dicts):
34
  """Convert a list of dicts to a dict of lists recursively.
35
 
36
- Example:
37
  # works with nested dicts
38
  >>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": {"c": 2}}, {"a": 3, "b": {"c": 4}}])
39
- {'b': {'c': [2, 4]}, 'a': [1, 3]}
40
  # works with incomplete dicts
41
  >>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": 2}, {"a": 3}])
42
- {'b': [2, None], 'a': [1, 3]}
 
 
 
 
 
 
 
 
43
 
44
  Args:
45
  list_of_dicts (List[dict]): A list of dicts.
46
 
47
  Returns:
48
- dict: A dict of lists.
49
  """
50
- if isinstance(list_of_dicts, list):
51
- if len(list_of_dicts) == 0:
52
- return {}
53
- elif isinstance(list_of_dicts[0], dict):
54
- keys = set()
55
- for d in list_of_dicts:
56
- if not isinstance(d, dict):
57
- raise ValueError("Not all elements of the list are dicts.")
 
58
  keys.update(d.keys())
59
- return {
60
- k: list_of_dicts_to_dict_of_lists_recursive(
61
- [d.get(k, None) for d in list_of_dicts]
62
- )
63
- for k in keys
64
- }
65
- else:
66
- return list_of_dicts
67
  else:
 
68
  return list_of_dicts
69
 
70
 
@@ -77,20 +87,50 @@ def _flatten_dict_gen(d, parent_key: Tuple[str, ...] = ()) -> Generator:
77
  yield new_key, v
78
 
79
 
80
- def flatten_dict(d: Dict[str, Any]) -> Dict[Tuple[str, ...], Any]:
81
- return dict(_flatten_dict_gen(d))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
- def unflatten_dict(d: Dict[Tuple[str, ...], Any]) -> Union[Dict[str, Any], Any]:
85
- """Unflattens a dictionary with nested keys.
 
 
 
86
 
87
  Example:
88
  >>> d = {("a", "b", "c"): 1, ("a", "b", "d"): 2, ("a", "e"): 3}
89
  >>> unflatten_dict(d)
90
  {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
 
 
 
 
 
91
  """
92
  result: Dict[str, Any] = {}
93
  for k, v in d.items():
 
 
94
  if len(k) == 0:
95
  if len(result) > 1:
96
  raise ValueError("Cannot unflatten dictionary with multiple root keys.")
@@ -152,30 +192,82 @@ class SaveJobReturnValueCallback(Callback):
152
  nested), where the keys are the keys of the job return-values and the values are lists of the corresponding
153
  values of all jobs. This is useful if you want to access specific values of all jobs in a multi-run all at once.
154
  Also, aggregated values (e.g. mean, min, max) are created for all numeric values and saved in another file.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  """
156
 
157
  def __init__(
158
  self,
159
  filenames: Union[str, List[str]] = "job_return_value.json",
160
  integrate_multirun_result: bool = False,
 
 
 
 
 
 
 
 
161
  ) -> None:
162
  self.log = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
163
  self.filenames = [filenames] if isinstance(filenames, str) else filenames
164
  self.integrate_multirun_result = integrate_multirun_result
165
  self.job_returns: List[JobReturn] = []
 
 
 
 
 
 
 
 
166
 
167
  def on_job_end(self, config: DictConfig, job_return: JobReturn, **kwargs: Any) -> None:
168
  self.job_returns.append(job_return)
169
  output_dir = Path(config.hydra.runtime.output_dir) # / Path(config.hydra.output_subdir)
 
 
 
 
 
170
  for filename in self.filenames:
171
  self._save(obj=job_return.return_value, filename=filename, output_dir=output_dir)
172
 
173
  def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
 
 
 
 
 
 
174
  if self.integrate_multirun_result:
175
  # rearrange the job return-values of all jobs from a multi-run into a dict of lists (maybe nested),
176
  obj = list_of_dicts_to_dict_of_lists_recursive(
177
  [jr.return_value for jr in self.job_returns]
178
  )
 
 
 
 
 
179
  # also create an aggregated result
180
  # convert to python object to allow selecting numeric columns
181
  obj_py = to_py_obj(obj)
@@ -195,6 +287,11 @@ class SaveJobReturnValueCallback(Callback):
195
  else:
196
  # aggregate the numeric values
197
  df_described = df_numbers_only.describe()
 
 
 
 
 
198
  # add the aggregation keys (e.g. mean, min, ...) as most inner keys and convert back to dict
199
  obj_flat_aggregated = df_described.T.stack().to_dict()
200
  # unflatten because _save() works better with nested dicts
@@ -202,25 +299,46 @@ class SaveJobReturnValueCallback(Callback):
202
  else:
203
  # create a dict of the job return-values of all jobs from a multi-run
204
  # (_save() works better with nested dicts)
205
- ids = overrides_to_identifiers([jr.overrides for jr in self.job_returns])
206
- obj = {identifier: jr.return_value for identifier, jr in zip(ids, self.job_returns)}
 
207
  obj_aggregated = None
208
  output_dir = Path(config.hydra.sweep.dir)
 
 
 
 
 
 
 
 
209
  for filename in self.filenames:
210
  self._save(
211
  obj=obj,
212
  filename=filename,
213
  output_dir=output_dir,
214
- multi_run_result=self.integrate_multirun_result,
215
  )
216
  # if available, also save the aggregated result
217
  if obj_aggregated is not None:
218
  file_base_name, ext = os.path.splitext(filename)
219
  filename_aggregated = f"{file_base_name}.aggregated{ext}"
220
- self._save(obj=obj_aggregated, filename=filename_aggregated, output_dir=output_dir)
 
 
 
 
 
 
 
221
 
222
  def _save(
223
- self, obj: Any, filename: str, output_dir: Path, multi_run_result: bool = False
 
 
 
 
 
224
  ) -> None:
225
  self.log.info(f"Saving job_return in {output_dir / filename}")
226
  output_dir.mkdir(parents=True, exist_ok=True)
@@ -236,23 +354,43 @@ class SaveJobReturnValueCallback(Callback):
236
  elif filename.endswith(".md"):
237
  # Convert PyTorch tensors and numpy arrays to native python types
238
  obj_py = to_py_obj(obj)
 
 
239
  obj_py_flat = flatten_dict(obj_py)
240
 
241
- if multi_run_result:
242
- # In the case of multi-run, we expect to have multiple values for each key.
243
- # We therefore just convert the dict to a pandas DataFrame.
244
  result = pd.DataFrame(obj_py_flat)
 
 
 
 
 
 
245
  else:
246
- # In the case of a single job, we expect to have only one value for each key.
247
- # We therefore convert the dict to a pandas Series and ...
248
  series = pd.Series(obj_py_flat)
249
- if len(series.index.levels) > 1:
250
- # ... if the Series has multiple index levels, we create a DataFrame by unstacking the last level.
251
- result = series.unstack(-1)
252
- else:
253
- # ... otherwise we just unpack the one-entry index values and save the resulting Series.
254
  series.index = series.index.get_level_values(0)
255
  result = series
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
  with open(str(output_dir / filename), "w") as file:
258
  file.write(result.to_markdown())
 
3
  import os
4
  import pickle
5
  from pathlib import Path
6
+ from typing import Any, Dict, Generator, List, Optional, Tuple, Union
7
 
8
  import numpy as np
9
  import pandas as pd
 
33
  def list_of_dicts_to_dict_of_lists_recursive(list_of_dicts):
34
  """Convert a list of dicts to a dict of lists recursively.
35
 
36
+ Examples:
37
  # works with nested dicts
38
  >>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": {"c": 2}}, {"a": 3, "b": {"c": 4}}])
39
+ {'a': [1, 3], 'b': {'c': [2, 4]}}
40
  # works with incomplete dicts
41
  >>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": 2}, {"a": 3}])
42
+ {'a': [1, 3], 'b': [2, None]}
43
+
44
+ # works with nested incomplete dicts
45
+ >>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": {"c": 2}}, {"a": 3}])
46
+ {'a': [1, 3], 'b': {'c': [2, None]}}
47
+
48
+ # works with nested incomplete dicts with None values
49
+ >>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": {"c": 2}}, {"a": None}])
50
+ {'a': [1, None], 'b': {'c': [2, None]}}
51
 
52
  Args:
53
  list_of_dicts (List[dict]): A list of dicts.
54
 
55
  Returns:
56
+ dict: An arbitrarily nested dict of lists.
57
  """
58
+ if not list_of_dicts:
59
+ return {}
60
+
61
+ # Check if all elements are either None or dictionaries
62
+ if all(d is None or isinstance(d, dict) for d in list_of_dicts):
63
+ # Gather all keys from non-None dictionaries
64
+ keys = set()
65
+ for d in list_of_dicts:
66
+ if d is not None:
67
  keys.update(d.keys())
68
+
69
+ # Build up the result recursively
70
+ return {
71
+ k: list_of_dicts_to_dict_of_lists_recursive(
72
+ [(d[k] if d is not None and k in d else None) for d in list_of_dicts]
73
+ )
74
+ for k in keys
75
+ }
76
  else:
77
+ # If items are not all dict/None, just return the list as is (base case).
78
  return list_of_dicts
79
 
80
 
 
87
  yield new_key, v
88
 
89
 
90
+ def flatten_dict(d: Dict[str, Any], pad_keys: bool = True) -> Dict[Tuple[str, ...], Any]:
91
+ """Flattens a dictionary with nested keys. Per default, the keys are padded with np.nan to have
92
+ the same length.
93
+
94
+ Example:
95
+ >>> d = {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
96
+ >>> flatten_dict(d)
97
+ {('a', 'b', 'c'): 1, ('a', 'b', 'd'): 2, ('a', 'e', np.nan): 3}
98
+
99
+ # with padding the keys
100
+ >>> d = {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
101
+ >>> flatten_dict(d, pad_keys=False)
102
+ {('a', 'b', 'c'): 1, ('a', 'b', 'd'): 2, ('a', 'e'): 3}
103
+ """
104
+ result = dict(_flatten_dict_gen(d))
105
+ # pad the keys with np.nan to have the same length. We use np.nan to be pandas-friendly.
106
+ if pad_keys:
107
+ max_num_keys = max(len(k) for k in result.keys())
108
+ result = {
109
+ tuple(list(k) + [np.nan] * (max_num_keys - len(k))): v for k, v in result.items()
110
+ }
111
+ return result
112
 
113
 
114
+ def unflatten_dict(
115
+ d: Dict[Tuple[str, ...], Any], unpad_keys: bool = True
116
+ ) -> Union[Dict[str, Any], Any]:
117
+ """Unflattens a dictionary with nested keys. Per default, the keys are unpadded by removing
118
+ np.nan values.
119
 
120
  Example:
121
  >>> d = {("a", "b", "c"): 1, ("a", "b", "d"): 2, ("a", "e"): 3}
122
  >>> unflatten_dict(d)
123
  {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
124
+
125
+ # with unpad the keys
126
+ >>> d = {("a", "b", "c"): 1, ("a", "b", "d"): 2, ("a", "e", np.nan): 3}
127
+ >>> unflatten_dict(d)
128
+ {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
129
  """
130
  result: Dict[str, Any] = {}
131
  for k, v in d.items():
132
+ if unpad_keys:
133
+ k = tuple([ki for ki in k if not pd.isna(ki)])
134
  if len(k) == 0:
135
  if len(result) > 1:
136
  raise ValueError("Cannot unflatten dictionary with multiple root keys.")
 
192
  nested), where the keys are the keys of the job return-values and the values are lists of the corresponding
193
  values of all jobs. This is useful if you want to access specific values of all jobs in a multi-run all at once.
194
  Also, aggregated values (e.g. mean, min, max) are created for all numeric values and saved in another file.
195
+ multirun_aggregator_blacklist: List[str] (default: None)
196
+ A list of keys to exclude from the aggregation (of multirun results), such as "count" or "25%". If None,
197
+ all keys are included. See pd.DataFrame.describe() for possible aggregation keys.
198
+ For numeric values, it is recommended to use ["min", "25%", "50%", "75%", "max"]
199
+ which will result in keeping only the count, mean and std values.
200
+ multirun_create_ids_from_overrides: bool (default: True)
201
+ Create job identifiers from the overrides of the jobs in a multi-run. If False, the job index is used as
202
+ identifier.
203
+ markdown_round_digits: int (default: 3)
204
+ The number of digits to round the values in the markdown file. If None, no rounding is applied.
205
+ multirun_job_id_key: str (default: "job_id")
206
+ The key to use for the job identifiers in the integrated multi-run result.
207
+ paths_file: str (default: None)
208
+ The file to save the paths of the log directories to. If None, the paths are not saved.
209
+ path_id: str (default: None)
210
+ A prefix to add to each line in the paths_file separated by a colon. If None, no prefix is added.
211
+ multirun_paths_file: str (default: None)
212
+ The file to save the paths of the multi-run log directories to. If None, the paths are not saved.
213
+ multirun_path_id: str (default: None)
214
+ A prefix to add to each line in the multirun_paths_file separated by a colon. If None, no prefix is added.
215
  """
216
 
217
  def __init__(
218
  self,
219
  filenames: Union[str, List[str]] = "job_return_value.json",
220
  integrate_multirun_result: bool = False,
221
+ multirun_aggregator_blacklist: Optional[List[str]] = None,
222
+ multirun_create_ids_from_overrides: bool = True,
223
+ markdown_round_digits: Optional[int] = 3,
224
+ multirun_job_id_key: str = "job_id",
225
+ paths_file: Optional[str] = None,
226
+ path_id: Optional[str] = None,
227
+ multirun_paths_file: Optional[str] = None,
228
+ multirun_path_id: Optional[str] = None,
229
  ) -> None:
230
  self.log = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
231
  self.filenames = [filenames] if isinstance(filenames, str) else filenames
232
  self.integrate_multirun_result = integrate_multirun_result
233
  self.job_returns: List[JobReturn] = []
234
+ self.multirun_aggregator_blacklist = multirun_aggregator_blacklist
235
+ self.multirun_create_ids_from_overrides = multirun_create_ids_from_overrides
236
+ self.multirun_job_id_key = multirun_job_id_key
237
+ self.markdown_round_digits = markdown_round_digits
238
+ self.multirun_paths_file = multirun_paths_file
239
+ self.multirun_path_id = multirun_path_id
240
+ self.paths_file = paths_file
241
+ self.path_id = path_id
242
 
243
  def on_job_end(self, config: DictConfig, job_return: JobReturn, **kwargs: Any) -> None:
244
  self.job_returns.append(job_return)
245
  output_dir = Path(config.hydra.runtime.output_dir) # / Path(config.hydra.output_subdir)
246
+ if self.paths_file is not None:
247
+ # append the output_dir to the file
248
+ with open(self.paths_file, "a") as file:
249
+ file.write(f"{output_dir}\n")
250
+
251
  for filename in self.filenames:
252
  self._save(obj=job_return.return_value, filename=filename, output_dir=output_dir)
253
 
254
  def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
255
+ job_ids: Union[List[str], List[int]]
256
+ if self.multirun_create_ids_from_overrides:
257
+ job_ids = overrides_to_identifiers([jr.overrides for jr in self.job_returns])
258
+ else:
259
+ job_ids = list(range(len(self.job_returns)))
260
+
261
  if self.integrate_multirun_result:
262
  # rearrange the job return-values of all jobs from a multi-run into a dict of lists (maybe nested),
263
  obj = list_of_dicts_to_dict_of_lists_recursive(
264
  [jr.return_value for jr in self.job_returns]
265
  )
266
+ if not isinstance(obj, dict):
267
+ obj = {"value": obj}
268
+ if self.multirun_create_ids_from_overrides:
269
+ obj[self.multirun_job_id_key] = job_ids
270
+
271
  # also create an aggregated result
272
  # convert to python object to allow selecting numeric columns
273
  obj_py = to_py_obj(obj)
 
287
  else:
288
  # aggregate the numeric values
289
  df_described = df_numbers_only.describe()
290
+ # remove rows in the blacklist
291
+ if self.multirun_aggregator_blacklist is not None:
292
+ df_described = df_described.drop(
293
+ self.multirun_aggregator_blacklist, errors="ignore", axis="index"
294
+ )
295
  # add the aggregation keys (e.g. mean, min, ...) as most inner keys and convert back to dict
296
  obj_flat_aggregated = df_described.T.stack().to_dict()
297
  # unflatten because _save() works better with nested dicts
 
299
  else:
300
  # create a dict of the job return-values of all jobs from a multi-run
301
  # (_save() works better with nested dicts)
302
+ obj = {
303
+ identifier: jr.return_value for identifier, jr in zip(job_ids, self.job_returns)
304
+ }
305
  obj_aggregated = None
306
  output_dir = Path(config.hydra.sweep.dir)
307
+ if self.multirun_paths_file is not None:
308
+ # append the output_dir to the file
309
+ line = f"{output_dir}\n"
310
+ if self.multirun_path_id is not None:
311
+ line = f"{self.multirun_path_id}:{line}"
312
+ with open(self.multirun_paths_file, "a") as file:
313
+ file.write(line)
314
+
315
  for filename in self.filenames:
316
  self._save(
317
  obj=obj,
318
  filename=filename,
319
  output_dir=output_dir,
320
+ is_tabular_data=self.integrate_multirun_result,
321
  )
322
  # if available, also save the aggregated result
323
  if obj_aggregated is not None:
324
  file_base_name, ext = os.path.splitext(filename)
325
  filename_aggregated = f"{file_base_name}.aggregated{ext}"
326
+ self._save(
327
+ obj=obj_aggregated,
328
+ filename=filename_aggregated,
329
+ output_dir=output_dir,
330
+ # If we have aggregated (integrated multi-run) results, we unstack the last level,
331
+ # i.e. the aggregation key.
332
+ unstack_last_index_level=True,
333
+ )
334
 
335
  def _save(
336
+ self,
337
+ obj: Any,
338
+ filename: str,
339
+ output_dir: Path,
340
+ is_tabular_data: bool = False,
341
+ unstack_last_index_level: bool = False,
342
  ) -> None:
343
  self.log.info(f"Saving job_return in {output_dir / filename}")
344
  output_dir.mkdir(parents=True, exist_ok=True)
 
354
  elif filename.endswith(".md"):
355
  # Convert PyTorch tensors and numpy arrays to native python types
356
  obj_py = to_py_obj(obj)
357
+ if not isinstance(obj_py, dict):
358
+ obj_py = {"value": obj_py}
359
  obj_py_flat = flatten_dict(obj_py)
360
 
361
+ if is_tabular_data:
362
+ # In the case of (not aggregated) integrated multi-run result, we expect to have
363
+ # multiple values for each key. We therefore just convert the dict to a pandas DataFrame.
364
  result = pd.DataFrame(obj_py_flat)
365
+ job_id_column = (self.multirun_job_id_key,) + (np.nan,) * (
366
+ result.columns.nlevels - 1
367
+ )
368
+ if job_id_column in result.columns:
369
+ result = result.set_index(job_id_column)
370
+ result.index.name = self.multirun_job_id_key
371
  else:
372
+ # Otherwise, we have only one value for each key. We convert the dict to a pandas Series.
 
373
  series = pd.Series(obj_py_flat)
374
+ # The series has a MultiIndex because flatten_dict() uses a tuple as key.
375
+ if len(series.index.levels) <= 1:
376
+ # If there is only one level, we just use the first level values as index.
 
 
377
  series.index = series.index.get_level_values(0)
378
  result = series
379
+ else:
380
+ # If there are multiple levels, we unstack the series to get a DataFrame
381
+ # providing a better overview.
382
+ if unstack_last_index_level:
383
+ # If we have aggregated (integrated multi-run) results, we unstack the last level,
384
+ # i.e. the aggregation key.
385
+ result = series.unstack(-1)
386
+ else:
387
+ # Otherwise we have a default multi-run result and unstack the first level,
388
+ # i.e. the identifier created from the overrides, and transpose the result
389
+ # to have the individual jobs as rows.
390
+ result = series.unstack(0).T
391
+
392
+ if self.markdown_round_digits is not None:
393
+ result = result.round(self.markdown_round_digits)
394
 
395
  with open(str(output_dir / filename), "w") as file:
396
  file.write(result.to_markdown())
src/langchain_modules/pie_document_store.py CHANGED
@@ -75,7 +75,7 @@ class PieDocumentStore(SerializableStore, BaseStore[str, LCDocument], abc.ABC):
75
  caption: pie_document[layer_name] for layer_name, caption in layer_captions.items()
76
  }
77
  layer_sizes = {
78
- f"num_{caption}s": len(layer) + (len(layer.predictions) if use_predictions else 0)
79
  for caption, layer in layers.items()
80
  }
81
  rows.append({"doc_id": doc_id, **layer_sizes})
 
75
  caption: pie_document[layer_name] for layer_name, caption in layer_captions.items()
76
  }
77
  layer_sizes = {
78
+ f"num_{caption}": len(layer) + (len(layer.predictions) if use_predictions else 0)
79
  for caption, layer in layers.items()
80
  }
81
  rows.append({"doc_id": doc_id, **layer_sizes})
src/langchain_modules/span_retriever.py CHANGED
@@ -23,6 +23,7 @@ from pytorch_ie.documents import (
23
  TextDocumentWithSpans,
24
  )
25
 
 
26
  from .pie_document_store import PieDocumentStore
27
  from .serializable_store import SerializableStore
28
  from .span_vectorstore import SpanVectorStore
@@ -30,20 +31,6 @@ from .span_vectorstore import SpanVectorStore
30
  logger = logging.getLogger(__name__)
31
 
32
 
33
- def _parse_config(config_string: str, format: str) -> Dict[str, Any]:
34
- """Parse a configuration string."""
35
- if format == "json":
36
- import json
37
-
38
- return json.loads(config_string)
39
- elif format == "yaml":
40
- import yaml
41
-
42
- return yaml.safe_load(config_string)
43
- else:
44
- raise ValueError(f"Unsupported format: {format}. Use 'json' or 'yaml'.")
45
-
46
-
47
  METADATA_KEY_CHILD_ID2IDX = "child_id2idx"
48
 
49
 
@@ -136,7 +123,7 @@ class DocumentAwareSpanRetriever(BaseRetriever, SerializableStore):
136
  ) -> "DocumentAwareSpanRetriever":
137
  """Instantiate a retriever from a configuration string."""
138
  return cls.instantiate_from_config(
139
- _parse_config(config_string, format=format), overwrites=overwrites
140
  )
141
 
142
  @classmethod
@@ -725,6 +712,8 @@ class DocumentAwareSpanRetrieverWithRelations(DocumentAwareSpanRetriever):
725
  """The list of span labels to consider."""
726
  reversed_relations_suffix: Optional[str] = None
727
  """Whether to consider reverse relations as well."""
 
 
728
 
729
  def get_relation_layer(
730
  self, pie_document: TextBasedDocument, use_predicted_annotations: bool
@@ -762,11 +751,19 @@ class DocumentAwareSpanRetrieverWithRelations(DocumentAwareSpanRetriever):
762
  )
763
 
764
  for relation in relations:
 
 
 
 
765
  if self.relation_labels is None or relation.label in self.relation_labels:
766
  head2label2tails_with_scores[span2id[relation.head]][relation.label].append(
767
  (span2id[relation.tail], relation.score)
768
  )
769
- if self.reversed_relations_suffix is not None:
 
 
 
 
770
  reversed_label = f"{relation.label}{self.reversed_relations_suffix}"
771
  if self.relation_labels is None or reversed_label in self.relation_labels:
772
  head2label2tails_with_scores[span2id[relation.tail]][
 
23
  TextDocumentWithSpans,
24
  )
25
 
26
+ from ..utils import parse_config
27
  from .pie_document_store import PieDocumentStore
28
  from .serializable_store import SerializableStore
29
  from .span_vectorstore import SpanVectorStore
 
31
  logger = logging.getLogger(__name__)
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  METADATA_KEY_CHILD_ID2IDX = "child_id2idx"
35
 
36
 
 
123
  ) -> "DocumentAwareSpanRetriever":
124
  """Instantiate a retriever from a configuration string."""
125
  return cls.instantiate_from_config(
126
+ parse_config(config_string, format=format), overwrites=overwrites
127
  )
128
 
129
  @classmethod
 
712
  """The list of span labels to consider."""
713
  reversed_relations_suffix: Optional[str] = None
714
  """Whether to consider reverse relations as well."""
715
+ symmetric_relations: Optional[list[str]] = None
716
+ """The list of relation labels that are symmetric."""
717
 
718
  def get_relation_layer(
719
  self, pie_document: TextBasedDocument, use_predicted_annotations: bool
 
751
  )
752
 
753
  for relation in relations:
754
+ is_symmetric = (
755
+ self.symmetric_relations is not None
756
+ and relation.label in self.symmetric_relations
757
+ )
758
  if self.relation_labels is None or relation.label in self.relation_labels:
759
  head2label2tails_with_scores[span2id[relation.head]][relation.label].append(
760
  (span2id[relation.tail], relation.score)
761
  )
762
+ if is_symmetric:
763
+ head2label2tails_with_scores[span2id[relation.tail]][
764
+ relation.label
765
+ ].append((span2id[relation.head], relation.score))
766
+ if self.reversed_relations_suffix is not None and not is_symmetric:
767
  reversed_label = f"{relation.label}{self.reversed_relations_suffix}"
768
  if self.relation_labels is None or reversed_label in self.relation_labels:
769
  head2label2tails_with_scores[span2id[relation.tail]][
src/pipeline/ner_re_pipeline.py CHANGED
@@ -2,7 +2,18 @@ from __future__ import annotations
2
 
3
  import logging
4
  from functools import partial
5
- from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, TypeVar, Union
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  from pie_modules.utils import resolve_type
8
  from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
@@ -72,11 +83,13 @@ def add_annotations_from_other_documents(
72
  def process_pipeline_steps(
73
  documents: Sequence[Document],
74
  processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]],
 
75
  ) -> Sequence[Document]:
76
 
77
  # call the processors in the order they are provided
78
  for step_name, processor in processors.items():
79
- logger.info(f"process {step_name} ...")
 
80
  processed_documents = processor(documents)
81
  if processed_documents is not None:
82
  documents = processed_documents
@@ -120,6 +133,7 @@ class NerRePipeline:
120
  batch_size: Optional[int] = None,
121
  show_progress_bar: Optional[bool] = None,
122
  document_type: Optional[Union[Type[Document], str]] = None,
 
123
  **processor_kwargs,
124
  ):
125
  self.taskmodule = DummyTaskmodule(document_type)
@@ -128,6 +142,7 @@ class NerRePipeline:
128
  self.processor_kwargs = processor_kwargs or {}
129
  self.entity_layer = entity_layer
130
  self.relation_layer = relation_layer
 
131
  # set some values for the inference processors, if provided
132
  for inference_pipeline in ["ner_pipeline", "re_pipeline"]:
133
  if inference_pipeline not in self.processor_kwargs:
@@ -145,7 +160,29 @@ class NerRePipeline:
145
  ):
146
  self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar
147
 
148
- def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  input_docs: Sequence[Document]
151
  # we need to keep the original documents to add the gold data back
@@ -166,24 +203,14 @@ class NerRePipeline:
166
  layer_names=[self.entity_layer, self.relation_layer],
167
  **self.processor_kwargs.get("clear_annotations", {}),
168
  ),
169
- "ner_pipeline": AutoPipeline.from_pretrained(
170
- self.ner_model_path, **self.processor_kwargs.get("ner_pipeline", {})
171
- ),
172
  "use_predicted_entities": partial(
173
  process_documents,
174
  processor=move_annotations_from_predictions,
175
  layer_names=[self.entity_layer],
176
  **self.processor_kwargs.get("use_predicted_entities", {}),
177
  ),
178
- # "create_candidate_relations": partial(
179
- # process_documents,
180
- # processor=CandidateRelationAdder(
181
- # **self.processor_kwargs.get("create_candidate_relations", {})
182
- # ),
183
- # ),
184
- "re_pipeline": AutoPipeline.from_pretrained(
185
- self.re_model_path, **self.processor_kwargs.get("re_pipeline", {})
186
- ),
187
  # otherwise we can not move the entities back to predictions
188
  "clear_candidate_relations": partial(
189
  process_documents,
@@ -204,5 +231,8 @@ class NerRePipeline:
204
  **self.processor_kwargs.get("re_add_gold_data", {}),
205
  ),
206
  },
 
207
  )
 
 
208
  return docs_with_predictions
 
2
 
3
  import logging
4
  from functools import partial
5
+ from typing import (
6
+ Callable,
7
+ Dict,
8
+ Iterable,
9
+ List,
10
+ Optional,
11
+ Sequence,
12
+ Type,
13
+ TypeVar,
14
+ Union,
15
+ overload,
16
+ )
17
 
18
  from pie_modules.utils import resolve_type
19
  from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
 
83
  def process_pipeline_steps(
84
  documents: Sequence[Document],
85
  processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]],
86
+ verbose: bool = False,
87
  ) -> Sequence[Document]:
88
 
89
  # call the processors in the order they are provided
90
  for step_name, processor in processors.items():
91
+ if verbose:
92
+ logger.info(f"process {step_name} ...")
93
  processed_documents = processor(documents)
94
  if processed_documents is not None:
95
  documents = processed_documents
 
133
  batch_size: Optional[int] = None,
134
  show_progress_bar: Optional[bool] = None,
135
  document_type: Optional[Union[Type[Document], str]] = None,
136
+ verbose: bool = True,
137
  **processor_kwargs,
138
  ):
139
  self.taskmodule = DummyTaskmodule(document_type)
 
142
  self.processor_kwargs = processor_kwargs or {}
143
  self.entity_layer = entity_layer
144
  self.relation_layer = relation_layer
145
+ self.verbose = verbose
146
  # set some values for the inference processors, if provided
147
  for inference_pipeline in ["ner_pipeline", "re_pipeline"]:
148
  if inference_pipeline not in self.processor_kwargs:
 
160
  ):
161
  self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar
162
 
163
+ self.ner_pipeline = AutoPipeline.from_pretrained(
164
+ self.ner_model_path, **self.processor_kwargs.get("ner_pipeline", {})
165
+ )
166
+ self.re_pipeline = AutoPipeline.from_pretrained(
167
+ self.re_model_path, **self.processor_kwargs.get("re_pipeline", {})
168
+ )
169
+
170
+ @overload
171
+ def __call__(
172
+ self, documents: Sequence[Document], inplace: bool = False
173
+ ) -> Sequence[Document]: ...
174
+
175
+ @overload
176
+ def __call__(self, documents: Document, inplace: bool = False) -> Document: ...
177
+
178
+ def __call__(
179
+ self, documents: Union[Sequence[Document], Document], inplace: bool = False
180
+ ) -> Union[Sequence[Document], Document]:
181
+
182
+ is_single_doc = False
183
+ if isinstance(documents, Document):
184
+ documents = [documents]
185
+ is_single_doc = True
186
 
187
  input_docs: Sequence[Document]
188
  # we need to keep the original documents to add the gold data back
 
203
  layer_names=[self.entity_layer, self.relation_layer],
204
  **self.processor_kwargs.get("clear_annotations", {}),
205
  ),
206
+ "ner_pipeline": self.ner_pipeline,
 
 
207
  "use_predicted_entities": partial(
208
  process_documents,
209
  processor=move_annotations_from_predictions,
210
  layer_names=[self.entity_layer],
211
  **self.processor_kwargs.get("use_predicted_entities", {}),
212
  ),
213
+ "re_pipeline": self.re_pipeline,
 
 
 
 
 
 
 
 
214
  # otherwise we can not move the entities back to predictions
215
  "clear_candidate_relations": partial(
216
  process_documents,
 
231
  **self.processor_kwargs.get("re_add_gold_data", {}),
232
  ),
233
  },
234
+ verbose=self.verbose,
235
  )
236
+ if is_single_doc:
237
+ return docs_with_predictions[0]
238
  return docs_with_predictions
src/predict.py CHANGED
@@ -106,8 +106,12 @@ def predict(cfg: DictConfig) -> Tuple[dict, dict]:
106
  # Per default, the model is loaded with .from_pretrained() which already loads the weights.
107
  # However, ckpt_path can be used to load different weights from any checkpoint.
108
  if cfg.ckpt_path is not None:
109
- pipeline.model = pipeline.model.load_from_checkpoint(checkpoint_path=cfg.ckpt_path).to(
110
- pipeline.device
 
 
 
 
111
  )
112
 
113
  # auto-convert the dataset if the metric specifies a document type
 
106
  # Per default, the model is loaded with .from_pretrained() which already loads the weights.
107
  # However, ckpt_path can be used to load different weights from any checkpoint.
108
  if cfg.ckpt_path is not None:
109
+ log.info(f"Loading model weights from checkpoint: {cfg.ckpt_path}")
110
+ pipeline.model = (
111
+ type(pipeline.model)
112
+ .load_from_checkpoint(checkpoint_path=cfg.ckpt_path)
113
+ .to(pipeline.device)
114
+ .to(dtype=pipeline.model.dtype)
115
  )
116
 
117
  # auto-convert the dataset if the metric specifies a document type
src/start_demo.py CHANGED
@@ -19,8 +19,10 @@ import yaml
19
  from src.demo.annotation_utils import load_argumentation_model
20
  from src.demo.backend_utils import (
21
  download_processed_documents,
 
22
  process_text_from_arxiv,
23
  process_uploaded_files,
 
24
  render_annotated_document,
25
  upload_processed_documents,
26
  wrapped_add_annotated_pie_documents_from_dataset,
@@ -31,6 +33,7 @@ from src.demo.frontend_utils import (
31
  escape_regex,
32
  get_cell_for_fixed_column_from_df,
33
  open_accordion,
 
34
  unescape_regex,
35
  )
36
  from src.demo.rendering_utils import AVAILABLE_RENDER_MODES, HIGHLIGHT_SPANS_JS
@@ -67,12 +70,10 @@ def main(cfg: DictConfig) -> None:
67
 
68
  example_text = cfg["example_text"]
69
 
70
- default_device = "cuda" if torch.cuda.is_available() else "cpu"
71
 
72
- default_retriever_config_str = load_yaml_config(cfg["default_retriever_config_path"])
73
-
74
- default_model_name = cfg["default_model_name"]
75
- default_model_revision = cfg["default_model_revision"]
76
  handle_parts_of_same = cfg["handle_parts_of_same"]
77
 
78
  default_arxiv_id = cfg["default_arxiv_id"]
@@ -97,19 +98,32 @@ def main(cfg: DictConfig) -> None:
97
  }
98
  render_caption2mode = {v: k for k, v in render_mode2caption.items()}
99
  default_min_similarity = cfg["default_min_similarity"]
 
100
  layer_caption_mapping = cfg["layer_caption_mapping"]
101
  relation_name_mapping = cfg["relation_name_mapping"]
102
 
 
 
 
 
 
 
 
103
  gr.Info("Loading models ...")
104
  argumentation_model = load_argumentation_model(
105
- model_name=default_model_name,
106
- revision=default_model_revision,
107
  device=default_device,
108
  )
109
  retriever = load_retriever(
110
- default_retriever_config_str, device=default_device, config_format="yaml"
111
  )
112
 
 
 
 
 
 
 
113
  with gr.Blocks() as demo:
114
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
115
  # models_state = gr.State((argumentation_model, embedding_model))
@@ -131,18 +145,16 @@ def main(cfg: DictConfig) -> None:
131
 
132
  with gr.Accordion("Model Configuration", open=False):
133
  with gr.Accordion("argumentation structure", open=True):
134
- model_name = gr.Textbox(
135
- label="Model Name",
136
- value=default_model_name,
137
- )
138
- model_revision = gr.Textbox(
139
- label="Model Revision",
140
- value=default_model_revision,
141
  )
142
  load_arg_model_btn = gr.Button("Load Argumentation Model")
143
 
144
  with gr.Accordion("retriever", open=True):
145
- retriever_config = gr.Code(
146
  language="yaml",
147
  label="Retriever Configuration",
148
  value=default_retriever_config_str,
@@ -155,26 +167,25 @@ def main(cfg: DictConfig) -> None:
155
  value=default_device,
156
  )
157
  load_arg_model_btn.click(
158
- fn=lambda _model_name, _model_revision, _device: (
159
  load_argumentation_model(
160
- model_name=_model_name,
161
- revision=_model_revision,
162
  device=_device,
163
  ),
164
  ),
165
- inputs=[model_name, model_revision, device],
166
  outputs=argumentation_model_state,
167
  )
168
  load_retriever_btn.click(
169
  fn=lambda _retriever_config, _device, _previous_retriever: (
170
  load_retriever(
171
- retriever_config_str=_retriever_config,
172
  device=_device,
173
  previous_retriever=_previous_retriever[0],
174
  config_format="yaml",
175
  ),
176
  ),
177
- inputs=[retriever_config, device, retriever_state],
178
  outputs=retriever_state,
179
  )
180
 
@@ -213,7 +224,7 @@ def main(cfg: DictConfig) -> None:
213
  with gr.Tabs() as right_tabs:
214
  with gr.Tab("Retrieval", id="retrieval") as retrieval_tab:
215
  with gr.Accordion(
216
- "Indexed Documents", open=False
217
  ) as processed_documents_accordion:
218
  processed_documents_df = gr.DataFrame(
219
  headers=["id", "num_adus", "num_relations"],
@@ -274,7 +285,7 @@ def main(cfg: DictConfig) -> None:
274
  minimum=2,
275
  maximum=50,
276
  step=1,
277
- value=10,
278
  )
279
  retrieve_similar_adus_btn = gr.Button(
280
  "Retrieve *similar* ADUs for *selected* ADU"
@@ -293,8 +304,10 @@ def main(cfg: DictConfig) -> None:
293
  "Retrieve *relevant* ADUs for *all* ADUs in the document"
294
  )
295
  all_relevant_adus_df = gr.DataFrame(
296
- headers=["doc_id", "adu_id", "score", "text"], interactive=False
 
297
  )
 
298
 
299
  with gr.Tab("Import Documents", id="import_documents") as import_documents_tab:
300
  upload_btn = gr.UploadButton(
@@ -303,6 +316,28 @@ def main(cfg: DictConfig) -> None:
303
  file_count="multiple",
304
  )
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  with gr.Accordion("Import text from arXiv", open=False):
307
  arxiv_id = gr.Textbox(
308
  label="arXiv paper ID",
@@ -326,13 +361,25 @@ def main(cfg: DictConfig) -> None:
326
  load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
327
 
328
  render_event_kwargs = dict(
329
- fn=lambda _retriever, _document_id, _render_as, _render_kwargs: render_annotated_document(
330
  retriever=_retriever[0],
331
  document_id=_document_id,
332
  render_with=render_caption2mode[_render_as],
333
  render_kwargs_json=_render_kwargs,
 
 
 
 
 
334
  ),
335
- inputs=[retriever_state, selected_document_id, render_as, render_kwargs],
 
 
 
 
 
 
 
336
  outputs=rendered_output,
337
  )
338
 
@@ -343,6 +390,16 @@ def main(cfg: DictConfig) -> None:
343
  inputs=[retriever_state],
344
  outputs=[processed_documents_df],
345
  )
 
 
 
 
 
 
 
 
 
 
346
  predict_btn.click(
347
  fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
348
  ).then(
@@ -367,6 +424,8 @@ def main(cfg: DictConfig) -> None:
367
  api_name="predict",
368
  ).success(
369
  **show_overview_kwargs
 
 
370
  ).success(
371
  **render_event_kwargs
372
  )
@@ -396,6 +455,8 @@ def main(cfg: DictConfig) -> None:
396
  api_name="predict",
397
  ).success(
398
  **show_overview_kwargs
 
 
399
  )
400
 
401
  load_pie_dataset_btn.click(
@@ -409,6 +470,8 @@ def main(cfg: DictConfig) -> None:
409
  ),
410
  inputs=[retriever_state, load_pie_dataset_kwargs_str],
411
  outputs=[processed_documents_df],
 
 
412
  )
413
 
414
  selected_document_id.change(
@@ -430,7 +493,9 @@ def main(cfg: DictConfig) -> None:
430
  file_names=_file_names,
431
  argumentation_model=_argumentation_model[0],
432
  retriever=_retriever[0],
433
- split_regex_escaped=unescape_regex(_split_regex_escaped),
 
 
434
  handle_parts_of_same=handle_parts_of_same,
435
  layer_captions=layer_caption_mapping,
436
  ),
@@ -441,7 +506,61 @@ def main(cfg: DictConfig) -> None:
441
  split_regex_escaped,
442
  ],
443
  outputs=[processed_documents_df],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  )
 
445
  processed_documents_df.select(
446
  fn=get_cell_for_fixed_column_from_df,
447
  inputs=[processed_documents_df, gr.State("doc_id")],
@@ -461,7 +580,7 @@ def main(cfg: DictConfig) -> None:
461
  ),
462
  inputs=[upload_processed_documents_btn, retriever_state],
463
  outputs=[processed_documents_df],
464
- )
465
 
466
  retrieve_relevant_adus_event_kwargs = dict(
467
  fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans(
@@ -533,12 +652,16 @@ def main(cfg: DictConfig) -> None:
533
  )
534
 
535
  retrieve_all_relevant_adus_btn.click(
536
- fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_relevant_spans(
537
- retriever=_retriever[0],
538
- query_doc_id=_document_id,
539
- k=_tok_k,
540
- score_threshold=_min_similarity,
541
- query_span_id_column="query_span_id",
 
 
 
 
542
  ),
543
  inputs=[
544
  retriever_state,
@@ -546,9 +669,11 @@ def main(cfg: DictConfig) -> None:
546
  min_similarity,
547
  top_k,
548
  ],
549
- outputs=[all_relevant_adus_df],
550
  )
551
 
 
 
552
  # select query span id from the "retrieve all" result data frames
553
  all_similar_adus_df.select(
554
  fn=get_cell_for_fixed_column_from_df,
 
19
  from src.demo.annotation_utils import load_argumentation_model
20
  from src.demo.backend_utils import (
21
  download_processed_documents,
22
+ load_acl_anthology_venues,
23
  process_text_from_arxiv,
24
  process_uploaded_files,
25
+ process_uploaded_pdf_files,
26
  render_annotated_document,
27
  upload_processed_documents,
28
  wrapped_add_annotated_pie_documents_from_dataset,
 
33
  escape_regex,
34
  get_cell_for_fixed_column_from_df,
35
  open_accordion,
36
+ open_accordion_with_stats,
37
  unescape_regex,
38
  )
39
  from src.demo.rendering_utils import AVAILABLE_RENDER_MODES, HIGHLIGHT_SPANS_JS
 
70
 
71
  example_text = cfg["example_text"]
72
 
73
+ default_device = "cuda:0" if torch.cuda.is_available() else "cpu"
74
 
75
+ default_retriever_config_str = yaml.dump(cfg["retriever"])
76
+ default_argumentation_model_config_str = yaml.dump(cfg["argumentation_model"])
 
 
77
  handle_parts_of_same = cfg["handle_parts_of_same"]
78
 
79
  default_arxiv_id = cfg["default_arxiv_id"]
 
98
  }
99
  render_caption2mode = {v: k for k, v in render_mode2caption.items()}
100
  default_min_similarity = cfg["default_min_similarity"]
101
+ default_top_k = cfg["default_top_k"]
102
  layer_caption_mapping = cfg["layer_caption_mapping"]
103
  relation_name_mapping = cfg["relation_name_mapping"]
104
 
105
+ indexed_documents_label = "Indexed Documents"
106
+ indexed_documents_caption2column = {
107
+ "documents": "TOTAL",
108
+ "ADUs": "num_adus",
109
+ "Relations": "num_relations",
110
+ }
111
+
112
  gr.Info("Loading models ...")
113
  argumentation_model = load_argumentation_model(
114
+ config_str=default_argumentation_model_config_str,
 
115
  device=default_device,
116
  )
117
  retriever = load_retriever(
118
+ config_str=default_retriever_config_str, device=default_device, config_format="yaml"
119
  )
120
 
121
+ if cfg.get("pdf_fulltext_extractor"):
122
+ gr.Info("Loading PDF fulltext extractor ...")
123
+ pdf_fulltext_extractor = hydra.utils.instantiate(cfg["pdf_fulltext_extractor"])
124
+ else:
125
+ pdf_fulltext_extractor = None
126
+
127
  with gr.Blocks() as demo:
128
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
129
  # models_state = gr.State((argumentation_model, embedding_model))
 
145
 
146
  with gr.Accordion("Model Configuration", open=False):
147
  with gr.Accordion("argumentation structure", open=True):
148
+ argumentation_model_config_str = gr.Code(
149
+ language="yaml",
150
+ label="Argumentation Model Configuration",
151
+ value=default_argumentation_model_config_str,
152
+ lines=len(default_argumentation_model_config_str.split("\n")),
 
 
153
  )
154
  load_arg_model_btn = gr.Button("Load Argumentation Model")
155
 
156
  with gr.Accordion("retriever", open=True):
157
+ retriever_config_str = gr.Code(
158
  language="yaml",
159
  label="Retriever Configuration",
160
  value=default_retriever_config_str,
 
167
  value=default_device,
168
  )
169
  load_arg_model_btn.click(
170
+ fn=lambda _argumentation_model_config_str, _device: (
171
  load_argumentation_model(
172
+ config_str=_argumentation_model_config_str,
 
173
  device=_device,
174
  ),
175
  ),
176
+ inputs=[argumentation_model_config_str, device],
177
  outputs=argumentation_model_state,
178
  )
179
  load_retriever_btn.click(
180
  fn=lambda _retriever_config, _device, _previous_retriever: (
181
  load_retriever(
182
+ config_str=_retriever_config,
183
  device=_device,
184
  previous_retriever=_previous_retriever[0],
185
  config_format="yaml",
186
  ),
187
  ),
188
+ inputs=[retriever_config_str, device, retriever_state],
189
  outputs=retriever_state,
190
  )
191
 
 
224
  with gr.Tabs() as right_tabs:
225
  with gr.Tab("Retrieval", id="retrieval") as retrieval_tab:
226
  with gr.Accordion(
227
+ indexed_documents_label, open=False
228
  ) as processed_documents_accordion:
229
  processed_documents_df = gr.DataFrame(
230
  headers=["id", "num_adus", "num_relations"],
 
285
  minimum=2,
286
  maximum=50,
287
  step=1,
288
+ value=default_top_k,
289
  )
290
  retrieve_similar_adus_btn = gr.Button(
291
  "Retrieve *similar* ADUs for *selected* ADU"
 
304
  "Retrieve *relevant* ADUs for *all* ADUs in the document"
305
  )
306
  all_relevant_adus_df = gr.DataFrame(
307
+ headers=["doc_id", "adu_id", "score", "text", "query_span_id"],
308
+ interactive=False,
309
  )
310
+ all_relevant_adus_query_doc_id = gr.Textbox(visible=False)
311
 
312
  with gr.Tab("Import Documents", id="import_documents") as import_documents_tab:
313
  upload_btn = gr.UploadButton(
 
316
  file_count="multiple",
317
  )
318
 
319
+ upload_pdf_btn = gr.UploadButton(
320
+ "Batch Analyse PDFs",
321
+ # file_types=["pdf"],
322
+ file_count="multiple",
323
+ visible=pdf_fulltext_extractor is not None,
324
+ )
325
+
326
+ enable_acl_venue_loading = pdf_fulltext_extractor is not None and cfg.get(
327
+ "acl_anthology_pdf_dir"
328
+ )
329
+ acl_anthology_venues = gr.Textbox(
330
+ label="ACL Anthology Venues",
331
+ value="wiesp",
332
+ max_lines=1,
333
+ visible=enable_acl_venue_loading,
334
+ )
335
+ load_acl_anthology_venues_btn = gr.Button(
336
+ "Import from ACL Anthology",
337
+ variant="secondary",
338
+ visible=enable_acl_venue_loading,
339
+ )
340
+
341
  with gr.Accordion("Import text from arXiv", open=False):
342
  arxiv_id = gr.Textbox(
343
  label="arXiv paper ID",
 
361
  load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
362
 
363
  render_event_kwargs = dict(
364
+ fn=lambda _retriever, _document_id, _render_as, _render_kwargs, _all_relevant_adus_df, _all_relevant_adus_query_doc_id: render_annotated_document(
365
  retriever=_retriever[0],
366
  document_id=_document_id,
367
  render_with=render_caption2mode[_render_as],
368
  render_kwargs_json=_render_kwargs,
369
+ highlight_span_ids=(
370
+ _all_relevant_adus_df["query_span_id"].tolist()
371
+ if _document_id == _all_relevant_adus_query_doc_id
372
+ else None
373
+ ),
374
  ),
375
+ inputs=[
376
+ retriever_state,
377
+ selected_document_id,
378
+ render_as,
379
+ render_kwargs,
380
+ all_relevant_adus_df,
381
+ all_relevant_adus_query_doc_id,
382
+ ],
383
  outputs=rendered_output,
384
  )
385
 
 
390
  inputs=[retriever_state],
391
  outputs=[processed_documents_df],
392
  )
393
+ show_stats_kwargs = dict(
394
+ fn=lambda _processed_documents_df: open_accordion_with_stats(
395
+ _processed_documents_df,
396
+ base_label=indexed_documents_label,
397
+ caption2column=indexed_documents_caption2column,
398
+ total_column="TOTAL",
399
+ ),
400
+ inputs=[processed_documents_df],
401
+ outputs=[processed_documents_accordion],
402
+ )
403
  predict_btn.click(
404
  fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
405
  ).then(
 
424
  api_name="predict",
425
  ).success(
426
  **show_overview_kwargs
427
+ ).success(
428
+ **show_stats_kwargs
429
  ).success(
430
  **render_event_kwargs
431
  )
 
455
  api_name="predict",
456
  ).success(
457
  **show_overview_kwargs
458
+ ).success(
459
+ **show_stats_kwargs
460
  )
461
 
462
  load_pie_dataset_btn.click(
 
470
  ),
471
  inputs=[retriever_state, load_pie_dataset_kwargs_str],
472
  outputs=[processed_documents_df],
473
+ ).success(
474
+ **show_stats_kwargs
475
  )
476
 
477
  selected_document_id.change(
 
493
  file_names=_file_names,
494
  argumentation_model=_argumentation_model[0],
495
  retriever=_retriever[0],
496
+ split_regex_escaped=(
497
+ unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
498
+ ),
499
  handle_parts_of_same=handle_parts_of_same,
500
  layer_captions=layer_caption_mapping,
501
  ),
 
506
  split_regex_escaped,
507
  ],
508
  outputs=[processed_documents_df],
509
+ ).success(
510
+ **show_stats_kwargs
511
+ )
512
+ upload_pdf_btn.upload(
513
+ fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs]
514
+ ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then(
515
+ fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_pdf_files(
516
+ file_names=_file_names,
517
+ argumentation_model=_argumentation_model[0],
518
+ retriever=_retriever[0],
519
+ split_regex_escaped=(
520
+ unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
521
+ ),
522
+ handle_parts_of_same=handle_parts_of_same,
523
+ layer_captions=layer_caption_mapping,
524
+ pdf_fulltext_extractor=pdf_fulltext_extractor,
525
+ ),
526
+ inputs=[
527
+ upload_pdf_btn,
528
+ argumentation_model_state,
529
+ retriever_state,
530
+ split_regex_escaped,
531
+ ],
532
+ outputs=[processed_documents_df],
533
+ ).success(
534
+ **show_stats_kwargs
535
+ )
536
+
537
+ load_acl_anthology_venues_btn.click(
538
+ fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs]
539
+ ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then(
540
+ fn=lambda _acl_anthology_venues, _argumentation_model, _retriever, _split_regex_escaped: load_acl_anthology_venues(
541
+ pdf_fulltext_extractor=pdf_fulltext_extractor,
542
+ venues=[venue.strip() for venue in _acl_anthology_venues.split(",")],
543
+ argumentation_model=_argumentation_model[0],
544
+ retriever=_retriever[0],
545
+ split_regex_escaped=(
546
+ unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
547
+ ),
548
+ handle_parts_of_same=handle_parts_of_same,
549
+ layer_captions=layer_caption_mapping,
550
+ acl_anthology_data_dir=cfg.get("acl_anthology_data_dir"),
551
+ pdf_output_dir=cfg.get("acl_anthology_pdf_dir"),
552
+ ),
553
+ inputs=[
554
+ acl_anthology_venues,
555
+ argumentation_model_state,
556
+ retriever_state,
557
+ split_regex_escaped,
558
+ ],
559
+ outputs=[processed_documents_df],
560
+ ).success(
561
+ **show_stats_kwargs
562
  )
563
+
564
  processed_documents_df.select(
565
  fn=get_cell_for_fixed_column_from_df,
566
  inputs=[processed_documents_df, gr.State("doc_id")],
 
580
  ),
581
  inputs=[upload_processed_documents_btn, retriever_state],
582
  outputs=[processed_documents_df],
583
+ ).success(**show_stats_kwargs)
584
 
585
  retrieve_relevant_adus_event_kwargs = dict(
586
  fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans(
 
652
  )
653
 
654
  retrieve_all_relevant_adus_btn.click(
655
+ fn=lambda _retriever, _document_id, _min_similarity, _tok_k: (
656
+ retrieve_all_relevant_spans(
657
+ retriever=_retriever[0],
658
+ query_doc_id=_document_id,
659
+ k=_tok_k,
660
+ score_threshold=_min_similarity,
661
+ query_span_id_column="query_span_id",
662
+ query_span_text_column="query_span_text",
663
+ ),
664
+ _document_id,
665
  ),
666
  inputs=[
667
  retriever_state,
 
669
  min_similarity,
670
  top_k,
671
  ],
672
+ outputs=[all_relevant_adus_df, all_relevant_adus_query_doc_id],
673
  )
674
 
675
+ all_relevant_adus_df.change(**render_event_kwargs)
676
+
677
  # select query span id from the "retrieve all" result data frames
678
  all_similar_adus_df.select(
679
  fn=get_cell_for_fixed_column_from_df,
src/train.py CHANGED
@@ -220,6 +220,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
220
  train_metrics = trainer.callback_metrics
221
 
222
  best_ckpt_path = trainer.checkpoint_callback.best_model_path
 
223
  if best_ckpt_path != "":
224
  log.info(f"Best ckpt path: {best_ckpt_path}")
225
  best_checkpoint_file = os.path.basename(best_ckpt_path)
@@ -228,6 +229,14 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
228
  best_checkpoint=best_checkpoint_file,
229
  checkpoint_dir=trainer.checkpoint_callback.dirpath,
230
  )
 
 
 
 
 
 
 
 
231
 
232
  if not cfg.trainer.get("fast_dev_run"):
233
  if cfg.model_save_dir is not None:
@@ -259,6 +268,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
259
  trainer.test(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None)
260
 
261
  test_metrics = trainer.callback_metrics
 
262
 
263
  # merge train and test metrics
264
  metric_dict = {**train_metrics, **test_metrics}
 
220
  train_metrics = trainer.callback_metrics
221
 
222
  best_ckpt_path = trainer.checkpoint_callback.best_model_path
223
+ best_epoch = None
224
  if best_ckpt_path != "":
225
  log.info(f"Best ckpt path: {best_ckpt_path}")
226
  best_checkpoint_file = os.path.basename(best_ckpt_path)
 
229
  best_checkpoint=best_checkpoint_file,
230
  checkpoint_dir=trainer.checkpoint_callback.dirpath,
231
  )
232
+ # get epoch from best_checkpoint_file (e.g. "epoch_078.ckpt")
233
+ try:
234
+ best_epoch = int(os.path.splitext(best_checkpoint_file)[0].split("_")[-1])
235
+ except Exception as e:
236
+ log.warning(
237
+ f'Could not retrieve epoch from best checkpoint file name: "{e}". '
238
+ f"Expected format: " + '"epoch_{best_epoch}.ckpt"'
239
+ )
240
 
241
  if not cfg.trainer.get("fast_dev_run"):
242
  if cfg.model_save_dir is not None:
 
268
  trainer.test(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None)
269
 
270
  test_metrics = trainer.callback_metrics
271
+ test_metrics["best_epoch"] = best_epoch
272
 
273
  # merge train and test metrics
274
  metric_dict = {**train_metrics, **test_metrics}
src/utils/__init__.py CHANGED
@@ -1,4 +1,9 @@
1
- from .config_utils import execute_pipeline, instantiate_dict_entries, prepare_omegaconf
 
 
 
 
 
2
  from .data_utils import download_and_unzip, filter_dataframe_and_get_column
3
  from .logging_utils import close_loggers, get_pylogger, log_hyperparameters
4
  from .rich_utils import enforce_tags, print_config_tree
 
1
+ from .config_utils import (
2
+ execute_pipeline,
3
+ instantiate_dict_entries,
4
+ parse_config,
5
+ prepare_omegaconf,
6
+ )
7
  from .data_utils import download_and_unzip, filter_dataframe_and_get_column
8
  from .logging_utils import close_loggers, get_pylogger, log_hyperparameters
9
  from .rich_utils import enforce_tags, print_config_tree
src/utils/config_utils.py CHANGED
@@ -1,5 +1,5 @@
1
  from copy import copy
2
- from typing import Any, List, Optional
3
 
4
  from hydra.utils import instantiate
5
  from omegaconf import DictConfig, OmegaConf
@@ -69,3 +69,17 @@ def prepare_omegaconf():
69
  OmegaConf.register_new_resolver("replace", lambda s, x, y: s.replace(x, y))
70
  else:
71
  logger.warning("OmegaConf resolver 'replace' is already registered")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from copy import copy
2
+ from typing import Any, Dict, List, Optional
3
 
4
  from hydra.utils import instantiate
5
  from omegaconf import DictConfig, OmegaConf
 
69
  OmegaConf.register_new_resolver("replace", lambda s, x, y: s.replace(x, y))
70
  else:
71
  logger.warning("OmegaConf resolver 'replace' is already registered")
72
+
73
+
74
+ def parse_config(config_string: str, format: str) -> Dict[str, Any]:
75
+ """Parse a configuration string."""
76
+ if format == "json":
77
+ import json
78
+
79
+ return json.loads(config_string)
80
+ elif format == "yaml":
81
+ import yaml
82
+
83
+ return yaml.safe_load(config_string)
84
+ else:
85
+ raise ValueError(f"Unsupported format: {format}. Use 'json' or 'yaml'.")
src/utils/pdf_utils/README.MD ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generate paper json files from a collection xml file, with fulltext extraction.
2
+
3
+ This is a slightly re-arranged version of Sotaro Takeshita's code, which is available at https://github.com/gengo-proj/data-factory.
4
+
5
+ ## Requirements
6
+
7
+ - Docker
8
+ - Python>=3.10
9
+ - python packages:
10
+ - acl-anthology-py>=0.4.3
11
+ - bs4
12
+ - jsonschema
13
+
14
+ ## Setup
15
+
16
+ Start Grobid Docker container
17
+
18
+ ```bash
19
+ docker run --rm --init --ulimit core=0 -p 8070:8070 lfoppiano/grobid:0.8.0
20
+ ```
21
+
22
+ Get the meta data from ACL Anthology
23
+
24
+ ```bash
25
+ git clone git@github.com:acl-org/acl-anthology.git
26
+ ```
27
+
28
+ ## Usage
29
+
30
+ ```bash
31
+ python src/data/acl_anthology_crawler.py \
32
+ --base-output-dir <path/to/save/raw-paper.json> \
33
+ --pdf-output-dir <path/to/save/downloaded/paper.pdf> \
34
+ --anthology-data-dir ./acl-anthology/data/
35
+ ```
src/utils/pdf_utils/__init__.py ADDED
File without changes
src/utils/pdf_utils/acl_anthology_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Iterator
3
+
4
+ from acl_anthology import Anthology
5
+
6
+ from .process_pdf import paper_url_to_uuid
7
+ from .raw_paper import RawPaper
8
+
9
+
10
+ @dataclass
11
+ class XML2RawPapers:
12
+ anthology: Anthology
13
+ collection_id_filters: list[str] | None = None
14
+ venue_id_whitelist: list[str] | None = None
15
+ verbose: bool = True
16
+
17
+ def __call__(self, *args, **kwargs) -> Iterator[RawPaper]:
18
+
19
+ for collection_id, collection in self.anthology.collections.items():
20
+ if self.collection_id_filters is not None:
21
+ if not any(
22
+ [
23
+ collection_id.find(filter_str) != -1
24
+ for filter_str in self.collection_id_filters
25
+ ]
26
+ ):
27
+ continue
28
+ if self.verbose:
29
+ print(f"Processing collection: {collection_id}")
30
+ for volume in collection.volumes():
31
+ if self.venue_id_whitelist is not None:
32
+ if not any(
33
+ [venue_id in volume.venue_ids for venue_id in self.venue_id_whitelist]
34
+ ):
35
+ continue
36
+
37
+ volume_id = f"{collection_id}-{volume.id}"
38
+
39
+ for paper in volume.papers():
40
+ fulltext, abstract = None, None
41
+ if (
42
+ paper.pdf is not None
43
+ and paper.pdf.name is not None
44
+ and paper.pdf.name.find("http") == -1
45
+ ):
46
+ name = paper.pdf.name
47
+ else:
48
+ name = (
49
+ f"{volume_id}.{paper.id.rjust(3, '0')}"
50
+ if len(collection_id) == 1
51
+ else f"{volume_id}.{paper.id}"
52
+ )
53
+
54
+ paper_uuid = paper_url_to_uuid(name)
55
+ raw_paper = RawPaper(
56
+ paper_uuid=str(paper_uuid),
57
+ name=name,
58
+ collection_id=collection_id,
59
+ collection_acronym=volume.venues()[0].acronym,
60
+ volume_id=volume_id,
61
+ booktitle=volume.title.as_text(),
62
+ paper_id=int(paper.id),
63
+ year=int(paper.year),
64
+ paper_title=paper.title.as_text(),
65
+ authors=[
66
+ {"first": author.first, "last": author.last}
67
+ for author in paper.authors
68
+ ],
69
+ abstract=(
70
+ paper.abstract.as_text() if paper.abstract is not None else abstract
71
+ ),
72
+ url=paper.pdf.url if paper.pdf is not None else None,
73
+ bibkey=paper.bibkey if paper.bibkey is not None else None,
74
+ doi=paper.doi if paper.doi is not None else None,
75
+ fulltext=fulltext,
76
+ )
77
+ yield raw_paper
src/utils/pdf_utils/client.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Generic API Client """
2
+
3
+ import json
4
+ from copy import deepcopy
5
+
6
+ import requests
7
+
8
+ try:
9
+ from urlparse import urljoin
10
+ except ImportError:
11
+ from urllib.parse import urljoin
12
+
13
+
14
+ class ApiClient(object):
15
+ """Client to interact with a generic Rest API.
16
+
17
+ Subclasses should implement functionality accordingly with the provided
18
+ service methods, i.e. ``get``, ``post``, ``put`` and ``delete``.
19
+ """
20
+
21
+ accept_type = "application/xml"
22
+ api_base = None
23
+
24
+ def __init__(self, base_url, username=None, api_key=None, status_endpoint=None, timeout=60):
25
+ """Initialise client.
26
+
27
+ Args:
28
+ base_url (str): The base URL to the service being used.
29
+ username (str): The username to authenticate with.
30
+ api_key (str): The API key to authenticate with.
31
+ timeout (int): Maximum time before timing out.
32
+ """
33
+ self.base_url = base_url
34
+ self.username = username
35
+ self.api_key = api_key
36
+ self.status_endpoint = urljoin(self.base_url, status_endpoint)
37
+ self.timeout = timeout
38
+
39
+ @staticmethod
40
+ def encode(request, data):
41
+ """Add request content data to request body, set Content-type header.
42
+
43
+ Should be overridden by subclasses if not using JSON encoding.
44
+
45
+ Args:
46
+ request (HTTPRequest): The request object.
47
+ data (dict, None): Data to be encoded.
48
+
49
+ Returns:
50
+ HTTPRequest: The request object.
51
+ """
52
+ if data is None:
53
+ return request
54
+
55
+ request.add_header("Content-Type", "application/json")
56
+ request.data = json.dumps(data)
57
+
58
+ return request
59
+
60
+ @staticmethod
61
+ def decode(response):
62
+ """Decode the returned data in the response.
63
+
64
+ Should be overridden by subclasses if something else than JSON is
65
+ expected.
66
+
67
+ Args:
68
+ response (HTTPResponse): The response object.
69
+
70
+ Returns:
71
+ dict or None.
72
+ """
73
+ try:
74
+ return response.json()
75
+ except ValueError as e:
76
+ return e.message
77
+
78
+ def get_credentials(self):
79
+ """Returns parameters to be added to authenticate the request.
80
+
81
+ This lives on its own to make it easier to re-implement it if needed.
82
+
83
+ Returns:
84
+ dict: A dictionary containing the credentials.
85
+ """
86
+ return {"username": self.username, "api_key": self.api_key}
87
+
88
+ def call_api(
89
+ self,
90
+ method,
91
+ url,
92
+ headers=None,
93
+ params=None,
94
+ data=None,
95
+ files=None,
96
+ timeout=None,
97
+ ):
98
+ """Call API.
99
+
100
+ This returns object containing data, with error details if applicable.
101
+
102
+ Args:
103
+ method (str): The HTTP method to use.
104
+ url (str): Resource location relative to the base URL.
105
+ headers (dict or None): Extra request headers to set.
106
+ params (dict or None): Query-string parameters.
107
+ data (dict or None): Request body contents for POST or PUT requests.
108
+ files (dict or None: Files to be passed to the request.
109
+ timeout (int): Maximum time before timing out.
110
+
111
+ Returns:
112
+ ResultParser or ErrorParser.
113
+ """
114
+ headers = deepcopy(headers) or {}
115
+ headers["Accept"] = self.accept_type
116
+ params = deepcopy(params) or {}
117
+ data = data or {}
118
+ files = files or {}
119
+ # if self.username is not None and self.api_key is not None:
120
+ # params.update(self.get_credentials())
121
+ r = requests.request(
122
+ method,
123
+ url,
124
+ headers=headers,
125
+ params=params,
126
+ files=files,
127
+ data=data,
128
+ timeout=timeout,
129
+ )
130
+
131
+ return r, r.status_code
132
+
133
+ def get(self, url, params=None, **kwargs):
134
+ """Call the API with a GET request.
135
+
136
+ Args:
137
+ url (str): Resource location relative to the base URL.
138
+ params (dict or None): Query-string parameters.
139
+
140
+ Returns:
141
+ ResultParser or ErrorParser.
142
+ """
143
+ return self.call_api("GET", url, params=params, **kwargs)
144
+
145
+ def delete(self, url, params=None, **kwargs):
146
+ """Call the API with a DELETE request.
147
+
148
+ Args:
149
+ url (str): Resource location relative to the base URL.
150
+ params (dict or None): Query-string parameters.
151
+
152
+ Returns:
153
+ ResultParser or ErrorParser.
154
+ """
155
+ return self.call_api("DELETE", url, params=params, **kwargs)
156
+
157
+ def put(self, url, params=None, data=None, files=None, **kwargs):
158
+ """Call the API with a PUT request.
159
+
160
+ Args:
161
+ url (str): Resource location relative to the base URL.
162
+ params (dict or None): Query-string parameters.
163
+ data (dict or None): Request body contents.
164
+ files (dict or None: Files to be passed to the request.
165
+
166
+ Returns:
167
+ An instance of ResultParser or ErrorParser.
168
+ """
169
+ return self.call_api("PUT", url, params=params, data=data, files=files, **kwargs)
170
+
171
+ def post(self, url, params=None, data=None, files=None, **kwargs):
172
+ """Call the API with a POST request.
173
+
174
+ Args:
175
+ url (str): Resource location relative to the base URL.
176
+ params (dict or None): Query-string parameters.
177
+ data (dict or None): Request body contents.
178
+ files (dict or None: Files to be passed to the request.
179
+
180
+ Returns:
181
+ An instance of ResultParser or ErrorParser.
182
+ """
183
+ return self.call_api(
184
+ method="POST", url=url, params=params, data=data, files=files, **kwargs
185
+ )
186
+
187
+ def service_status(self, **kwargs):
188
+ """Call the API to get the status of the service.
189
+
190
+ Returns:
191
+ An instance of ResultParser or ErrorParser.
192
+ """
193
+ return self.call_api("GET", self.status_endpoint, params={"format": "json"}, **kwargs)
src/utils/pdf_utils/grobid_client.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import io
3
+ import ntpath
4
+ import os
5
+ import time
6
+ from typing import List, Optional
7
+
8
+ from .client import ApiClient
9
+
10
+ # This version uses the standard ProcessPoolExecutor for parallelizing the concurrent calls to the GROBID services.
11
+ # Given the limits of ThreadPoolExecutor (input stored in memory, blocking Executor.map until the whole input
12
+ # is acquired), it works with batches of PDF of a size indicated in the config.json file (default is 1000 entries).
13
+ # We are moving from first batch to the second one only when the first is entirely processed - which means it is
14
+ # slightly sub-optimal, but should scale better. However acquiring a list of million of files in directories would
15
+ # require something scalable too, which is not implemented for the moment.
16
+ DEFAULT_GROBID_CONFIG = {
17
+ "grobid_server": "localhost",
18
+ "grobid_port": "8070",
19
+ "batch_size": 1000,
20
+ "sleep_time": 5,
21
+ "generateIDs": False,
22
+ "consolidate_header": False,
23
+ "consolidate_citations": False,
24
+ "include_raw_citations": True,
25
+ "include_raw_affiliations": False,
26
+ "max_workers": 2,
27
+ }
28
+
29
+
30
+ class GrobidClient(ApiClient):
31
+ def __init__(self, config=None):
32
+ self.config = config or DEFAULT_GROBID_CONFIG
33
+ self.generate_ids = self.config["generateIDs"]
34
+ self.consolidate_header = self.config["consolidate_header"]
35
+ self.consolidate_citations = self.config["consolidate_citations"]
36
+ self.include_raw_citations = self.config["include_raw_citations"]
37
+ self.include_raw_affiliations = self.config["include_raw_affiliations"]
38
+ self.max_workers = self.config["max_workers"]
39
+ self.grobid_server = self.config["grobid_server"]
40
+ self.grobid_port = str(self.config["grobid_port"])
41
+ self.sleep_time = self.config["sleep_time"]
42
+
43
+ def process(self, input: str, output: str, service: str):
44
+ batch_size_pdf = self.config["batch_size"]
45
+ pdf_files = []
46
+
47
+ for pdf_file in glob.glob(input + "/*.pdf"):
48
+ pdf_files.append(pdf_file)
49
+
50
+ if len(pdf_files) == batch_size_pdf:
51
+ self.process_batch(pdf_files, output, service)
52
+ pdf_files = []
53
+
54
+ # last batch
55
+ if len(pdf_files) > 0:
56
+ self.process_batch(pdf_files, output, service)
57
+
58
+ def process_batch(self, pdf_files: List[str], output: str, service: str) -> None:
59
+ print(len(pdf_files), "PDF files to process")
60
+ for pdf_file in pdf_files:
61
+ self.process_pdf(pdf_file, output, service)
62
+
63
+ def process_pdf_stream(self, pdf_file: str, pdf_strm: bytes, output: str, service: str) -> str:
64
+ # process the stream
65
+ files = {"input": (pdf_file, pdf_strm, "application/pdf", {"Expires": "0"})}
66
+
67
+ the_url = "http://" + self.grobid_server
68
+ the_url += ":" + self.grobid_port
69
+ the_url += "/api/" + service
70
+
71
+ # set the GROBID parameters
72
+ the_data = {}
73
+ if self.generate_ids:
74
+ the_data["generateIDs"] = "1"
75
+ else:
76
+ the_data["generateIDs"] = "0"
77
+
78
+ if self.consolidate_header:
79
+ the_data["consolidateHeader"] = "1"
80
+ else:
81
+ the_data["consolidateHeader"] = "0"
82
+
83
+ if self.consolidate_citations:
84
+ the_data["consolidateCitations"] = "1"
85
+ else:
86
+ the_data["consolidateCitations"] = "0"
87
+
88
+ if self.include_raw_affiliations:
89
+ the_data["includeRawAffiliations"] = "1"
90
+ else:
91
+ the_data["includeRawAffiliations"] = "0"
92
+
93
+ if self.include_raw_citations:
94
+ the_data["includeRawCitations"] = "1"
95
+ else:
96
+ the_data["includeRawCitations"] = "0"
97
+
98
+ res, status = self.post(
99
+ url=the_url, files=files, data=the_data, headers={"Accept": "text/plain"}
100
+ )
101
+
102
+ if status == 503:
103
+ time.sleep(self.sleep_time)
104
+ # TODO: check if simply passing output as output is correct
105
+ return self.process_pdf_stream(
106
+ pdf_file=pdf_file, pdf_strm=pdf_strm, service=service, output=output
107
+ )
108
+ elif status != 200:
109
+ with open(os.path.join(output, "failed.log"), "a+") as failed:
110
+ failed.write(pdf_file.strip(".pdf") + "\n")
111
+ print("Processing failed with error " + str(status))
112
+ return ""
113
+ else:
114
+ return res.text
115
+
116
+ def process_pdf(self, pdf_file: str, output: str, service: str) -> None:
117
+ # check if TEI file is already produced
118
+ # we use ntpath here to be sure it will work on Windows too
119
+ pdf_file_name = ntpath.basename(pdf_file)
120
+ filename = os.path.join(output, os.path.splitext(pdf_file_name)[0] + ".tei.xml")
121
+ if os.path.isfile(filename):
122
+ return
123
+
124
+ print(pdf_file)
125
+ pdf_strm = open(pdf_file, "rb").read()
126
+ tei_text = self.process_pdf_stream(pdf_file, pdf_strm, output, service)
127
+
128
+ # writing TEI file
129
+ if tei_text:
130
+ with io.open(filename, "w+", encoding="utf8") as tei_file:
131
+ tei_file.write(tei_text)
132
+
133
+ def process_citation(self, bib_string: str, log_file: str) -> Optional[str]:
134
+ # process citation raw string and return corresponding dict
135
+ the_data = {"citations": bib_string, "consolidateCitations": "0"}
136
+
137
+ the_url = "http://" + self.grobid_server
138
+ the_url += ":" + self.grobid_port
139
+ the_url += "/api/processCitation"
140
+
141
+ for _ in range(5):
142
+ try:
143
+ res, status = self.post(
144
+ url=the_url, data=the_data, headers={"Accept": "text/plain"}
145
+ )
146
+ if status == 503:
147
+ time.sleep(self.sleep_time)
148
+ continue
149
+ elif status != 200:
150
+ with open(log_file, "a+") as failed:
151
+ failed.write("-- BIBSTR --\n")
152
+ failed.write(bib_string + "\n\n")
153
+ break
154
+ else:
155
+ return res.text
156
+ except Exception:
157
+ continue
158
+
159
+ return None
160
+
161
+ def process_header_names(self, header_string: str, log_file: str) -> Optional[str]:
162
+ # process author names from header string
163
+ the_data = {"names": header_string}
164
+
165
+ the_url = "http://" + self.grobid_server
166
+ the_url += ":" + self.grobid_port
167
+ the_url += "/api/processHeaderNames"
168
+
169
+ res, status = self.post(url=the_url, data=the_data, headers={"Accept": "text/plain"})
170
+
171
+ if status == 503:
172
+ time.sleep(self.sleep_time)
173
+ return self.process_header_names(header_string, log_file)
174
+ elif status != 200:
175
+ with open(log_file, "a+") as failed:
176
+ failed.write("-- AUTHOR --\n")
177
+ failed.write(header_string + "\n\n")
178
+ else:
179
+ return res.text
180
+
181
+ return None
182
+
183
+ def process_affiliations(self, aff_string: str, log_file: str) -> Optional[str]:
184
+ # process affiliation from input string
185
+ the_data = {"affiliations": aff_string}
186
+
187
+ the_url = "http://" + self.grobid_server
188
+ the_url += ":" + self.grobid_port
189
+ the_url += "/api/processAffiliations"
190
+
191
+ res, status = self.post(url=the_url, data=the_data, headers={"Accept": "text/plain"})
192
+
193
+ if status == 503:
194
+ time.sleep(self.sleep_time)
195
+ return self.process_affiliations(aff_string, log_file)
196
+ elif status != 200:
197
+ with open(log_file, "a+") as failed:
198
+ failed.write("-- AFFILIATION --\n")
199
+ failed.write(aff_string + "\n\n")
200
+ else:
201
+ return res.text
202
+
203
+ return None
src/utils/pdf_utils/grobid_util.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import defaultdict
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ import bs4
6
+ from bs4 import BeautifulSoup
7
+
8
+ SUBSTITUTE_TAGS = {"persName", "orgName", "publicationStmt", "titleStmt", "biblScope"}
9
+
10
+
11
+ def clean_tags(el: bs4.element.Tag):
12
+ """
13
+ Replace all tags with lowercase version
14
+ :param el:
15
+ :return:
16
+ """
17
+ for sub_tag in SUBSTITUTE_TAGS:
18
+ for sub_el in el.find_all(sub_tag):
19
+ sub_el.name = sub_tag.lower()
20
+
21
+
22
+ def soup_from_path(file_path: str):
23
+ """
24
+ Read XML file
25
+ :param file_path:
26
+ :return:
27
+ """
28
+ return BeautifulSoup(open(file_path, "rb").read(), "xml")
29
+
30
+
31
+ def get_title_from_grobid_xml(raw_xml: BeautifulSoup) -> str:
32
+ """
33
+ Returns title
34
+ :return:
35
+ """
36
+ for title_entry in raw_xml.find_all("title"):
37
+ if title_entry.has_attr("level") and title_entry["level"] == "a":
38
+ return title_entry.text
39
+ try:
40
+ return raw_xml.title.text
41
+ except AttributeError:
42
+ return ""
43
+
44
+
45
+ def get_author_names_from_grobid_xml(
46
+ raw_xml: BeautifulSoup,
47
+ ) -> List[Dict[str, Union[str, List[str]]]]:
48
+ """
49
+ Returns a list of dictionaries, one for each author,
50
+ containing the first and last names.
51
+
52
+ e.g.
53
+ {
54
+ "first": first,
55
+ "middle": middle,
56
+ "last": last,
57
+ "suffix": suffix
58
+ }
59
+ """
60
+ names = []
61
+
62
+ for author in raw_xml.find_all("author"):
63
+ if not author.persname:
64
+ continue
65
+
66
+ # forenames include first and middle names
67
+ forenames = author.persname.find_all("forename")
68
+
69
+ # surnames include last names
70
+ surnames = author.persname.find_all("surname")
71
+
72
+ # name suffixes
73
+ suffixes = author.persname.find_all("suffix")
74
+
75
+ first = ""
76
+ middle = []
77
+ last = ""
78
+ suffix = ""
79
+
80
+ for forename in forenames:
81
+ if forename["type"] == "first":
82
+ if not first:
83
+ first = forename.text
84
+ else:
85
+ middle.append(forename.text)
86
+ elif forename["type"] == "middle":
87
+ middle.append(forename.text)
88
+
89
+ if len(surnames) > 1:
90
+ for surname in surnames[:-1]:
91
+ middle.append(surname.text)
92
+ last = surnames[-1].text
93
+ elif len(surnames) == 1:
94
+ last = surnames[0].text
95
+
96
+ if len(suffix) >= 1:
97
+ suffix = " ".join([suff.text for suff in suffixes])
98
+
99
+ names_dict: Dict[str, Union[str, List[str]]] = {
100
+ "first": first,
101
+ "middle": middle,
102
+ "last": last,
103
+ "suffix": suffix,
104
+ }
105
+
106
+ names.append(names_dict)
107
+ return names
108
+
109
+
110
+ def get_affiliation_from_grobid_xml(raw_xml: BeautifulSoup) -> Dict:
111
+ """
112
+ Get affiliation from grobid xml
113
+ :param raw_xml:
114
+ :return:
115
+ """
116
+ location_dict = dict()
117
+ laboratory_name = ""
118
+ institution_name = ""
119
+
120
+ if raw_xml and raw_xml.affiliation:
121
+ for child in raw_xml.affiliation:
122
+ if child.name == "orgname":
123
+ if child.has_attr("type"):
124
+ if child["type"] == "laboratory":
125
+ laboratory_name = child.text
126
+ elif child["type"] == "institution":
127
+ institution_name = child.text
128
+ elif child.name == "address":
129
+ for grandchild in child:
130
+ if grandchild.name and grandchild.text:
131
+ location_dict[grandchild.name] = grandchild.text
132
+
133
+ if laboratory_name or institution_name:
134
+ return {
135
+ "laboratory": laboratory_name,
136
+ "institution": institution_name,
137
+ "location": location_dict,
138
+ }
139
+
140
+ return {}
141
+
142
+
143
+ def get_author_data_from_grobid_xml(raw_xml: BeautifulSoup) -> List[Dict]:
144
+ """
145
+ Returns a list of dictionaries, one for each author,
146
+ containing the first and last names.
147
+
148
+ e.g.
149
+ {
150
+ "first": first,
151
+ "middle": middle,
152
+ "last": last,
153
+ "suffix": suffix,
154
+ "affiliation": {
155
+ "laboratory": "",
156
+ "institution": "",
157
+ "location": "",
158
+ },
159
+ "email": ""
160
+ }
161
+ """
162
+ authors = []
163
+
164
+ for author in raw_xml.find_all("author"):
165
+
166
+ first = ""
167
+ middle = []
168
+ last = ""
169
+ suffix = ""
170
+
171
+ if author.persname:
172
+ # forenames include first and middle names
173
+ forenames = author.persname.find_all("forename")
174
+
175
+ # surnames include last names
176
+ surnames = author.persname.find_all("surname")
177
+
178
+ # name suffixes
179
+ suffixes = author.persname.find_all("suffix")
180
+
181
+ for forename in forenames:
182
+ if forename.has_attr("type"):
183
+ if forename["type"] == "first":
184
+ if not first:
185
+ first = forename.text
186
+ else:
187
+ middle.append(forename.text)
188
+ elif forename["type"] == "middle":
189
+ middle.append(forename.text)
190
+
191
+ if len(surnames) > 1:
192
+ for surname in surnames[:-1]:
193
+ middle.append(surname.text)
194
+ last = surnames[-1].text
195
+ elif len(surnames) == 1:
196
+ last = surnames[0].text
197
+
198
+ if len(suffix) >= 1:
199
+ suffix = " ".join([suffix.text for suffix in suffixes])
200
+
201
+ affiliation = get_affiliation_from_grobid_xml(author)
202
+
203
+ email = ""
204
+ if author.email:
205
+ email = author.email.text
206
+
207
+ author_dict = {
208
+ "first": first,
209
+ "middle": middle,
210
+ "last": last,
211
+ "suffix": suffix,
212
+ "affiliation": affiliation,
213
+ "email": email,
214
+ }
215
+
216
+ authors.append(author_dict)
217
+
218
+ return authors
219
+
220
+
221
+ def get_year_from_grobid_xml(raw_xml: BeautifulSoup) -> Optional[int]:
222
+ """
223
+ Returns date published if exists
224
+ :return:
225
+ """
226
+ if raw_xml.date and raw_xml.date.has_attr("when"):
227
+ # match year in date text (which is in some unspecified date format)
228
+ year_match = re.match(r"((19|20)\d{2})", raw_xml.date["when"])
229
+ if year_match:
230
+ year = year_match.group(0)
231
+ if year and year.isnumeric() and len(year) == 4:
232
+ return int(year)
233
+ return None
234
+
235
+
236
+ def get_venue_from_grobid_xml(raw_xml: BeautifulSoup, title_text: str) -> str:
237
+ """
238
+ Returns venue/journal/publisher of bib entry
239
+ Grobid ref documentation: https://grobid.readthedocs.io/en/latest/training/Bibliographical-references/
240
+ level="j": journal title
241
+ level="m": "non journal bibliographical item holding the cited article"
242
+ level="s": series title
243
+ :return:
244
+ """
245
+ title_names = []
246
+ keep_types = ["j", "m", "s"]
247
+ # get all titles of the anove types
248
+ for title_entry in raw_xml.find_all("title"):
249
+ if (
250
+ title_entry.has_attr("level")
251
+ and title_entry["level"] in keep_types
252
+ and title_entry.text != title_text
253
+ ):
254
+ title_names.append((title_entry["level"], title_entry.text))
255
+ # return the title name that most likely belongs to the journal or publication venue
256
+ if title_names:
257
+ title_names.sort(key=lambda x: keep_types.index(x[0]))
258
+ return title_names[0][1]
259
+ return ""
260
+
261
+
262
+ def get_volume_from_grobid_xml(raw_xml: BeautifulSoup) -> str:
263
+ """
264
+ Returns the volume number of grobid bib entry
265
+ Grobid <biblscope unit="volume">
266
+ :return:
267
+ """
268
+ for bibl_entry in raw_xml.find_all("biblscope"):
269
+ if bibl_entry.has_attr("unit") and bibl_entry["unit"] == "volume":
270
+ return bibl_entry.text
271
+ return ""
272
+
273
+
274
+ def get_issue_from_grobid_xml(raw_xml: BeautifulSoup) -> str:
275
+ """
276
+ Returns the issue number of grobid bib entry
277
+ Grobid <biblscope unit="issue">
278
+ :return:
279
+ """
280
+ for bibl_entry in raw_xml.find_all("biblscope"):
281
+ if bibl_entry.has_attr("unit") and bibl_entry["unit"] == "issue":
282
+ return bibl_entry.text
283
+ return ""
284
+
285
+
286
+ def get_pages_from_grobid_xml(raw_xml: BeautifulSoup) -> str:
287
+ """
288
+ Returns the page numbers of grobid bib entry
289
+ Grobid <biblscope unit="page">
290
+ :return:
291
+ """
292
+ for bibl_entry in raw_xml.find_all("biblscope"):
293
+ if (
294
+ bibl_entry.has_attr("unit")
295
+ and bibl_entry["unit"] == "page"
296
+ and bibl_entry.has_attr("from")
297
+ ):
298
+ from_page = bibl_entry["from"]
299
+ if bibl_entry.has_attr("to"):
300
+ to_page = bibl_entry["to"]
301
+ return f"{from_page}--{to_page}"
302
+ else:
303
+ return from_page
304
+ return ""
305
+
306
+
307
+ def get_other_ids_from_grobid_xml(raw_xml: BeautifulSoup) -> Dict[str, List]:
308
+ """
309
+ Returns a dictionary of other identifiers from grobid bib entry (arxiv, pubmed, doi)
310
+ :param raw_xml:
311
+ :return:
312
+ """
313
+ other_ids = defaultdict(list)
314
+
315
+ for idno_entry in raw_xml.find_all("idno"):
316
+ if idno_entry.has_attr("type") and idno_entry.text:
317
+ other_ids[idno_entry["type"]].append(idno_entry.text)
318
+
319
+ return other_ids
320
+
321
+
322
+ def get_raw_bib_text_from_grobid_xml(raw_xml: BeautifulSoup) -> str:
323
+ """
324
+ Returns the raw bibiliography string
325
+ :param raw_xml:
326
+ :return:
327
+ """
328
+ for note in raw_xml.find_all("note"):
329
+ if note.has_attr("type") and note["type"] == "raw_reference":
330
+ return note.text
331
+ return ""
332
+
333
+
334
+ def get_publication_datetime_from_grobid_xml(raw_xml: BeautifulSoup) -> str:
335
+ """
336
+ Finds and returns the publication datetime if it exists
337
+ :param raw_xml:
338
+ :return:
339
+ """
340
+ if raw_xml.publicationStmt:
341
+ for child in raw_xml.publicationstmt:
342
+ if (
343
+ child.name == "date"
344
+ and child.has_attr("type")
345
+ and child["type"] == "published"
346
+ and child.has_attr("when")
347
+ ):
348
+ return child["when"]
349
+ return ""
350
+
351
+
352
+ def parse_bib_entry(bib_entry: BeautifulSoup) -> Dict:
353
+ """
354
+ Parse one bib entry
355
+ :param bib_entry:
356
+ :return:
357
+ """
358
+ clean_tags(bib_entry)
359
+ title = get_title_from_grobid_xml(bib_entry)
360
+ return {
361
+ "ref_id": bib_entry.attrs.get("xml:id", None),
362
+ "title": title,
363
+ "authors": get_author_names_from_grobid_xml(bib_entry),
364
+ "year": get_year_from_grobid_xml(bib_entry),
365
+ "venue": get_venue_from_grobid_xml(bib_entry, title),
366
+ "volume": get_volume_from_grobid_xml(bib_entry),
367
+ "issue": get_issue_from_grobid_xml(bib_entry),
368
+ "pages": get_pages_from_grobid_xml(bib_entry),
369
+ "other_ids": get_other_ids_from_grobid_xml(bib_entry),
370
+ "raw_text": get_raw_bib_text_from_grobid_xml(bib_entry),
371
+ "urls": [],
372
+ }
373
+
374
+
375
+ def is_reference_tag(tag: bs4.element.Tag) -> bool:
376
+ return tag.name == "ref" and tag.attrs.get("type", "") == "bibr"
377
+
378
+
379
+ def extract_paper_metadata_from_grobid_xml(tag: bs4.element.Tag) -> Dict:
380
+ """
381
+ Extract paper metadata (title, authors, affiliation, year) from grobid xml
382
+ :param tag:
383
+ :return:
384
+ """
385
+ clean_tags(tag)
386
+ paper_metadata = {
387
+ "title": tag.titlestmt.title.text,
388
+ "authors": get_author_data_from_grobid_xml(tag),
389
+ "year": get_publication_datetime_from_grobid_xml(tag),
390
+ }
391
+ return paper_metadata
392
+
393
+
394
+ def parse_bibliography(soup: BeautifulSoup) -> List[Dict]:
395
+ """
396
+ Finds all bibliography entries in a grobid xml.
397
+ """
398
+ bibliography = soup.listBibl
399
+ if bibliography is None:
400
+ return []
401
+
402
+ entries = bibliography.find_all("biblStruct")
403
+
404
+ structured_entries = []
405
+ for entry in entries:
406
+ bib_entry = parse_bib_entry(entry)
407
+ # add bib entry only if it has a title
408
+ if bib_entry["title"]:
409
+ structured_entries.append(bib_entry)
410
+
411
+ bibliography.decompose()
412
+
413
+ return structured_entries
src/utils/pdf_utils/process_pdf.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import uuid
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Dict, Optional
7
+
8
+ import requests
9
+ from bs4 import BeautifulSoup
10
+
11
+ from .grobid_client import GrobidClient
12
+ from .grobid_util import extract_paper_metadata_from_grobid_xml, parse_bibliography
13
+ from .s2orc_paper import Paper
14
+ from .utils import (
15
+ _clean_empty_and_duplicate_authors_from_grobid_parse,
16
+ check_if_citations_are_bracket_style,
17
+ extract_abstract_from_tei_xml,
18
+ extract_back_matter_from_tei_xml,
19
+ extract_body_text_from_tei_xml,
20
+ extract_figures_and_tables_from_tei_xml,
21
+ normalize_grobid_id,
22
+ sub_all_note_tags,
23
+ )
24
+
25
+ BASE_TEMP_DIR = "./grobid/temp"
26
+ BASE_OUTPUT_DIR = "./grobid/output"
27
+ BASE_LOG_DIR = "./grobid/log"
28
+
29
+
30
+ def convert_tei_xml_soup_to_s2orc_json(soup: BeautifulSoup, paper_id: str, pdf_hash: str) -> Paper:
31
+ """
32
+ Convert Grobid TEI XML to S2ORC json format
33
+ :param soup: BeautifulSoup of XML file content
34
+ :param paper_id: name of file
35
+ :param pdf_hash: hash of PDF
36
+ :return:
37
+ """
38
+ # extract metadata
39
+ metadata = extract_paper_metadata_from_grobid_xml(soup.fileDesc)
40
+ # clean metadata authors (remove dupes etc)
41
+ metadata["authors"] = _clean_empty_and_duplicate_authors_from_grobid_parse(metadata["authors"])
42
+
43
+ # parse bibliography entries (removes empty bib entries)
44
+ biblio_entries = parse_bibliography(soup)
45
+ bibkey_map = {normalize_grobid_id(bib["ref_id"]): bib for bib in biblio_entries}
46
+
47
+ # # process formulas and replace with text
48
+ # extract_formulas_from_tei_xml(soup)
49
+
50
+ # extract figure and table captions
51
+ refkey_map = extract_figures_and_tables_from_tei_xml(soup)
52
+
53
+ # get bracket style
54
+ is_bracket_style = check_if_citations_are_bracket_style(soup)
55
+
56
+ # substitute all note tags with p tags
57
+ soup = sub_all_note_tags(soup)
58
+
59
+ # process abstract if possible
60
+ abstract_entries = extract_abstract_from_tei_xml(
61
+ soup, bibkey_map, refkey_map, is_bracket_style
62
+ )
63
+
64
+ # process body text
65
+ body_entries = extract_body_text_from_tei_xml(soup, bibkey_map, refkey_map, is_bracket_style)
66
+
67
+ # parse back matter (acks, author statements, competing interests, abbrevs etc)
68
+ back_matter = extract_back_matter_from_tei_xml(soup, bibkey_map, refkey_map, is_bracket_style)
69
+
70
+ # form final paper entry
71
+ return Paper(
72
+ paper_id=paper_id,
73
+ pdf_hash=pdf_hash,
74
+ metadata=metadata,
75
+ abstract=abstract_entries,
76
+ body_text=body_entries,
77
+ back_matter=back_matter,
78
+ bib_entries=bibkey_map,
79
+ ref_entries=refkey_map,
80
+ )
81
+
82
+
83
+ def convert_tei_xml_file_to_s2orc_json(tei_file: str, pdf_hash: str = "") -> Paper:
84
+ """
85
+ Convert a TEI XML file to S2ORC JSON
86
+ :param tei_file:
87
+ :param pdf_hash:
88
+ :return:
89
+ """
90
+ if not os.path.exists(tei_file):
91
+ raise FileNotFoundError("Input TEI XML file doesn't exist")
92
+ paper_id = tei_file.split("/")[-1].split(".")[0]
93
+ soup = BeautifulSoup(open(tei_file, "rb").read(), "xml")
94
+ paper = convert_tei_xml_soup_to_s2orc_json(soup, paper_id, pdf_hash)
95
+ return paper
96
+
97
+
98
+ def process_pdf_stream(
99
+ input_file: str, sha: str, input_stream: bytes, grobid_config: Optional[Dict] = None
100
+ ) -> Dict:
101
+ """
102
+ Process PDF stream
103
+ :param input_file:
104
+ :param sha:
105
+ :param input_stream:
106
+ :return:
107
+ """
108
+ # process PDF through Grobid -> TEI.XML
109
+ client = GrobidClient(grobid_config)
110
+ tei_text = client.process_pdf_stream(
111
+ input_file, input_stream, "temp", "processFulltextDocument"
112
+ )
113
+
114
+ # make soup
115
+ soup = BeautifulSoup(tei_text, "xml")
116
+
117
+ # get paper
118
+ paper = convert_tei_xml_soup_to_s2orc_json(soup, input_file, sha)
119
+
120
+ return paper.release_json("pdf")
121
+
122
+
123
+ def process_pdf_file(
124
+ input_file: str,
125
+ temp_dir: str = BASE_TEMP_DIR,
126
+ output_dir: str = BASE_OUTPUT_DIR,
127
+ grobid_config: Optional[Dict] = None,
128
+ verbose: bool = True,
129
+ ) -> str:
130
+ """
131
+ Process a PDF file and get JSON representation
132
+ :param input_file:
133
+ :param temp_dir:
134
+ :param output_dir:
135
+ :return:
136
+ """
137
+ os.makedirs(temp_dir, exist_ok=True)
138
+ os.makedirs(output_dir, exist_ok=True)
139
+
140
+ # get paper id as the name of the file
141
+ paper_id = ".".join(input_file.split("/")[-1].split(".")[:-1])
142
+ tei_file = os.path.join(temp_dir, f"{paper_id}.tei.xml")
143
+ output_file = os.path.join(output_dir, f"{paper_id}.json")
144
+
145
+ # check if input file exists and output file doesn't
146
+ if not os.path.exists(input_file):
147
+ raise FileNotFoundError(f"{input_file} doesn't exist")
148
+ if os.path.exists(output_file):
149
+ if verbose:
150
+ print(f"{output_file} already exists!")
151
+ return output_file
152
+
153
+ # process PDF through Grobid -> TEI.XML
154
+ client = GrobidClient(grobid_config)
155
+ # TODO: compute PDF hash
156
+ # TODO: add grobid version number to output
157
+ client.process_pdf(input_file, temp_dir, "processFulltextDocument")
158
+
159
+ # process TEI.XML -> JSON
160
+ assert os.path.exists(tei_file)
161
+ paper = convert_tei_xml_file_to_s2orc_json(tei_file)
162
+
163
+ # write to file
164
+ with open(output_file, "w") as outf:
165
+ json.dump(paper.release_json(), outf, indent=4, sort_keys=False)
166
+
167
+ return output_file
168
+
169
+
170
+ UUID_NAMESPACE = uuid.UUID("bab08d37-ac12-40c4-847a-20ca337742fd")
171
+
172
+
173
+ def paper_url_to_uuid(paper_url: str) -> "uuid.UUID":
174
+ return uuid.uuid5(UUID_NAMESPACE, paper_url)
175
+
176
+
177
+ @dataclass
178
+ class PDFDownloader:
179
+ verbose: bool = True
180
+
181
+ def download(self, url: str, opath: str | Path) -> Path:
182
+ """Download a pdf file from URL and save locally.
183
+ Skip if there is a file at `opath` already.
184
+
185
+ Parameters
186
+ ----------
187
+ url : str
188
+ URL of the target PDF file
189
+ opath : str
190
+ Path to save downloaded PDF data.
191
+ """
192
+ if os.path.exists(opath):
193
+ return Path(opath)
194
+
195
+ if not os.path.exists(os.path.dirname(opath)):
196
+ os.makedirs(os.path.dirname(opath), exist_ok=True)
197
+
198
+ if self.verbose:
199
+ print(f"Downloading {url} into {opath}")
200
+ with open(opath, "wb") as f:
201
+ res = requests.get(url)
202
+ f.write(res.content)
203
+
204
+ return Path(opath)
205
+
206
+
207
+ @dataclass
208
+ class FulltextExtractor:
209
+
210
+ def __call__(self, pdf_file_path: Path | str) -> tuple[str, dict] | None:
211
+ """Extract plain text from a PDf file"""
212
+ raise NotImplementedError
213
+
214
+
215
+ @dataclass
216
+ class GrobidFulltextExtractor(FulltextExtractor):
217
+ tmp_dir: str = "./tmp/grobid"
218
+ grobid_config: Optional[Dict] = None
219
+ section_seperator: str = "\n\n"
220
+ paragraph_seperator: str = "\n"
221
+ verbose: bool = True
222
+
223
+ def construct_plain_text(self, extraction_result: dict) -> str:
224
+
225
+ section_strings = []
226
+
227
+ # add the title, if available (consider it as the first section)
228
+ title = extraction_result.get("title")
229
+ if title and title.strip():
230
+ section_strings.append(title.strip())
231
+
232
+ section_paragraphs: dict[str, list[str]] = extraction_result["sections"]
233
+ section_strings.extend(
234
+ self.paragraph_seperator.join(
235
+ # consider the section title as the first paragraph and
236
+ # remove empty paragraphs
237
+ filter(lambda s: len(s) > 0, map(lambda s: s.strip(), [section_name] + paragraphs))
238
+ )
239
+ for section_name, paragraphs in section_paragraphs.items()
240
+ )
241
+
242
+ return self.section_seperator.join(section_strings)
243
+
244
+ def postprocess_extraction_result(self, extraction_result: dict) -> dict:
245
+
246
+ # add sections
247
+ sections: dict[str, list[str]] = {}
248
+ for body_text in extraction_result["pdf_parse"]["body_text"]:
249
+ section_name = body_text["section"]
250
+
251
+ if section_name not in sections.keys():
252
+ sections[section_name] = []
253
+ sections[section_name] += [body_text["text"]]
254
+ extraction_result = {**extraction_result, "sections": sections}
255
+
256
+ return extraction_result
257
+
258
+ def __call__(self, pdf_file_path: Path | str) -> tuple[str, dict] | None:
259
+ """Extract plain text from a PDf file"""
260
+ try:
261
+ extraction_fpath = process_pdf_file(
262
+ str(pdf_file_path),
263
+ temp_dir=self.tmp_dir,
264
+ output_dir=self.tmp_dir,
265
+ grobid_config=self.grobid_config,
266
+ verbose=self.verbose,
267
+ )
268
+ with open(extraction_fpath, "r") as f:
269
+ extraction_result = json.load(f)
270
+
271
+ processed_extraction_result = self.postprocess_extraction_result(extraction_result)
272
+ plain_text = self.construct_plain_text(processed_extraction_result)
273
+ return plain_text, extraction_result
274
+ except AssertionError:
275
+ print("Grobid failed to parse this document.")
276
+ return None
src/utils/pdf_utils/raw_paper.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from dataclasses import asdict, dataclass
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ from jsonschema import validate
8
+
9
+ # TODO: load from file
10
+ schema = {
11
+ "title": "RawPaper",
12
+ "type": "object",
13
+ "properties": {
14
+ "paper_uuid": {"type": "string"},
15
+ "name": {"type": "string"},
16
+ "collection_id": {"type": "string"},
17
+ "collection_acronym": {"type": "string"},
18
+ "volume_id": {"type": "string"},
19
+ "booktitle": {"type": "string"},
20
+ "paper_id": {"type": "integer"},
21
+ "year": {"type": ["integer", "null"]},
22
+ "paper_title": {"type": "string"},
23
+ "authors": {
24
+ "type": "array",
25
+ "items": {
26
+ "type": "object",
27
+ "items": {
28
+ "first": {"type": ["string", "null"]},
29
+ "last": {"type": ["string", "null"]},
30
+ },
31
+ },
32
+ },
33
+ "abstract": {"type": ["string", "null"]},
34
+ "url": {"type": "string"},
35
+ "bibkey": {"type": ["string", "null"]},
36
+ "doi": {"type": ["string", "null"]},
37
+ "fulltext": {
38
+ "type": ["object", "null"],
39
+ "patternProperties": {"^.*$": {"type": "array", "items": {"type": "string"}}},
40
+ },
41
+ },
42
+ }
43
+
44
+ assert isinstance(schema, dict)
45
+
46
+
47
+ @dataclass
48
+ class RawPaper:
49
+ paper_uuid: str
50
+ name: str
51
+
52
+ collection_id: str
53
+ collection_acronym: str
54
+ volume_id: str
55
+ booktitle: str
56
+ paper_id: int
57
+ year: int | None
58
+
59
+ paper_title: str
60
+ authors: list[dict[str, str | None]]
61
+ abstract: str | None
62
+ url: str | None
63
+ bibkey: str
64
+ doi: str | None
65
+ fulltext: dict[str, list[str]] | None
66
+
67
+ @classmethod
68
+ def load_from_json(cls, fpath: str | Path) -> "RawPaper":
69
+ fpath = fpath if not isinstance(fpath, Path) else str(fpath)
70
+ # return cls(**sienna.load(fpath))
71
+ with open(fpath, "r") as f:
72
+ data = cls(**json.load(f))
73
+ return data
74
+
75
+ def get_fname(self) -> str:
76
+ return f"{self.name}.json"
77
+
78
+ def dumps(self) -> dict[str, Any]:
79
+ return asdict(self)
80
+
81
+ def validate(self) -> None:
82
+ validate(self.dumps(), schema=schema)
83
+
84
+ def save(self, odir: str) -> None:
85
+ self.validate()
86
+ if not os.path.exists(odir):
87
+ os.makedirs(odir, exist_ok=True)
88
+ opath = os.path.join(odir, self.get_fname())
89
+ with open(opath, "w") as f:
90
+ f.write(json.dumps(self.dumps(), indent=2))
src/utils/pdf_utils/s2orc_paper.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ S2ORC_NAME_STRING = "S2ORC"
5
+ S2ORC_VERSION_STRING = "1.0.0"
6
+
7
+ CORRECT_KEYS = {"issn": "issue", "type": "type_str"}
8
+
9
+ SKIP_KEYS = {"link", "bib_id"}
10
+
11
+ REFERENCE_OUTPUT_KEYS = {
12
+ "figure": {"text", "type_str", "uris", "num", "fig_num"},
13
+ "table": {"text", "type_str", "content", "num", "html"},
14
+ "footnote": {"text", "type_str", "num"},
15
+ "section": {"text", "type_str", "num", "parent"},
16
+ "equation": {"text", "type_str", "latex", "mathml", "num"},
17
+ }
18
+
19
+ METADATA_KEYS = {"title", "authors", "year", "venue", "identifiers"}
20
+
21
+
22
+ class ReferenceEntry:
23
+ """
24
+ Class for representing S2ORC figure and table references
25
+
26
+ An example json representation (values are examples, not accurate):
27
+
28
+ {
29
+ "FIGREF0": {
30
+ "text": "FIG. 2. Depth profiles of...",
31
+ "latex": null,
32
+ "type": "figure"
33
+ },
34
+ "TABREF2": {
35
+ "text": "Diversity indices of...",
36
+ "latex": null,
37
+ "type": "table",
38
+ "content": "",
39
+ "html": ""
40
+ }
41
+ }
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ ref_id: str,
47
+ text: str,
48
+ type_str: str,
49
+ latex: Optional[str] = None,
50
+ mathml: Optional[str] = None,
51
+ content: Optional[str] = None,
52
+ html: Optional[str] = None,
53
+ uris: Optional[List[str]] = None,
54
+ num: Optional[str] = None,
55
+ parent: Optional[str] = None,
56
+ fig_num: Optional[str] = None,
57
+ ):
58
+ self.ref_id = ref_id
59
+ self.text = text
60
+ self.type_str = type_str
61
+ self.latex = latex
62
+ self.mathml = mathml
63
+ self.content = content
64
+ self.html = html
65
+ self.uris = uris
66
+ self.num = num
67
+ self.parent = parent
68
+ self.fig_num = fig_num
69
+
70
+ def as_json(self):
71
+ keep_keys = REFERENCE_OUTPUT_KEYS.get(self.type_str, None)
72
+ if keep_keys:
73
+ return {k: self.__getattribute__(k) for k in keep_keys}
74
+ else:
75
+ return {
76
+ "text": self.text,
77
+ "type": self.type_str,
78
+ "latex": self.latex,
79
+ "mathml": self.mathml,
80
+ "content": self.content,
81
+ "html": self.html,
82
+ "uris": self.uris,
83
+ "num": self.num,
84
+ "parent": self.parent,
85
+ "fig_num": self.fig_num,
86
+ }
87
+
88
+
89
+ class BibliographyEntry:
90
+ """
91
+ Class for representing S2ORC parsed bibliography entries
92
+
93
+ An example json representation (values are examples, not accurate):
94
+
95
+ {
96
+ "title": "Mobility Reports...",
97
+ "authors": [
98
+ {
99
+ "first": "A",
100
+ "middle": ["A"],
101
+ "last": "Haija",
102
+ "suffix": ""
103
+ }
104
+ ],
105
+ "year": 2015,
106
+ "venue": "IEEE Wireless Commune Mag",
107
+ "volume": "42",
108
+ "issn": "9",
109
+ "pages": "80--92",
110
+ "other_ids": {
111
+ "doi": [
112
+ "10.1109/TWC.2014.2360196"
113
+ ],
114
+
115
+ }
116
+ }
117
+
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ bib_id: str,
123
+ title: str,
124
+ authors: List[Dict[str, str]],
125
+ ref_id: Optional[str] = None,
126
+ year: Optional[int] = None,
127
+ venue: Optional[str] = None,
128
+ volume: Optional[str] = None,
129
+ issue: Optional[str] = None,
130
+ pages: Optional[str] = None,
131
+ other_ids: Optional[Dict[str, List]] = None,
132
+ num: Optional[int] = None,
133
+ urls: Optional[List] = None,
134
+ raw_text: Optional[str] = None,
135
+ links: Optional[List] = None,
136
+ ):
137
+ self.bib_id = bib_id
138
+ self.ref_id = ref_id
139
+ self.title = title
140
+ self.authors = authors
141
+ self.year = year
142
+ self.venue = venue
143
+ self.volume = volume
144
+ self.issue = issue
145
+ self.pages = pages
146
+ self.other_ids = other_ids
147
+ self.num = num
148
+ self.urls = urls
149
+ self.raw_text = raw_text
150
+ self.links = links
151
+
152
+ def as_json(self):
153
+ return {
154
+ "ref_id": self.ref_id,
155
+ "title": self.title,
156
+ "authors": self.authors,
157
+ "year": self.year,
158
+ "venue": self.venue,
159
+ "volume": self.volume,
160
+ "issue": self.issue,
161
+ "pages": self.pages,
162
+ "other_ids": self.other_ids,
163
+ "num": self.num,
164
+ "urls": self.urls,
165
+ "raw_text": self.raw_text,
166
+ "links": self.links,
167
+ }
168
+
169
+
170
+ class Affiliation:
171
+ """
172
+ Class for representing affiliation info
173
+
174
+ Example:
175
+ {
176
+ "laboratory": "Key Laboratory of Urban Environment and Health",
177
+ "institution": "Chinese Academy of Sciences",
178
+ "location": {
179
+ "postCode": "361021",
180
+ "settlement": "Xiamen",
181
+ "country": "People's Republic of China"
182
+ }
183
+ """
184
+
185
+ def __init__(self, laboratory: str, institution: str, location: Dict):
186
+ self.laboratory = laboratory
187
+ self.institution = institution
188
+ self.location = location
189
+
190
+ def as_json(self):
191
+ return {
192
+ "laboratory": self.laboratory,
193
+ "institution": self.institution,
194
+ "location": self.location,
195
+ }
196
+
197
+
198
+ class Author:
199
+ """
200
+ Class for representing paper authors
201
+
202
+ Example:
203
+
204
+ {
205
+ "first": "Anyi",
206
+ "middle": [],
207
+ "last": "Hu",
208
+ "suffix": "",
209
+ "affiliation": {
210
+ "laboratory": "Key Laboratory of Urban Environment and Health",
211
+ "institution": "Chinese Academy of Sciences",
212
+ "location": {
213
+ "postCode": "361021",
214
+ "settlement": "Xiamen",
215
+ "country": "People's Republic of China"
216
+ }
217
+ },
218
+ "email": ""
219
+ }
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ first: str,
225
+ middle: List[str],
226
+ last: str,
227
+ suffix: str,
228
+ affiliation: Optional[Dict] = None,
229
+ email: Optional[str] = None,
230
+ ):
231
+ self.first = first
232
+ self.middle = middle
233
+ self.last = last
234
+ self.suffix = suffix
235
+ self.affiliation = Affiliation(**affiliation) if affiliation else {}
236
+ self.email = email
237
+
238
+ def as_json(self):
239
+ return {
240
+ "first": self.first,
241
+ "middle": self.middle,
242
+ "last": self.last,
243
+ "suffix": self.suffix,
244
+ "affiliation": self.affiliation.as_json() if self.affiliation else {},
245
+ "email": self.email,
246
+ }
247
+
248
+
249
+ class Metadata:
250
+ """
251
+ Class for representing paper metadata
252
+
253
+ Example:
254
+ {
255
+ "title": "Niche Partitioning...",
256
+ "authors": [
257
+ {
258
+ "first": "Anyi",
259
+ "middle": [],
260
+ "last": "Hu",
261
+ "suffix": "",
262
+ "affiliation": {
263
+ "laboratory": "Key Laboratory of Urban Environment and Health",
264
+ "institution": "Chinese Academy of Sciences",
265
+ "location": {
266
+ "postCode": "361021",
267
+ "settlement": "Xiamen",
268
+ "country": "People's Republic of China"
269
+ }
270
+ },
271
+ "email": ""
272
+ }
273
+ ],
274
+ "year": "2011-11"
275
+ }
276
+ """
277
+
278
+ def __init__(
279
+ self,
280
+ title: str,
281
+ authors: List[Dict],
282
+ year: Optional[str] = None,
283
+ venue: Optional[str] = None,
284
+ identifiers: Optional[Dict] = {},
285
+ ):
286
+ self.title = title
287
+ self.authors = [Author(**author) for author in authors]
288
+ self.year = year
289
+ self.venue = venue
290
+ self.identifiers = identifiers
291
+
292
+ def as_json(self):
293
+ return {
294
+ "title": self.title,
295
+ "authors": [author.as_json() for author in self.authors],
296
+ "year": self.year,
297
+ "venue": self.venue,
298
+ "identifiers": self.identifiers,
299
+ }
300
+
301
+
302
+ class Paragraph:
303
+ """
304
+ Class for representing a parsed paragraph from Grobid xml
305
+ All xml tags are removed from the paragraph text, all figures, equations, and tables are replaced
306
+ with a special token that maps to a reference identifier
307
+ Citation mention spans and section header are extracted
308
+
309
+ An example json representation (values are examples, not accurate):
310
+
311
+ {
312
+ "text": "Formal language techniques BID1 may be used to study FORMULA0 (see REF0)...",
313
+ "mention_spans": [
314
+ {
315
+ "start": 27,
316
+ "end": 31,
317
+ "text": "[1]")
318
+ ],
319
+ "ref_spans": [
320
+ {
321
+ "start": ,
322
+ "end": ,
323
+ "text": "Fig. 1"
324
+ }
325
+ ],
326
+ "eq_spans": [
327
+ {
328
+ "start": 53,
329
+ "end": 61,
330
+ "text": "α = 1",
331
+ "latex": "\\alpha = 1",
332
+ "ref_id": null
333
+ }
334
+ ],
335
+ "section": "Abstract"
336
+ }
337
+ """
338
+
339
+ def __init__(
340
+ self,
341
+ text: str,
342
+ cite_spans: List[Dict],
343
+ ref_spans: List[Dict],
344
+ eq_spans: Optional[List[Dict]] = [],
345
+ section: Optional[Any] = None,
346
+ sec_num: Optional[Any] = None,
347
+ ):
348
+ self.text = text
349
+ self.cite_spans = cite_spans
350
+ self.ref_spans = ref_spans
351
+ self.eq_spans = eq_spans
352
+ if type(section) is str:
353
+ if section:
354
+ sec_parts = section.split("::")
355
+ section_list = [[None, sec_name] for sec_name in sec_parts]
356
+ else:
357
+ section_list = None
358
+ if section_list and sec_num:
359
+ section_list[-1][0] = sec_num
360
+ else:
361
+ section_list = section
362
+ self.section = section_list
363
+
364
+ def as_json(self):
365
+ return {
366
+ "text": self.text,
367
+ "cite_spans": self.cite_spans,
368
+ "ref_spans": self.ref_spans,
369
+ "eq_spans": self.eq_spans,
370
+ "section": "::".join([sec[1] for sec in self.section]) if self.section else "",
371
+ "sec_num": self.section[-1][0] if self.section else None,
372
+ }
373
+
374
+
375
+ class Paper:
376
+ """
377
+ Class for representing a parsed S2ORC paper
378
+ """
379
+
380
+ def __init__(
381
+ self,
382
+ paper_id: str,
383
+ pdf_hash: str,
384
+ metadata: Dict,
385
+ abstract: List[Dict],
386
+ body_text: List[Dict],
387
+ back_matter: List[Dict],
388
+ bib_entries: Dict,
389
+ ref_entries: Dict,
390
+ ):
391
+ self.paper_id = paper_id
392
+ self.pdf_hash = pdf_hash
393
+ self.metadata = Metadata(**metadata)
394
+ self.abstract = [Paragraph(**para) for para in abstract]
395
+ self.body_text = [Paragraph(**para) for para in body_text]
396
+ self.back_matter = [Paragraph(**para) for para in back_matter]
397
+ self.bib_entries = [
398
+ BibliographyEntry(
399
+ bib_id=key,
400
+ **{
401
+ CORRECT_KEYS[k] if k in CORRECT_KEYS else k: v
402
+ for k, v in bib.items()
403
+ if k not in SKIP_KEYS
404
+ },
405
+ )
406
+ for key, bib in bib_entries.items()
407
+ ]
408
+ self.ref_entries = [
409
+ ReferenceEntry(
410
+ ref_id=key,
411
+ **{
412
+ CORRECT_KEYS[k] if k in CORRECT_KEYS else k: v
413
+ for k, v in ref.items()
414
+ if k != "ref_id"
415
+ },
416
+ )
417
+ for key, ref in ref_entries.items()
418
+ ]
419
+
420
+ def as_json(self):
421
+ return {
422
+ "paper_id": self.paper_id,
423
+ "pdf_hash": self.pdf_hash,
424
+ "metadata": self.metadata.as_json(),
425
+ "abstract": [para.as_json() for para in self.abstract],
426
+ "body_text": [para.as_json() for para in self.body_text],
427
+ "back_matter": [para.as_json() for para in self.back_matter],
428
+ "bib_entries": {bib.bib_id: bib.as_json() for bib in self.bib_entries},
429
+ "ref_entries": {ref.ref_id: ref.as_json() for ref in self.ref_entries},
430
+ }
431
+
432
+ @property
433
+ def raw_abstract_text(self) -> str:
434
+ """
435
+ Get all the body text joined by a newline
436
+ :return:
437
+ """
438
+ return "\n".join([para.text for para in self.abstract])
439
+
440
+ @property
441
+ def raw_body_text(self) -> str:
442
+ """
443
+ Get all the body text joined by a newline
444
+ :return:
445
+ """
446
+ return "\n".join([para.text for para in self.body_text])
447
+
448
+ def release_json(self, doc_type: str = "pdf") -> Dict:
449
+ """
450
+ Return in release JSON format
451
+ :return:
452
+ """
453
+ # TODO: not fully implemented; metadata format is not right; extra keys in some places
454
+ release_dict: Dict = {"paper_id": self.paper_id}
455
+ release_dict.update(
456
+ {
457
+ "header": {
458
+ "generated_with": f"{S2ORC_NAME_STRING} {S2ORC_VERSION_STRING}",
459
+ "date_generated": datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
460
+ }
461
+ }
462
+ )
463
+ release_dict.update(self.metadata.as_json())
464
+ release_dict.update({"abstract": self.raw_abstract_text})
465
+ release_dict.update(
466
+ {
467
+ f"{doc_type}_parse": {
468
+ "paper_id": self.paper_id,
469
+ "_pdf_hash": self.pdf_hash,
470
+ "abstract": [para.as_json() for para in self.abstract],
471
+ "body_text": [para.as_json() for para in self.body_text],
472
+ "back_matter": [para.as_json() for para in self.back_matter],
473
+ "bib_entries": {bib.bib_id: bib.as_json() for bib in self.bib_entries},
474
+ "ref_entries": {ref.ref_id: ref.as_json() for ref in self.ref_entries},
475
+ }
476
+ }
477
+ )
478
+ return release_dict
src/utils/pdf_utils/s2orc_utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ from .s2orc_paper import METADATA_KEYS, Paper
4
+
5
+
6
+ def load_s2orc(paper_dict: Dict[str, Any]) -> Paper:
7
+ """
8
+ Load release S2ORC into Paper class
9
+ :param paper_dict:
10
+ :return:
11
+ """
12
+ paper_id = paper_dict["paper_id"]
13
+ pdf_hash = paper_dict.get("_pdf_hash", paper_dict.get("s2_pdf_hash", None))
14
+
15
+ # 2019 gorc parses
16
+ grobid_parse = paper_dict.get("grobid_parse")
17
+ if grobid_parse:
18
+ metadata = {k: v for k, v in paper_dict["metadata"].items() if k in METADATA_KEYS}
19
+ abstract = grobid_parse.get("abstract", [])
20
+ body_text = grobid_parse.get("body_text", [])
21
+ back_matter = grobid_parse.get("back_matter", [])
22
+ bib_entries = grobid_parse.get("bib_entries", {})
23
+ for k, v in bib_entries.items():
24
+ if "link" in v:
25
+ v["links"] = [v["link"]]
26
+ ref_entries = grobid_parse.get("ref_entries", {})
27
+ # current and 2020 s2orc release_json
28
+ elif ("pdf_parse" in paper_dict and paper_dict.get("pdf_parse")) or (
29
+ "body_text" in paper_dict and paper_dict.get("body_text")
30
+ ):
31
+ if "pdf_parse" in paper_dict:
32
+ paper_dict = paper_dict["pdf_parse"]
33
+ if paper_dict.get("metadata"):
34
+ metadata = {
35
+ k: v for k, v in paper_dict.get("metadata", {}).items() if k in METADATA_KEYS
36
+ }
37
+ # 2020 s2orc releases (metadata is separate)
38
+ else:
39
+ metadata = {"title": None, "authors": [], "year": None}
40
+ abstract = paper_dict.get("abstract", [])
41
+ body_text = paper_dict.get("body_text", [])
42
+ back_matter = paper_dict.get("back_matter", [])
43
+ bib_entries = paper_dict.get("bib_entries", {})
44
+ for k, v in bib_entries.items():
45
+ if "link" in v:
46
+ v["links"] = [v["link"]]
47
+ ref_entries = paper_dict.get("ref_entries", {})
48
+ else:
49
+ print(paper_id)
50
+ raise NotImplementedError("Unknown S2ORC file type!")
51
+
52
+ return Paper(
53
+ paper_id=paper_id,
54
+ pdf_hash=pdf_hash,
55
+ metadata=metadata,
56
+ abstract=abstract,
57
+ body_text=body_text,
58
+ back_matter=back_matter,
59
+ bib_entries=bib_entries,
60
+ ref_entries=ref_entries,
61
+ )
src/utils/pdf_utils/utils.py ADDED
@@ -0,0 +1,904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Dict, List, Tuple
3
+
4
+ import bs4
5
+
6
+
7
+ def replace_refspans(
8
+ spans_to_replace: List[Tuple[int, int, str, str]],
9
+ full_string: str,
10
+ pre_padding: str = "",
11
+ post_padding: str = "",
12
+ btwn_padding: str = ", ",
13
+ ) -> str:
14
+ """
15
+ For each span within the full string, replace that span with new text
16
+ :param spans_to_replace: list of tuples of form (start_ind, end_ind, span_text, new_substring)
17
+ :param full_string:
18
+ :param pre_padding:
19
+ :param post_padding:
20
+ :param btwn_padding:
21
+ :return:
22
+ """
23
+ # assert all spans are equal to full_text span
24
+ assert all([full_string[start:end] == span for start, end, span, _ in spans_to_replace])
25
+
26
+ # assert none of the spans start with the same start ind
27
+ start_inds = [rep[0] for rep in spans_to_replace]
28
+ assert len(set(start_inds)) == len(start_inds)
29
+
30
+ # sort by start index
31
+ spans_to_replace.sort(key=lambda x: x[0])
32
+
33
+ # form strings for each span group
34
+ for i, entry in enumerate(spans_to_replace):
35
+ start, end, span, new_string = entry
36
+
37
+ # skip empties
38
+ if end <= 0:
39
+ continue
40
+
41
+ # compute shift amount
42
+ shift_amount = len(new_string) - len(span) + len(pre_padding) + len(post_padding)
43
+
44
+ # shift remaining appropriately
45
+ for ind in range(i + 1, len(spans_to_replace)):
46
+ next_start, next_end, next_span, next_string = spans_to_replace[ind]
47
+ # skip empties
48
+ if next_end <= 0:
49
+ continue
50
+ # if overlap between ref span and current ref span, remove from replacement
51
+ if next_start < end:
52
+ next_start = 0
53
+ next_end = 0
54
+ next_string = ""
55
+ # if ref span abuts previous reference span
56
+ elif next_start == end:
57
+ next_start += shift_amount
58
+ next_end += shift_amount
59
+ next_string = btwn_padding + pre_padding + next_string + post_padding
60
+ # if ref span starts after, shift starts and ends
61
+ elif next_start > end:
62
+ next_start += shift_amount
63
+ next_end += shift_amount
64
+ next_string = pre_padding + next_string + post_padding
65
+ # save adjusted span
66
+ spans_to_replace[ind] = (next_start, next_end, next_span, next_string)
67
+
68
+ spans_to_replace = [entry for entry in spans_to_replace if entry[1] > 0]
69
+ spans_to_replace.sort(key=lambda x: x[0])
70
+
71
+ # apply shifts in series
72
+ for start, end, span, new_string in spans_to_replace:
73
+ assert full_string[start:end] == span
74
+ full_string = full_string[:start] + new_string + full_string[end:]
75
+
76
+ return full_string
77
+
78
+
79
+ BRACKET_REGEX = re.compile(r"\[[1-9]\d{0,2}([,;\-\s]+[1-9]\d{0,2})*;?\]")
80
+ BRACKET_STYLE_THRESHOLD = 5
81
+
82
+ SINGLE_BRACKET_REGEX = re.compile(r"\[([1-9]\d{0,2})\]")
83
+ EXPANSION_CHARS = {"-", "–"}
84
+
85
+ REPLACE_TABLE_TOKS = {
86
+ "<row>": "<tr>",
87
+ "<row/>": "<tr/>",
88
+ "</row>": "</tr>",
89
+ "<cell>": "<td>",
90
+ "<cell/>": "<td/>",
91
+ "</cell>": "</td>",
92
+ "<cell ": "<td ",
93
+ "cols=": "colspan=",
94
+ }
95
+
96
+
97
+ def span_already_added(sub_start: int, sub_end: int, span_indices: List[Tuple[int, int]]) -> bool:
98
+ """
99
+ Check if span is a subspan of existing span
100
+ :param sub_start:
101
+ :param sub_end:
102
+ :param span_indices:
103
+ :return:
104
+ """
105
+ for span_start, span_end in span_indices:
106
+ if sub_start >= span_start and sub_end <= span_end:
107
+ return True
108
+ return False
109
+
110
+
111
+ def is_expansion_string(between_string: str) -> bool:
112
+ """
113
+ Check if the string between two refs is an expansion string
114
+ :param between_string:
115
+ :return:
116
+ """
117
+ if (
118
+ len(between_string) <= 2
119
+ and any([c in EXPANSION_CHARS for c in between_string])
120
+ and all([c in EXPANSION_CHARS.union({" "}) for c in between_string])
121
+ ):
122
+ return True
123
+ return False
124
+
125
+
126
+ # TODO: still cases like `09bcee03baceb509d4fcf736fa1322cb8adf507f` w/ dups like ['L Jung', 'R Hessler', 'Louis Jung', 'Roland Hessler']
127
+ # example paper that has empties & duplicates: `09bce26cc7e825e15a4469e3e78b7a54898bb97f`
128
+ def _clean_empty_and_duplicate_authors_from_grobid_parse(
129
+ authors: List[Dict],
130
+ ) -> List[Dict]:
131
+ """
132
+ Within affiliation, `location` is a dict with fields <settlement>, <region>, <country>, <postCode>, etc.
133
+ Too much hassle, so just take the first one that's not empty.
134
+ """
135
+ # stripping empties
136
+ clean_authors_list = []
137
+ for author in authors:
138
+ clean_first = author["first"].strip()
139
+ clean_last = author["last"].strip()
140
+ clean_middle = [m.strip() for m in author["middle"]]
141
+ clean_suffix = author["suffix"].strip()
142
+ if clean_first or clean_last or clean_middle:
143
+ author["first"] = clean_first
144
+ author["last"] = clean_last
145
+ author["middle"] = clean_middle
146
+ author["suffix"] = clean_suffix
147
+ clean_authors_list.append(author)
148
+ # combining duplicates (preserve first occurrence of author name as position)
149
+ key_to_author_blobs = {}
150
+ ordered_keys_by_author_pos = []
151
+ for author in clean_authors_list:
152
+ key = (
153
+ author["first"],
154
+ author["last"],
155
+ " ".join(author["middle"]),
156
+ author["suffix"],
157
+ )
158
+ if key not in key_to_author_blobs:
159
+ key_to_author_blobs[key] = author
160
+ ordered_keys_by_author_pos.append(key)
161
+ else:
162
+ if author["email"]:
163
+ key_to_author_blobs[key]["email"] = author["email"]
164
+ if author["affiliation"] and (
165
+ author["affiliation"]["institution"]
166
+ or author["affiliation"]["laboratory"]
167
+ or author["affiliation"]["location"]
168
+ ):
169
+ key_to_author_blobs[key]["affiliation"] = author["affiliation"]
170
+ dedup_authors_list = [key_to_author_blobs[key] for key in ordered_keys_by_author_pos]
171
+ return dedup_authors_list
172
+
173
+
174
+ def sub_spans_and_update_indices(
175
+ spans_to_replace: List[Tuple[int, int, str, str]], full_string: str
176
+ ) -> Tuple[str, List]:
177
+ """
178
+ Replace all spans and recompute indices
179
+ :param spans_to_replace:
180
+ :param full_string:
181
+ :return:
182
+ """
183
+ # TODO: check no spans overlapping
184
+ # TODO: check all spans well-formed
185
+
186
+ # assert all spans are equal to full_text span
187
+ assert all([full_string[start:end] == token for start, end, token, _ in spans_to_replace])
188
+
189
+ # assert none of the spans start with the same start ind
190
+ start_inds = [rep[0] for rep in spans_to_replace]
191
+ assert len(set(start_inds)) == len(start_inds)
192
+
193
+ # sort by start index
194
+ spans_to_replace.sort(key=lambda x: x[0])
195
+
196
+ # compute offsets for each span
197
+ new_spans = [
198
+ (start, end, token, surface, 0) for start, end, token, surface in spans_to_replace
199
+ ]
200
+ for i, entry in enumerate(spans_to_replace):
201
+ start, end, token, surface = entry
202
+ new_end = start + len(surface)
203
+ offset = new_end - end
204
+ # new_spans[i][1] += offset
205
+ new_spans[i] = (
206
+ new_spans[i][0],
207
+ new_spans[i][1] + offset,
208
+ new_spans[i][2],
209
+ new_spans[i][3],
210
+ new_spans[i][4],
211
+ )
212
+ # for new_span_entry in new_spans[i + 1 :]:
213
+ # new_span_entry[4] += offset
214
+ for j in range(i + 1, len(new_spans)):
215
+ new_spans[j] = (
216
+ new_spans[j][0],
217
+ new_spans[j][1],
218
+ new_spans[j][2],
219
+ new_spans[j][3],
220
+ new_spans[j][4] + offset,
221
+ )
222
+
223
+ # generate new text and create final spans
224
+ new_text = replace_refspans(spans_to_replace, full_string, btwn_padding="")
225
+ result = [
226
+ (start + offset, end + offset, token, surface)
227
+ for start, end, token, surface, offset in new_spans
228
+ ]
229
+
230
+ return new_text, result
231
+
232
+
233
+ class UniqTokenGenerator:
234
+ """
235
+ Generate unique token
236
+ """
237
+
238
+ def __init__(self, tok_string):
239
+ self.tok_string = tok_string
240
+ self.ind = 0
241
+
242
+ def __iter__(self):
243
+ return self
244
+
245
+ def __next__(self):
246
+ return self.next()
247
+
248
+ def next(self):
249
+ new_token = f"{self.tok_string}{self.ind}"
250
+ self.ind += 1
251
+ return new_token
252
+
253
+
254
+ def normalize_grobid_id(grobid_id: str):
255
+ """
256
+ Normalize grobid object identifiers
257
+ :param grobid_id:
258
+ :return:
259
+ """
260
+ str_norm = grobid_id.upper().replace("_", "").replace("#", "")
261
+ if str_norm.startswith("B"):
262
+ return str_norm.replace("B", "BIBREF")
263
+ if str_norm.startswith("TAB"):
264
+ return str_norm.replace("TAB", "TABREF")
265
+ if str_norm.startswith("FIG"):
266
+ return str_norm.replace("FIG", "FIGREF")
267
+ if str_norm.startswith("FORMULA"):
268
+ return str_norm.replace("FORMULA", "EQREF")
269
+ return str_norm
270
+
271
+
272
+ def extract_formulas_from_tei_xml(sp: bs4.BeautifulSoup) -> None:
273
+ """
274
+ Replace all formulas with the text
275
+ :param sp:
276
+ :return:
277
+ """
278
+ for eq in sp.find_all("formula"):
279
+ eq.replace_with(sp.new_string(eq.text.strip()))
280
+
281
+
282
+ def table_to_html(table: bs4.element.Tag) -> str:
283
+ """
284
+ Sub table tags with html table tags
285
+ :param table_str:
286
+ :return:
287
+ """
288
+ for tag in table:
289
+ if tag.name != "row":
290
+ print(f"Unknown table subtag: {tag.name}")
291
+ tag.decompose()
292
+ table_str = str(table)
293
+ for token, subtoken in REPLACE_TABLE_TOKS.items():
294
+ table_str = table_str.replace(token, subtoken)
295
+ return table_str
296
+
297
+
298
+ def extract_figures_and_tables_from_tei_xml(sp: bs4.BeautifulSoup) -> Dict[str, Dict]:
299
+ """
300
+ Generate figure and table dicts
301
+ :param sp:
302
+ :return:
303
+ """
304
+ ref_map = dict()
305
+
306
+ for fig in sp.find_all("figure"):
307
+ try:
308
+ if fig.name and fig.get("xml:id"):
309
+ if fig.get("type") == "table":
310
+ ref_map[normalize_grobid_id(fig.get("xml:id"))] = {
311
+ "text": (
312
+ fig.figDesc.text.strip()
313
+ if fig.figDesc
314
+ else fig.head.text.strip() if fig.head else ""
315
+ ),
316
+ "latex": None,
317
+ "type": "table",
318
+ "content": table_to_html(fig.table),
319
+ "fig_num": fig.get("xml:id"),
320
+ }
321
+ else:
322
+ if True in [char.isdigit() for char in fig.findNext("head").findNext("label")]:
323
+ fig_num = fig.findNext("head").findNext("label").contents[0]
324
+ else:
325
+ fig_num = None
326
+ ref_map[normalize_grobid_id(fig.get("xml:id"))] = {
327
+ "text": fig.figDesc.text.strip() if fig.figDesc else "",
328
+ "latex": None,
329
+ "type": "figure",
330
+ "content": "",
331
+ "fig_num": fig_num,
332
+ }
333
+ except AttributeError:
334
+ continue
335
+ fig.decompose()
336
+
337
+ return ref_map
338
+
339
+
340
+ def check_if_citations_are_bracket_style(sp: bs4.BeautifulSoup) -> bool:
341
+ """
342
+ Check if the document has bracket style citations
343
+ :param sp:
344
+ :return:
345
+ """
346
+ cite_strings = []
347
+ if sp.body:
348
+ for div in sp.body.find_all("div"):
349
+ if div.head:
350
+ continue
351
+ for rtag in div.find_all("ref"):
352
+ ref_type = rtag.get("type")
353
+ if ref_type == "bibr":
354
+ cite_strings.append(rtag.text.strip())
355
+
356
+ # check how many match bracket style
357
+ bracket_style = [bool(BRACKET_REGEX.match(cite_str)) for cite_str in cite_strings]
358
+
359
+ # return true if
360
+ if sum(bracket_style) > BRACKET_STYLE_THRESHOLD:
361
+ return True
362
+
363
+ return False
364
+
365
+
366
+ def sub_all_note_tags(sp: bs4.BeautifulSoup) -> bs4.BeautifulSoup:
367
+ """
368
+ Sub all note tags with p tags
369
+ :param para_el:
370
+ :param sp:
371
+ :return:
372
+ """
373
+ for ntag in sp.find_all("note"):
374
+ p_tag = sp.new_tag("p")
375
+ p_tag.string = ntag.text.strip()
376
+ ntag.replace_with(p_tag)
377
+ return sp
378
+
379
+
380
+ def process_formulas_in_paragraph(para_el: bs4.BeautifulSoup, sp: bs4.BeautifulSoup) -> None:
381
+ """
382
+ Process all formulas in paragraph and replace with text and label
383
+ :param para_el:
384
+ :param sp:
385
+ :return:
386
+ """
387
+ for ftag in para_el.find_all("formula"):
388
+ # get label if exists and insert a space between formula and label
389
+ if ftag.label:
390
+ label = " " + ftag.label.text
391
+ ftag.label.decompose()
392
+ else:
393
+ label = ""
394
+ ftag.replace_with(sp.new_string(f"{ftag.text.strip()}{label}"))
395
+
396
+
397
+ def process_references_in_paragraph(
398
+ para_el: bs4.BeautifulSoup, sp: bs4.BeautifulSoup, refs: Dict
399
+ ) -> Dict:
400
+ """
401
+ Process all references in paragraph and generate a dict that contains (type, ref_id, surface_form)
402
+ :param para_el:
403
+ :param sp:
404
+ :param refs:
405
+ :return:
406
+ """
407
+ tokgen = UniqTokenGenerator("REFTOKEN")
408
+ ref_dict = dict()
409
+ for rtag in para_el.find_all("ref"):
410
+ try:
411
+ ref_type = rtag.get("type")
412
+ # skip if citation
413
+ if ref_type == "bibr":
414
+ continue
415
+ if ref_type == "table" or ref_type == "figure":
416
+ ref_id = rtag.get("target")
417
+ if ref_id and normalize_grobid_id(ref_id) in refs:
418
+ # normalize reference string
419
+ rtag_string = normalize_grobid_id(ref_id)
420
+ else:
421
+ rtag_string = None
422
+ # add to ref set
423
+ ref_key = tokgen.next()
424
+ ref_dict[ref_key] = (rtag_string, rtag.text.strip(), ref_type)
425
+ rtag.replace_with(sp.new_string(f" {ref_key} "))
426
+ else:
427
+ # replace with surface form
428
+ rtag.replace_with(sp.new_string(rtag.text.strip()))
429
+ except AttributeError:
430
+ continue
431
+ return ref_dict
432
+
433
+
434
+ def process_citations_in_paragraph(
435
+ para_el: bs4.BeautifulSoup, sp: bs4.BeautifulSoup, bibs: Dict, bracket: bool
436
+ ) -> Dict:
437
+ """
438
+ Process all citations in paragraph and generate a dict for surface forms
439
+ :param para_el:
440
+ :param sp:
441
+ :param bibs:
442
+ :param bracket:
443
+ :return:
444
+ """
445
+
446
+ # CHECK if range between two surface forms is appropriate for bracket style expansion
447
+ def _get_surface_range(start_surface, end_surface):
448
+ span1_match = SINGLE_BRACKET_REGEX.match(start_surface)
449
+ span2_match = SINGLE_BRACKET_REGEX.match(end_surface)
450
+ if span1_match and span2_match:
451
+ # get numbers corresponding to citations
452
+ span1_num = int(span1_match.group(1))
453
+ span2_num = int(span2_match.group(1))
454
+ # expand if range is between 1 and 20
455
+ if 1 < span2_num - span1_num < 20:
456
+ return span1_num, span2_num
457
+ return None
458
+
459
+ # CREATE BIBREF range between two reference ids, e.g. BIBREF1-BIBREF4 -> BIBREF1 BIBREF2 BIBREF3 BIBREF4
460
+ def _create_ref_id_range(start_ref_id, end_ref_id):
461
+ start_ref_num = int(start_ref_id[6:])
462
+ end_ref_num = int(end_ref_id[6:])
463
+ return [f"BIBREF{curr_ref_num}" for curr_ref_num in range(start_ref_num, end_ref_num + 1)]
464
+
465
+ # CREATE surface form range between two bracket strings, e.g. [1]-[4] -> [1] [2] [3] [4]
466
+ def _create_surface_range(start_number, end_number):
467
+ return [f"[{n}]" for n in range(start_number, end_number + 1)]
468
+
469
+ # create citation dict with keywords
470
+ cite_map = dict()
471
+ tokgen = UniqTokenGenerator("CITETOKEN")
472
+
473
+ for rtag in para_el.find_all("ref"):
474
+ try:
475
+ # get surface span, e.g. [3]
476
+ surface_span = rtag.text.strip()
477
+
478
+ # check if target is available (#b2 -> BID2)
479
+ if rtag.get("target"):
480
+ # normalize reference string
481
+ rtag_ref_id = normalize_grobid_id(rtag.get("target"))
482
+
483
+ # skip if rtag ref_id not in bibliography
484
+ if rtag_ref_id not in bibs:
485
+ cite_key = tokgen.next()
486
+ rtag.replace_with(sp.new_string(f" {cite_key} "))
487
+ cite_map[cite_key] = (None, surface_span)
488
+ continue
489
+
490
+ # if bracket style, only keep if surface form is bracket
491
+ if bracket:
492
+ # valid bracket span
493
+ if surface_span and (
494
+ surface_span[0] == "["
495
+ or surface_span[-1] == "]"
496
+ or surface_span[-1] == ","
497
+ ):
498
+ pass
499
+ # invalid, replace tag with surface form and continue to next ref tag
500
+ else:
501
+ rtag.replace_with(sp.new_string(f" {surface_span} "))
502
+ continue
503
+ # not bracket, add cite span and move on
504
+ else:
505
+ cite_key = tokgen.next()
506
+ rtag.replace_with(sp.new_string(f" {cite_key} "))
507
+ cite_map[cite_key] = (rtag_ref_id, surface_span)
508
+ continue
509
+
510
+ # EXTRA PROCESSING FOR BRACKET STYLE CITATIONS; EXPAND RANGES ###
511
+ # look backward for range marker, e.g. [1]-*[3]*
512
+ backward_between_span = ""
513
+ for sib in rtag.previous_siblings:
514
+ if sib.name == "ref":
515
+ break
516
+ elif type(sib) is bs4.NavigableString:
517
+ backward_between_span += sib
518
+ else:
519
+ break
520
+
521
+ # check if there's a backwards expansion, e.g. need to expand [1]-[3] -> [1] [2] [3]
522
+ if is_expansion_string(backward_between_span):
523
+ # get surface number range
524
+ surface_num_range = _get_surface_range(
525
+ rtag.find_previous_sibling("ref").text.strip(), surface_span
526
+ )
527
+ # if the surface number range is reasonable (range < 20, in order), EXPAND
528
+ if surface_num_range:
529
+ # delete previous ref tag and anything in between (i.e. delete "-" and extra spaces)
530
+ for sib in rtag.previous_siblings:
531
+ if sib.name == "ref":
532
+ break
533
+ elif type(sib) is bs4.NavigableString:
534
+ sib.replace_with(sp.new_string(""))
535
+ else:
536
+ break
537
+
538
+ # get ref id of previous ref, e.g. [1] (#b0 -> BID0)
539
+ previous_rtag = rtag.find_previous_sibling("ref")
540
+ previous_rtag_ref_id = normalize_grobid_id(previous_rtag.get("target"))
541
+ previous_rtag.decompose()
542
+
543
+ # replace this ref tag with the full range expansion, e.g. [3] (#b2 -> BID1 BID2)
544
+ id_range = _create_ref_id_range(previous_rtag_ref_id, rtag_ref_id)
545
+ surface_range = _create_surface_range(
546
+ surface_num_range[0], surface_num_range[1]
547
+ )
548
+ replace_string = ""
549
+ for range_ref_id, range_surface_form in zip(id_range, surface_range):
550
+ # only replace if ref id is in bibliography, else add none
551
+ if range_ref_id in bibs:
552
+ cite_key = tokgen.next()
553
+ cite_map[cite_key] = (range_ref_id, range_surface_form)
554
+ else:
555
+ cite_key = tokgen.next()
556
+ cite_map[cite_key] = (None, range_surface_form)
557
+ replace_string += cite_key + " "
558
+ rtag.replace_with(sp.new_string(f" {replace_string} "))
559
+ # ELSE do not expand backwards and replace previous and current rtag with appropriate ref id
560
+ else:
561
+ # add mapping between ref id and surface form for previous ref tag
562
+ previous_rtag = rtag.find_previous_sibling("ref")
563
+ previous_rtag_ref_id = normalize_grobid_id(previous_rtag.get("target"))
564
+ previous_rtag_surface = previous_rtag.text.strip()
565
+ cite_key = tokgen.next()
566
+ previous_rtag.replace_with(sp.new_string(f" {cite_key} "))
567
+ cite_map[cite_key] = (
568
+ previous_rtag_ref_id,
569
+ previous_rtag_surface,
570
+ )
571
+
572
+ # add mapping between ref id and surface form for current reftag
573
+ cite_key = tokgen.next()
574
+ rtag.replace_with(sp.new_string(f" {cite_key} "))
575
+ cite_map[cite_key] = (rtag_ref_id, surface_span)
576
+ else:
577
+ # look forward and see if expansion string, e.g. *[1]*-[3]
578
+ forward_between_span = ""
579
+ for sib in rtag.next_siblings:
580
+ if sib.name == "ref":
581
+ break
582
+ elif type(sib) is bs4.NavigableString:
583
+ forward_between_span += sib
584
+ else:
585
+ break
586
+ # look forward for range marker (if is a range, continue -- range will be expanded
587
+ # when we get to the second value)
588
+ if is_expansion_string(forward_between_span):
589
+ continue
590
+ # else treat like normal reference
591
+ else:
592
+ cite_key = tokgen.next()
593
+ rtag.replace_with(sp.new_string(f" {cite_key} "))
594
+ cite_map[cite_key] = (rtag_ref_id, surface_span)
595
+
596
+ else:
597
+ cite_key = tokgen.next()
598
+ rtag.replace_with(sp.new_string(f" {cite_key} "))
599
+ cite_map[cite_key] = (None, surface_span)
600
+ except AttributeError:
601
+ continue
602
+
603
+ return cite_map
604
+
605
+
606
+ def process_paragraph(
607
+ sp: bs4.BeautifulSoup,
608
+ para_el: bs4.element.Tag,
609
+ section_names: List[Tuple],
610
+ bib_dict: Dict,
611
+ ref_dict: Dict,
612
+ bracket: bool,
613
+ ) -> Dict:
614
+ """
615
+ Process one paragraph
616
+ :param sp:
617
+ :param para_el:
618
+ :param section_names:
619
+ :param bib_dict:
620
+ :param ref_dict:
621
+ :param bracket: if bracket style, expand and clean up citations
622
+ :return:
623
+ """
624
+ # return empty paragraph if no text
625
+ if not para_el.text:
626
+ return {
627
+ "text": "",
628
+ "cite_spans": [],
629
+ "ref_spans": [],
630
+ "eq_spans": [],
631
+ "section": section_names,
632
+ }
633
+
634
+ # replace formulas with formula text
635
+ process_formulas_in_paragraph(para_el, sp)
636
+
637
+ # get references to tables and figures
638
+ ref_map = process_references_in_paragraph(para_el, sp, ref_dict)
639
+
640
+ # generate citation map for paragraph element (keep only cite spans with bib entry or unlinked)
641
+ cite_map = process_citations_in_paragraph(para_el, sp, bib_dict, bracket)
642
+
643
+ # substitute space characters
644
+ para_text = re.sub(r"\s+", " ", para_el.text)
645
+ para_text = re.sub(r"\s", " ", para_text)
646
+
647
+ # get all cite and ref spans
648
+ all_spans_to_replace = []
649
+ for span in re.finditer(r"(CITETOKEN\d+)", para_text):
650
+ uniq_token = span.group()
651
+ ref_id, surface_text = cite_map[uniq_token]
652
+ all_spans_to_replace.append(
653
+ (span.start(), span.start() + len(uniq_token), uniq_token, surface_text)
654
+ )
655
+ for span in re.finditer(r"(REFTOKEN\d+)", para_text):
656
+ uniq_token = span.group()
657
+ ref_id, surface_text, ref_type = ref_map[uniq_token]
658
+ all_spans_to_replace.append(
659
+ (span.start(), span.start() + len(uniq_token), uniq_token, surface_text)
660
+ )
661
+
662
+ # replace cite and ref spans and create json blobs
663
+ para_text, all_spans_to_replace = sub_spans_and_update_indices(all_spans_to_replace, para_text)
664
+
665
+ cite_span_blobs = [
666
+ {"start": start, "end": end, "text": surface, "ref_id": cite_map[token][0]}
667
+ for start, end, token, surface in all_spans_to_replace
668
+ if token.startswith("CITETOKEN")
669
+ ]
670
+
671
+ ref_span_blobs = [
672
+ {"start": start, "end": end, "text": surface, "ref_id": ref_map[token][0]}
673
+ for start, end, token, surface in all_spans_to_replace
674
+ if token.startswith("REFTOKEN")
675
+ ]
676
+
677
+ for cite_blob in cite_span_blobs:
678
+ assert para_text[cite_blob["start"] : cite_blob["end"]] == cite_blob["text"]
679
+
680
+ for ref_blob in ref_span_blobs:
681
+ assert para_text[ref_blob["start"] : ref_blob["end"]] == ref_blob["text"]
682
+
683
+ return {
684
+ "text": para_text,
685
+ "cite_spans": cite_span_blobs,
686
+ "ref_spans": ref_span_blobs,
687
+ "eq_spans": [],
688
+ "section": section_names,
689
+ }
690
+
691
+
692
+ def extract_abstract_from_tei_xml(
693
+ sp: bs4.BeautifulSoup, bib_dict: Dict, ref_dict: Dict, cleanup_bracket: bool
694
+ ) -> List[Dict]:
695
+ """
696
+ Parse abstract from soup
697
+ :param sp:
698
+ :param bib_dict:
699
+ :param ref_dict:
700
+ :param cleanup_bracket:
701
+ :return:
702
+ """
703
+ abstract_text = []
704
+ if sp.abstract:
705
+ # process all divs
706
+ if sp.abstract.div:
707
+ for div in sp.abstract.find_all("div"):
708
+ if div.text:
709
+ if div.p:
710
+ for para in div.find_all("p"):
711
+ if para.text:
712
+ abstract_text.append(
713
+ process_paragraph(
714
+ sp,
715
+ para,
716
+ [(None, "Abstract")],
717
+ bib_dict,
718
+ ref_dict,
719
+ cleanup_bracket,
720
+ )
721
+ )
722
+ else:
723
+ if div.text:
724
+ abstract_text.append(
725
+ process_paragraph(
726
+ sp,
727
+ div,
728
+ [(None, "Abstract")],
729
+ bib_dict,
730
+ ref_dict,
731
+ cleanup_bracket,
732
+ )
733
+ )
734
+ # process all paragraphs
735
+ elif sp.abstract.p:
736
+ for para in sp.abstract.find_all("p"):
737
+ if para.text:
738
+ abstract_text.append(
739
+ process_paragraph(
740
+ sp,
741
+ para,
742
+ [(None, "Abstract")],
743
+ bib_dict,
744
+ ref_dict,
745
+ cleanup_bracket,
746
+ )
747
+ )
748
+ # else just try to get the text
749
+ else:
750
+ if sp.abstract.text:
751
+ abstract_text.append(
752
+ process_paragraph(
753
+ sp,
754
+ sp.abstract,
755
+ [(None, "Abstract")],
756
+ bib_dict,
757
+ ref_dict,
758
+ cleanup_bracket,
759
+ )
760
+ )
761
+ sp.abstract.decompose()
762
+ return abstract_text
763
+
764
+
765
+ def extract_body_text_from_div(
766
+ sp: bs4.BeautifulSoup,
767
+ div: bs4.element.Tag,
768
+ sections: List[Tuple],
769
+ bib_dict: Dict,
770
+ ref_dict: Dict,
771
+ cleanup_bracket: bool,
772
+ ) -> List[Dict]:
773
+ """
774
+ Parse body text from soup
775
+ :param sp:
776
+ :param div:
777
+ :param sections:
778
+ :param bib_dict:
779
+ :param ref_dict:
780
+ :param cleanup_bracket:
781
+ :return:
782
+ """
783
+ chunks = []
784
+ # check if nested divs; recursively process
785
+ if div.div:
786
+ for subdiv in div.find_all("div"):
787
+ # has header, add to section list and process
788
+ if subdiv.head:
789
+ chunks += extract_body_text_from_div(
790
+ sp,
791
+ subdiv,
792
+ sections + [(subdiv.head.get("n", None), subdiv.head.text.strip())],
793
+ bib_dict,
794
+ ref_dict,
795
+ cleanup_bracket,
796
+ )
797
+ subdiv.head.decompose()
798
+ # no header, process with same section list
799
+ else:
800
+ chunks += extract_body_text_from_div(
801
+ sp, subdiv, sections, bib_dict, ref_dict, cleanup_bracket
802
+ )
803
+ # process tags individuals
804
+ for tag in div:
805
+ try:
806
+ if tag.name == "p":
807
+ if tag.text:
808
+ chunks.append(
809
+ process_paragraph(sp, tag, sections, bib_dict, ref_dict, cleanup_bracket)
810
+ )
811
+ elif tag.name == "formula":
812
+ # e.g. <formula xml:id="formula_0">Y = W T X.<label>(1)</label></formula>
813
+ label = tag.label.text
814
+ tag.label.decompose()
815
+ eq_text = tag.text
816
+ chunks.append(
817
+ {
818
+ "text": "EQUATION",
819
+ "cite_spans": [],
820
+ "ref_spans": [],
821
+ "eq_spans": [
822
+ {
823
+ "start": 0,
824
+ "end": 8,
825
+ "text": "EQUATION",
826
+ "ref_id": "EQREF",
827
+ "raw_str": eq_text,
828
+ "eq_num": label,
829
+ }
830
+ ],
831
+ "section": sections,
832
+ }
833
+ )
834
+ except AttributeError:
835
+ if tag.text:
836
+ chunks.append(
837
+ process_paragraph(sp, tag, sections, bib_dict, ref_dict, cleanup_bracket)
838
+ )
839
+
840
+ return chunks
841
+
842
+
843
+ def extract_body_text_from_tei_xml(
844
+ sp: bs4.BeautifulSoup, bib_dict: Dict, ref_dict: Dict, cleanup_bracket: bool
845
+ ) -> List[Dict]:
846
+ """
847
+ Parse body text from soup
848
+ :param sp:
849
+ :param bib_dict:
850
+ :param ref_dict:
851
+ :param cleanup_bracket:
852
+ :return:
853
+ """
854
+ body_text = []
855
+ if sp.body:
856
+ body_text = extract_body_text_from_div(
857
+ sp, sp.body, [], bib_dict, ref_dict, cleanup_bracket
858
+ )
859
+ sp.body.decompose()
860
+ return body_text
861
+
862
+
863
+ def extract_back_matter_from_tei_xml(
864
+ sp: bs4.BeautifulSoup, bib_dict: Dict, ref_dict: Dict, cleanup_bracket: bool
865
+ ) -> List[Dict]:
866
+ """
867
+ Parse back matter from soup
868
+ :param sp:
869
+ :param bib_dict:
870
+ :param ref_dict:
871
+ :param cleanup_bracket:
872
+ :return:
873
+ """
874
+ back_text = []
875
+
876
+ if sp.back:
877
+ for div in sp.back.find_all("div"):
878
+ if div.get("type"):
879
+ section_type = div.get("type")
880
+ else:
881
+ section_type = ""
882
+
883
+ for child_div in div.find_all("div"):
884
+ if child_div.head:
885
+ section_title = child_div.head.text.strip()
886
+ section_num = child_div.head.get("n", None)
887
+ child_div.head.decompose()
888
+ else:
889
+ section_title = section_type
890
+ section_num = None
891
+ if child_div.text:
892
+ if child_div.text:
893
+ back_text.append(
894
+ process_paragraph(
895
+ sp,
896
+ child_div,
897
+ [(section_num, section_title)],
898
+ bib_dict,
899
+ ref_dict,
900
+ cleanup_bracket,
901
+ )
902
+ )
903
+ sp.back.decompose()
904
+ return back_text