ArneBinder commited on
Commit
3133b5e
·
verified ·
1 Parent(s): af21245

https://github.com/ArneBinder/pie-document-level/pull/312

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. requirements.txt +38 -18
  2. src/datamodules/__init__.py +1 -0
  3. src/datamodules/components/__init__.py +0 -0
  4. src/datamodules/components/sampler.py +67 -0
  5. src/datamodules/datamodule.py +154 -0
  6. src/dataset/__init__.py +0 -0
  7. src/dataset/processing.py +29 -0
  8. src/demo/__init__.py +0 -0
  9. src/demo/annotation_utils.py +137 -0
  10. src/demo/backend_utils.py +221 -0
  11. src/demo/data_utils.py +63 -0
  12. src/demo/frontend_utils.py +56 -0
  13. src/demo/rendering_utils.py +296 -0
  14. src/demo/rendering_utils_displacy.py +217 -0
  15. src/demo/retrieve_and_dump_all_relevant.py +101 -0
  16. src/demo/retriever_utils.py +313 -0
  17. src/document/__init__.py +0 -0
  18. src/document/processing.py +223 -0
  19. src/evaluate.py +137 -0
  20. src/evaluate_documents.py +116 -0
  21. src/hydra_callbacks/__init__.py +1 -0
  22. src/hydra_callbacks/save_job_return_value.py +261 -0
  23. src/langchain_modules/span_retriever.py +3 -3
  24. src/metrics/__init__.py +2 -0
  25. src/metrics/annotation_processor.py +23 -0
  26. src/metrics/coref_sklearn.py +162 -0
  27. src/metrics/coref_torchmetrics.py +107 -0
  28. src/models/__init__.py +5 -0
  29. src/models/components/__init__.py +0 -0
  30. src/models/components/pooler.py +79 -0
  31. src/models/sequence_classification_with_pooler.py +166 -0
  32. src/models/utils/__init__.py +5 -1
  33. src/models/utils/loading.py +4 -4
  34. src/pipeline/__init__.py +2 -0
  35. src/pipeline/ner_re_pipeline.py +208 -0
  36. src/pipeline/span_retrieval_based_re_pipeline.py +130 -0
  37. src/predict.py +183 -0
  38. src/serializer/__init__.py +1 -0
  39. src/serializer/interface.py +16 -0
  40. src/serializer/json.py +179 -0
  41. src/start_demo.py +578 -0
  42. src/taskmodules/__init__.py +8 -0
  43. src/taskmodules/components/__init__.py +0 -0
  44. src/taskmodules/cross_text_binary_coref.py +116 -0
  45. src/taskmodules/cross_text_binary_coref_nli.py +166 -0
  46. src/taskmodules/re_text_classification_with_indices.py +176 -0
  47. src/train.py +294 -0
  48. src/utils/__init__.py +6 -0
  49. src/utils/config_utils.py +71 -0
  50. src/utils/data_utils.py +33 -0
requirements.txt CHANGED
@@ -1,14 +1,3 @@
1
- # this requires python>=3.10
2
- gradio~=5.4.0
3
- prettytable==3.10.0
4
- beautifulsoup4==4.12.3
5
- # numpy 2.0.0 breaks the code
6
- numpy==1.25.2
7
- scipy==1.13.0
8
- arxiv==2.1.3
9
- pyrootutils>=1.0.0,<1.1.0
10
-
11
- ########## from root requirements ##########
12
  # --------- pytorch-ie --------- #
13
  pytorch-ie>=0.29.6,<0.32.0
14
  pie-datasets>=0.10.5,<0.11.0
@@ -16,20 +5,51 @@ pie-modules>=0.14.0,<0.15.0
16
 
17
  # --------- models -------- #
18
  adapters>=0.1.2,<0.2.0
19
- # ADU retrieval (and demo, in future):
 
20
  langchain>=0.3.0,<0.4.0
21
  langchain-core>=0.3.0,<0.4.0
22
  langchain-community>=0.3.0,<0.4.0
23
  # we use QDrant as vectorstore backend
24
  langchain-qdrant>=0.1.0,<0.2.0
25
  qdrant-client>=1.12.0,<2.0.0
26
- # 0.26 seems to be broken when used with adapters, see https://github.com/adapter-hub/adapters/issues/748
27
- huggingface_hub<0.26.0 # 0.26 seems to be broken
28
- # to to handle segmented entities (if HANDLE_PARTS_OF_SAME=True)
29
- networkx>=3.0.0,<4.0.0
30
 
31
- # --------- config --------- #
 
 
 
 
32
  hydra-core>=1.3.0
 
 
 
 
 
 
 
 
 
 
33
 
34
- # --------- dev --------- #
35
  pre-commit # hooks for applying linters on commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # --------- pytorch-ie --------- #
2
  pytorch-ie>=0.29.6,<0.32.0
3
  pie-datasets>=0.10.5,<0.11.0
 
5
 
6
  # --------- models -------- #
7
  adapters>=0.1.2,<0.2.0
8
+ pytorch-crf~=0.7.2
9
+ # --------- retriever -------- #
10
  langchain>=0.3.0,<0.4.0
11
  langchain-core>=0.3.0,<0.4.0
12
  langchain-community>=0.3.0,<0.4.0
13
  # we use QDrant as vectorstore backend
14
  langchain-qdrant>=0.1.0,<0.2.0
15
  qdrant-client>=1.12.0,<2.0.0
 
 
 
 
16
 
17
+ # --------- demo -------- #
18
+ gradio~=5.4.0
19
+ arxiv~=2.1.3
20
+
21
+ # --------- hydra --------- #
22
  hydra-core>=1.3.0
23
+ hydra-colorlog>=1.2.0
24
+ hydra-optuna-sweeper>=1.2.0
25
+
26
+ # --------- loggers --------- #
27
+ wandb
28
+ # neptune-client
29
+ # mlflow
30
+ # comet-ml
31
+ # tensorboard
32
+ # aim
33
 
34
+ # --------- linters --------- #
35
  pre-commit # hooks for applying linters on commit
36
+ black # code formatting
37
+ isort # import sorting
38
+ flake8 # code analysis
39
+ nbstripout # remove output from jupyter notebooks
40
+
41
+ # --------- others --------- #
42
+ pyrootutils # standardizing the project root setup
43
+ python-dotenv # loading env variables from .env file
44
+ rich # beautiful text formatting in terminal
45
+ pytest # tests
46
+ pytest-cov # test coverageataset
47
+ sh # for running bash commands in some tests
48
+ pudb # debugger
49
+ tabulate # show statistics as markdown
50
+ plotext # show statistics as plots
51
+ prettytable # rendering annotated docs as table (demo)
52
+ beautifulsoup4 # rendering annotated docs with displacy + highlighted relations (demo)
53
+ # 0.26 seems to be broken when used with adapters, see https://github.com/adapter-hub/adapters/issues/748
54
+ huggingface_hub<0.26.0 # interaction with HF hub
55
+ networkx~=3.2.1 # to handle segmented entities (e.g if HANDLE_PARTS_OF_SAME=True in demo)
src/datamodules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .datamodule import PieDataModule
src/datamodules/components/__init__.py ADDED
File without changes
src/datamodules/components/sampler.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This is a slightly modified version of https://github.com/ufoym/imbalanced-dataset-sampler."""
2
+
3
+ from typing import Callable, List, Optional
4
+
5
+ import pandas as pd
6
+ import torch
7
+ import torch.utils.data
8
+
9
+
10
+ class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
11
+ """Samples elements randomly from a given list of indices for imbalanced dataset.
12
+
13
+ Arguments:
14
+ indices: a list of indices
15
+ num_samples: number of samples to draw
16
+ callback_get_label: a callback-like function which takes one argument - the dataset
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ dataset,
22
+ labels: Optional[List] = None,
23
+ indices: Optional[List] = None,
24
+ num_samples: Optional[int] = None,
25
+ callback_get_label: Optional[Callable] = None,
26
+ ):
27
+ # if indices is not provided, all elements in the dataset will be considered
28
+ self.indices = list(range(len(dataset))) if indices is None else indices
29
+
30
+ # define custom callback
31
+ self.callback_get_label = callback_get_label
32
+
33
+ # if num_samples is not provided, draw `len(indices)` samples in each iteration
34
+ self.num_samples = len(self.indices) if num_samples is None else num_samples
35
+
36
+ # distribution of classes in the dataset
37
+ df = pd.DataFrame()
38
+ df["label"] = self._get_labels(dataset) if labels is None else labels
39
+ df.index = self.indices
40
+ df = df.sort_index()
41
+
42
+ label_to_count = df["label"].value_counts()
43
+
44
+ weights = 1.0 / label_to_count[df["label"]]
45
+
46
+ self.weights = torch.DoubleTensor(weights.to_list())
47
+
48
+ def _get_labels(self, dataset):
49
+ if self.callback_get_label:
50
+ return self.callback_get_label(dataset)
51
+ elif isinstance(dataset, torch.utils.data.TensorDataset):
52
+ return dataset.tensors[1]
53
+ elif isinstance(dataset, torch.utils.data.Subset):
54
+ return dataset.dataset.imgs[:][1]
55
+ elif isinstance(dataset, torch.utils.data.Dataset):
56
+ return dataset.get_labels()
57
+ else:
58
+ raise NotImplementedError
59
+
60
+ def __iter__(self):
61
+ return (
62
+ self.indices[i]
63
+ for i in torch.multinomial(self.weights, self.num_samples, replacement=True)
64
+ )
65
+
66
+ def __len__(self):
67
+ return self.num_samples
src/datamodules/datamodule.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union
2
+
3
+ from pytorch_ie.core import Document
4
+ from pytorch_ie.core.taskmodule import (
5
+ IterableTaskEncodingDataset,
6
+ TaskEncoding,
7
+ TaskEncodingDataset,
8
+ TaskModule,
9
+ )
10
+ from pytorch_lightning import LightningDataModule
11
+ from torch.utils.data import DataLoader, Sampler
12
+ from typing_extensions import TypeAlias
13
+
14
+ from .components.sampler import ImbalancedDatasetSampler
15
+
16
+ DocumentType = TypeVar("DocumentType", bound=Document)
17
+ InputEncoding = TypeVar("InputEncoding")
18
+ TargetEncoding = TypeVar("TargetEncoding")
19
+ DatasetType: TypeAlias = Union[
20
+ TaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
21
+ IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
22
+ ]
23
+
24
+
25
+ class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]):
26
+ """A simple LightningDataModule for PIE document datasets.
27
+
28
+ A DataModule implements 5 key methods:
29
+ - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
30
+ - setup (things to do on every accelerator in distributed mode)
31
+ - train_dataloader (the training dataloader)
32
+ - val_dataloader (the validation dataloader(s))
33
+ - test_dataloader (the test dataloader(s))
34
+
35
+ This allows you to share a full dataset without explaining how to download,
36
+ split, transform and process the data.
37
+
38
+ Read the docs:
39
+ https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ taskmodule: TaskModule[DocumentType, InputEncoding, TargetEncoding, Any, Any, Any],
45
+ dataset: Dict[str, Sequence[DocumentType]],
46
+ data_config_path: Optional[str] = None,
47
+ train_split: Optional[str] = "train",
48
+ val_split: Optional[str] = "validation",
49
+ test_split: Optional[str] = "test",
50
+ show_progress_for_encode: bool = False,
51
+ train_sampler: Optional[str] = None,
52
+ **dataloader_kwargs,
53
+ ):
54
+ super().__init__()
55
+
56
+ self.taskmodule = taskmodule
57
+ self.config_path = data_config_path
58
+ self.dataset = dataset
59
+ self.train_split = train_split
60
+ self.val_split = val_split
61
+ self.test_split = test_split
62
+ self.show_progress_for_encode = show_progress_for_encode
63
+ self.train_sampler_name = train_sampler
64
+ self.dataloader_kwargs = dataloader_kwargs
65
+
66
+ self._data: Dict[str, DatasetType] = {}
67
+
68
+ @property
69
+ def num_train(self) -> int:
70
+ if self.train_split is None:
71
+ raise ValueError("no train_split assigned")
72
+ data_train = self._data.get(self.train_split, None)
73
+ if data_train is None:
74
+ raise ValueError("can not get train size if setup() was not yet called")
75
+ if isinstance(data_train, IterableTaskEncodingDataset):
76
+ raise TypeError("IterableTaskEncodingDataset has no length")
77
+ return len(data_train)
78
+
79
+ def setup(self, stage: str):
80
+ if stage == "fit":
81
+ split_names = [self.train_split, self.val_split]
82
+ elif stage == "validate":
83
+ split_names = [self.val_split]
84
+ elif stage == "test":
85
+ split_names = [self.test_split]
86
+ else:
87
+ raise NotImplementedError(f"not implemented for stage={stage} ")
88
+
89
+ for split in split_names:
90
+ if split is None or split not in self.dataset:
91
+ continue
92
+ task_encoding_dataset = self.taskmodule.encode(
93
+ self.dataset[split],
94
+ encode_target=True,
95
+ as_dataset=True,
96
+ show_progress=self.show_progress_for_encode,
97
+ )
98
+ if not isinstance(
99
+ task_encoding_dataset,
100
+ (TaskEncodingDataset, IterableTaskEncodingDataset),
101
+ ):
102
+ raise TypeError(
103
+ f"taskmodule.encode did not return a (Iterable)TaskEncodingDataset, but: {type(task_encoding_dataset)}"
104
+ )
105
+ self._data[split] = task_encoding_dataset
106
+
107
+ def data_split(self, split: Optional[str] = None) -> DatasetType:
108
+ if split is None or split not in self._data:
109
+ raise ValueError(f"data for split={split} not available")
110
+ return self._data[split]
111
+
112
+ def get_train_sampler(
113
+ self,
114
+ sampler_name: str,
115
+ dataset: DatasetType,
116
+ ) -> Sampler:
117
+ if sampler_name == "imbalanced_dataset":
118
+ # for now, this work only with targets that have a single entry
119
+ return ImbalancedDatasetSampler(
120
+ dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds]
121
+ )
122
+ else:
123
+ raise ValueError(f"unknown sampler name: {sampler_name}")
124
+
125
+ def train_dataloader(self):
126
+ ds = self.data_split(self.train_split)
127
+ if self.train_sampler_name is not None:
128
+ sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds)
129
+ else:
130
+ sampler = None
131
+ return DataLoader(
132
+ dataset=ds,
133
+ sampler=sampler,
134
+ collate_fn=self.taskmodule.collate,
135
+ # don't shuffle streamed datasets or if we use a sampler
136
+ shuffle=not (isinstance(ds, IterableTaskEncodingDataset) or sampler is not None),
137
+ **self.dataloader_kwargs,
138
+ )
139
+
140
+ def val_dataloader(self):
141
+ return DataLoader(
142
+ dataset=self.data_split(self.val_split),
143
+ collate_fn=self.taskmodule.collate,
144
+ shuffle=False,
145
+ **self.dataloader_kwargs,
146
+ )
147
+
148
+ def test_dataloader(self):
149
+ return DataLoader(
150
+ dataset=self.data_split(self.test_split),
151
+ collate_fn=self.taskmodule.collate,
152
+ shuffle=False,
153
+ **self.dataloader_kwargs,
154
+ )
src/dataset/__init__.py ADDED
File without changes
src/dataset/processing.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Type, Union
2
+
3
+ from pie_datasets import Dataset, DatasetDict
4
+ from pytorch_ie import Document
5
+ from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target
6
+
7
+
8
+ # TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and
9
+ # batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged
10
+ def apply_func_to_splits(
11
+ dataset: DatasetDict,
12
+ function: Union[str, Callable],
13
+ result_document_type: Type[Document],
14
+ **kwargs
15
+ ):
16
+ resolved_func = resolve_target(function)
17
+ resolved_document_type = resolve_optional_document_type(document_type=result_document_type)
18
+ result_dict = dict()
19
+ split: Dataset
20
+ for split_name, split in dataset.items():
21
+ converted_dataset = split.map(
22
+ function=resolved_func,
23
+ batched=True,
24
+ batch_size=len(split),
25
+ result_document_type=resolved_document_type,
26
+ **kwargs
27
+ )
28
+ result_dict[split_name] = converted_dataset
29
+ return DatasetDict(result_dict)
src/demo/__init__.py ADDED
File without changes
src/demo/annotation_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, Sequence, Union
3
+
4
+ import gradio as gr
5
+ from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
6
+
7
+ # this is required to dynamically load the PIE models
8
+ from pie_modules.models import * # noqa: F403
9
+ from pie_modules.taskmodules import * # noqa: F403
10
+ from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
11
+ from pytorch_ie import Pipeline
12
+ from pytorch_ie.annotations import LabeledSpan
13
+ from pytorch_ie.auto import AutoPipeline
14
+ from pytorch_ie.documents import (
15
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
16
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
17
+ )
18
+
19
+ # this is required to dynamically load the PIE models
20
+ from pytorch_ie.models import * # noqa: F403
21
+ from pytorch_ie.taskmodules import * # noqa: F403
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def annotate_document(
27
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
28
+ argumentation_model: Pipeline,
29
+ handle_parts_of_same: bool = False,
30
+ ) -> Union[
31
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
32
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
33
+ ]:
34
+ """Annotate a document with the provided pipeline.
35
+
36
+ Args:
37
+ document: The document to annotate.
38
+ argumentation_model: The pipeline to use for annotation.
39
+ handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
40
+ """
41
+
42
+ # execute prediction pipeline
43
+ argumentation_model(document)
44
+
45
+ if handle_parts_of_same:
46
+ merger = SpansViaRelationMerger(
47
+ relation_layer="binary_relations",
48
+ link_relation_label="parts_of_same",
49
+ create_multi_spans=True,
50
+ result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
51
+ result_field_mapping={
52
+ "labeled_spans": "labeled_multi_spans",
53
+ "binary_relations": "binary_relations",
54
+ "labeled_partitions": "labeled_partitions",
55
+ },
56
+ )
57
+ document = merger(document)
58
+
59
+ return document
60
+
61
+
62
+ def create_document(
63
+ text: str, doc_id: str, split_regex: Optional[str] = None
64
+ ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
65
+ """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
66
+ text.
67
+
68
+ Parameters:
69
+ text: The text to process.
70
+ doc_id: The ID of the document.
71
+ split_regex: A regular expression pattern to use for splitting the text into partitions.
72
+
73
+ Returns:
74
+ The processed document.
75
+ """
76
+
77
+ document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
78
+ id=doc_id, text=text, metadata={}
79
+ )
80
+ if split_regex is not None:
81
+ partitioner = RegexPartitioner(
82
+ pattern=split_regex, partition_layer_name="labeled_partitions"
83
+ )
84
+ document = partitioner(document)
85
+ else:
86
+ # add single partition from the whole text (the model only considers text in partitions)
87
+ document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
88
+ return document
89
+
90
+
91
+ def load_argumentation_model(
92
+ model_name: str,
93
+ revision: Optional[str] = None,
94
+ device: str = "cpu",
95
+ ) -> Pipeline:
96
+ try:
97
+ # the Pipeline class expects an integer for the device
98
+ if device == "cuda":
99
+ pipeline_device = 0
100
+ elif device.startswith("cuda:"):
101
+ pipeline_device = int(device.split(":")[1])
102
+ elif device == "cpu":
103
+ pipeline_device = -1
104
+ else:
105
+ raise gr.Error(f"Invalid device: {device}")
106
+
107
+ model = AutoPipeline.from_pretrained(
108
+ model_name,
109
+ device=pipeline_device,
110
+ num_workers=0,
111
+ taskmodule_kwargs=dict(revision=revision),
112
+ model_kwargs=dict(revision=revision),
113
+ )
114
+ gr.Info(
115
+ f"Loaded argumentation model: model_name={model_name}, revision={revision}, device={device}"
116
+ )
117
+ except Exception as e:
118
+ raise gr.Error(f"Failed to load argumentation model: {e}")
119
+
120
+ return model
121
+
122
+
123
+ def set_relation_types(
124
+ argumentation_model: Pipeline,
125
+ default: Optional[Sequence[str]] = None,
126
+ ) -> gr.Dropdown:
127
+ if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
128
+ relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"]
129
+ else:
130
+ raise gr.Error("Unsupported taskmodule for relation types")
131
+
132
+ return gr.Dropdown(
133
+ choices=relation_types,
134
+ label="Argumentative Relation Types",
135
+ value=default,
136
+ multiselect=True,
137
+ )
src/demo/backend_utils.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import tempfile
5
+ from typing import Iterable, List, Optional, Sequence
6
+
7
+ import gradio as gr
8
+ import pandas as pd
9
+ from pie_datasets import Dataset, IterableDataset, load_dataset
10
+ from pytorch_ie import Pipeline
11
+ from pytorch_ie.documents import (
12
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
13
+ )
14
+
15
+ from src.demo.annotation_utils import annotate_document, create_document
16
+ from src.demo.data_utils import load_text_from_arxiv
17
+ from src.demo.rendering_utils import (
18
+ RENDER_WITH_DISPLACY,
19
+ RENDER_WITH_PRETTY_TABLE,
20
+ render_displacy,
21
+ render_pretty_table,
22
+ )
23
+ from src.demo.retriever_utils import get_text_spans_and_relations_from_document
24
+ from src.langchain_modules import (
25
+ DocumentAwareSpanRetriever,
26
+ DocumentAwareSpanRetrieverWithRelations,
27
+ )
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def add_annotated_pie_documents(
33
+ retriever: DocumentAwareSpanRetriever,
34
+ pie_documents: Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions],
35
+ use_predicted_annotations: bool,
36
+ verbose: bool = False,
37
+ ) -> None:
38
+ if verbose:
39
+ gr.Info(f"Create span embeddings for {len(pie_documents)} documents...")
40
+ num_docs_before = len(retriever.docstore)
41
+ retriever.add_pie_documents(pie_documents, use_predicted_annotations=use_predicted_annotations)
42
+ # number of documents that were overwritten
43
+ num_overwritten_docs = num_docs_before + len(pie_documents) - len(retriever.docstore)
44
+ # warn if documents were overwritten
45
+ if num_overwritten_docs > 0:
46
+ gr.Warning(f"{num_overwritten_docs} documents were overwritten.")
47
+
48
+
49
+ def process_texts(
50
+ texts: Iterable[str],
51
+ doc_ids: Iterable[str],
52
+ argumentation_model: Pipeline,
53
+ retriever: DocumentAwareSpanRetriever,
54
+ split_regex_escaped: Optional[str],
55
+ handle_parts_of_same: bool = False,
56
+ verbose: bool = False,
57
+ ) -> None:
58
+ # check that doc_ids are unique
59
+ if len(set(doc_ids)) != len(list(doc_ids)):
60
+ raise gr.Error("Document IDs must be unique.")
61
+ pie_documents = [
62
+ create_document(text=text, doc_id=doc_id, split_regex=split_regex_escaped)
63
+ for text, doc_id in zip(texts, doc_ids)
64
+ ]
65
+ if verbose:
66
+ gr.Info(f"Annotate {len(pie_documents)} documents...")
67
+ pie_documents = [
68
+ annotate_document(
69
+ document=pie_document,
70
+ argumentation_model=argumentation_model,
71
+ handle_parts_of_same=handle_parts_of_same,
72
+ )
73
+ for pie_document in pie_documents
74
+ ]
75
+ add_annotated_pie_documents(
76
+ retriever=retriever,
77
+ pie_documents=pie_documents,
78
+ use_predicted_annotations=True,
79
+ verbose=verbose,
80
+ )
81
+
82
+
83
+ def add_annotated_pie_documents_from_dataset(
84
+ retriever: DocumentAwareSpanRetriever, verbose: bool = False, **load_dataset_kwargs
85
+ ) -> None:
86
+ try:
87
+ gr.Info(
88
+ "Loading PIE dataset with parameters:\n" + json.dumps(load_dataset_kwargs, indent=2)
89
+ )
90
+ dataset = load_dataset(**load_dataset_kwargs)
91
+ if not isinstance(dataset, (Dataset, IterableDataset)):
92
+ raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.")
93
+ dataset_converted = dataset.to_document_type(
94
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
95
+ )
96
+ add_annotated_pie_documents(
97
+ retriever=retriever,
98
+ pie_documents=dataset_converted,
99
+ use_predicted_annotations=False,
100
+ verbose=verbose,
101
+ )
102
+ except Exception as e:
103
+ raise gr.Error(f"Failed to load dataset: {e}")
104
+
105
+
106
+ def wrapped_process_text(
107
+ doc_id: str, text: str, retriever: DocumentAwareSpanRetriever, **kwargs
108
+ ) -> str:
109
+ try:
110
+ process_texts(doc_ids=[doc_id], texts=[text], retriever=retriever, **kwargs)
111
+ except Exception as e:
112
+ raise gr.Error(f"Failed to process text: {e}")
113
+ # Return as dict and document to avoid serialization issues
114
+ return doc_id
115
+
116
+
117
+ def process_uploaded_files(
118
+ file_names: List[str],
119
+ retriever: DocumentAwareSpanRetriever,
120
+ layer_captions: dict[str, str],
121
+ **kwargs,
122
+ ) -> pd.DataFrame:
123
+ try:
124
+ doc_ids = []
125
+ texts = []
126
+ for file_name in file_names:
127
+ if file_name.lower().endswith(".txt"):
128
+ # read the file content
129
+ with open(file_name, "r", encoding="utf-8") as f:
130
+ text = f.read()
131
+ base_file_name = os.path.basename(file_name)
132
+ doc_ids.append(base_file_name)
133
+ texts.append(text)
134
+ else:
135
+ raise gr.Error(f"Unsupported file format: {file_name}")
136
+ process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs)
137
+ except Exception as e:
138
+ raise gr.Error(f"Failed to process uploaded files: {e}")
139
+
140
+ return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
141
+
142
+
143
+ def wrapped_add_annotated_pie_documents_from_dataset(
144
+ retriever: DocumentAwareSpanRetriever, verbose: bool, layer_captions: dict[str, str], **kwargs
145
+ ) -> pd.DataFrame:
146
+ try:
147
+ add_annotated_pie_documents_from_dataset(retriever=retriever, verbose=verbose, **kwargs)
148
+ except Exception as e:
149
+ raise gr.Error(f"Failed to add annotated PIE documents from dataset: {e}")
150
+ return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
151
+
152
+
153
+ def download_processed_documents(
154
+ retriever: DocumentAwareSpanRetriever,
155
+ file_name: str = "retriever_store",
156
+ ) -> Optional[str]:
157
+ if len(retriever.docstore) == 0:
158
+ gr.Warning("No documents to download.")
159
+ return None
160
+
161
+ # zip the directory
162
+ file_path = os.path.join(tempfile.gettempdir(), file_name)
163
+
164
+ gr.Info(f"Zipping the retriever store to '{file_name}' ...")
165
+ result_file_path = retriever.save_to_archive(base_name=file_path, format="zip")
166
+
167
+ return result_file_path
168
+
169
+
170
+ def upload_processed_documents(
171
+ file_name: str,
172
+ retriever: DocumentAwareSpanRetriever,
173
+ layer_captions: dict[str, str],
174
+ ) -> pd.DataFrame:
175
+ # load the documents from the zip file or directory
176
+ retriever.load_from_disc(file_name)
177
+ # return the overview of the document store
178
+ return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
179
+
180
+
181
+ def process_text_from_arxiv(
182
+ arxiv_id: str, retriever: DocumentAwareSpanRetriever, abstract_only: bool = False, **kwargs
183
+ ) -> str:
184
+ try:
185
+ text, doc_id = load_text_from_arxiv(arxiv_id=arxiv_id, abstract_only=abstract_only)
186
+ except Exception as e:
187
+ raise gr.Error(f"Failed to load text from arXiv: {e}")
188
+ return wrapped_process_text(doc_id=doc_id, text=text, retriever=retriever, **kwargs)
189
+
190
+
191
+ def render_annotated_document(
192
+ retriever: DocumentAwareSpanRetrieverWithRelations,
193
+ document_id: str,
194
+ render_with: str,
195
+ render_kwargs_json: str,
196
+ ) -> str:
197
+ text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document(
198
+ retriever=retriever, document_id=document_id
199
+ )
200
+
201
+ render_kwargs = json.loads(render_kwargs_json)
202
+ if render_with == RENDER_WITH_PRETTY_TABLE:
203
+ html = render_pretty_table(
204
+ text=text,
205
+ spans=spans,
206
+ span_id2idx=span_id2idx,
207
+ binary_relations=relations,
208
+ **render_kwargs,
209
+ )
210
+ elif render_with == RENDER_WITH_DISPLACY:
211
+ html = render_displacy(
212
+ text=text,
213
+ spans=spans,
214
+ span_id2idx=span_id2idx,
215
+ binary_relations=relations,
216
+ **render_kwargs,
217
+ )
218
+ else:
219
+ raise ValueError(f"Unknown render_with value: {render_with}")
220
+
221
+ return html
src/demo/data_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ from typing import Tuple
4
+
5
+ import arxiv
6
+ import gradio as gr
7
+ import requests
8
+ from bs4 import BeautifulSoup
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def clean_spaces(text: str) -> str:
14
+ # replace all multiple spaces with a single space
15
+ text = re.sub(" +", " ", text)
16
+ # reduce more than two newlines to two newlines
17
+ text = re.sub("\n\n+", "\n\n", text)
18
+ # remove leading and trailing whitespaces
19
+ text = text.strip()
20
+ return text
21
+
22
+
23
+ def get_cleaned_arxiv_paper_text(html_content: str) -> str:
24
+ # parse the HTML content with BeautifulSoup
25
+ soup = BeautifulSoup(html_content, "html.parser")
26
+ # get alerts (this is one div with classes "package-alerts" and "ltx_document")
27
+ alerts = soup.find("div", class_="package-alerts ltx_document")
28
+ # get the "article" html element
29
+ article = soup.find("article")
30
+ article_text = article.get_text()
31
+ # cleanup the text
32
+ article_text_clean = clean_spaces(article_text)
33
+ return article_text_clean
34
+
35
+
36
+ def load_text_from_arxiv(arxiv_id: str, abstract_only: bool = False) -> Tuple[str, str]:
37
+
38
+ search_by_id = arxiv.Search(id_list=[arxiv_id])
39
+ try:
40
+ result = list(arxiv.Client().results(search_by_id))
41
+ except arxiv.HTTPError as e:
42
+ raise gr.Error(f"Failed to fetch arXiv data: {e}")
43
+ if len(result) == 0:
44
+ raise gr.Error(f"Could not find any paper with arXiv ID '{arxiv_id}'")
45
+ first_result = result[0]
46
+ if abstract_only:
47
+ abstract_clean = first_result.summary.replace("\n", " ")
48
+ return abstract_clean, first_result.entry_id
49
+ if "/abs/" not in first_result.entry_id:
50
+ raise gr.Error(
51
+ f"Could not create the HTML URL for arXiv ID '{arxiv_id}' because its entry ID has "
52
+ f"an unexpected format: {first_result.entry_id}"
53
+ )
54
+ html_url = first_result.entry_id.replace("/abs/", "/html/")
55
+ request_result = requests.get(html_url)
56
+ if request_result.status_code != 200:
57
+ raise gr.Error(
58
+ f"Could not fetch the HTML content for arXiv ID '{arxiv_id}', status code: "
59
+ f"{request_result.status_code}"
60
+ )
61
+ html_content = request_result.text
62
+ text_clean = get_cleaned_arxiv_paper_text(html_content)
63
+ return text_clean, html_url
src/demo/frontend_utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Union
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+
6
+
7
+ # see https://github.com/gradio-app/gradio/issues/9288#issuecomment-2356163329
8
+ def get_fix_df_height_css(css_class: str, max_height: int) -> str:
9
+ # return ".qa-pairs .table-wrap {min-height: 170px; max-height: 170px;}"
10
+ return "." + css_class + " .table-wrap {max-height: " + str(max_height) + "px;}"
11
+
12
+
13
+ def escape_regex(regex: str) -> str:
14
+ # "double escape" the backslashes
15
+ result = regex.encode("unicode_escape").decode("utf-8")
16
+ return result
17
+
18
+
19
+ def unescape_regex(regex: str) -> str:
20
+ # reverse of escape_regex
21
+ result = regex.encode("utf-8").decode("unicode_escape")
22
+ return result
23
+
24
+
25
+ def open_accordion():
26
+ return gr.Accordion(open=True)
27
+
28
+
29
+ def close_accordion():
30
+ return gr.Accordion(open=False)
31
+
32
+
33
+ def change_tab(id: Union[int, str]):
34
+ return gr.Tabs(selected=id)
35
+
36
+
37
+ def get_cell_for_fixed_column_from_df(
38
+ evt: gr.SelectData,
39
+ df: pd.DataFrame,
40
+ column: str,
41
+ ) -> Any:
42
+ """Get the value of the fixed column for the selected row in the DataFrame.
43
+ This is required can *not* with a lambda function because that will not get
44
+ the evt parameter.
45
+
46
+ Args:
47
+ evt: The event object.
48
+ df: The DataFrame.
49
+ column: The name of the column.
50
+
51
+ Returns:
52
+ The value of the fixed column for the selected row.
53
+ """
54
+ row_idx, col_idx = evt.index
55
+ doc_id = df.iloc[row_idx][column]
56
+ return doc_id
src/demo/rendering_utils.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from collections import defaultdict
4
+ from typing import Any, Dict, List, Optional, Sequence, Union
5
+
6
+ from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
7
+
8
+ from .rendering_utils_displacy import EntityRenderer
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ RENDER_WITH_DISPLACY = "displacy"
13
+ RENDER_WITH_PRETTY_TABLE = "pretty_table"
14
+ AVAILABLE_RENDER_MODES = [RENDER_WITH_DISPLACY, RENDER_WITH_PRETTY_TABLE]
15
+
16
+ # adjusted from rendering_utils_displacy.TPL_ENT
17
+ TPL_ENT_WITH_ID = """
18
+ <mark class="entity" data-entity-id="{entity_id}" data-slice-idx="{slice_idx}" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
19
+ {text}
20
+ <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
21
+ </mark>
22
+ """
23
+
24
+ HIGHLIGHT_SPANS_JS = """
25
+ () => {
26
+ function maybeSetColor(entity, colorAttributeKey, colorDictKey) {
27
+ var color = entity.getAttribute('data-color-' + colorAttributeKey);
28
+ // if color is a json string, parse it and use the value at colorDictKey
29
+ try {
30
+ const colors = JSON.parse(color);
31
+ color = colors[colorDictKey];
32
+ } catch (e) {}
33
+ if (color) {
34
+ entity.style.backgroundColor = color;
35
+ entity.style.color = '#000';
36
+ }
37
+ }
38
+
39
+ function highlightRelationArguments(entityId) {
40
+ const entities = document.querySelectorAll('.entity');
41
+ // reset all entities
42
+ entities.forEach(entity => {
43
+ const color = entity.getAttribute('data-color-original');
44
+ entity.style.backgroundColor = color;
45
+ entity.style.color = '';
46
+ });
47
+
48
+ if (entityId !== null) {
49
+ var visitedEntities = new Set();
50
+ // highlight selected entity
51
+ // get all elements with attribute data-entity-id==entityId
52
+ const selectedEntityParts = document.querySelectorAll(`[data-entity-id="${entityId}"]`);
53
+ selectedEntityParts.forEach(selectedEntityPart => {
54
+ const label = selectedEntityPart.getAttribute('data-label');
55
+ maybeSetColor(selectedEntityPart, 'selected', label);
56
+ visitedEntities.add(selectedEntityPart);
57
+ }); // <-- Corrected closing parenthesis here
58
+ // if there is at least one part, get the first one and ...
59
+ if (selectedEntityParts.length > 0) {
60
+ const selectedEntity = selectedEntityParts[0];
61
+
62
+ // ... highlight tails and ...
63
+ const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails'));
64
+ relationTailsAndLabels.forEach(relationTail => {
65
+ const tailEntityId = relationTail['entity-id'];
66
+ const tailEntityParts = document.querySelectorAll(`[data-entity-id="${tailEntityId}"]`);
67
+ tailEntityParts.forEach(tailEntity => {
68
+ const label = relationTail['label'];
69
+ maybeSetColor(tailEntity, 'tail', label);
70
+ visitedEntities.add(tailEntity);
71
+ }); // <-- Corrected closing parenthesis here
72
+ }); // <-- Corrected closing parenthesis here
73
+ // .. highlight heads
74
+ const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads'));
75
+ relationHeadsAndLabels.forEach(relationHead => {
76
+ const headEntityId = relationHead['entity-id'];
77
+ const headEntityParts = document.querySelectorAll(`[data-entity-id="${headEntityId}"]`);
78
+ headEntityParts.forEach(headEntity => {
79
+ const label = relationHead['label'];
80
+ maybeSetColor(headEntity, 'head', label);
81
+ visitedEntities.add(headEntity);
82
+ }); // <-- Corrected closing parenthesis here
83
+ }); // <-- Corrected closing parenthesis here
84
+ }
85
+
86
+ // highlight other entities
87
+ entities.forEach(entity => {
88
+ if (!visitedEntities.has(entity)) {
89
+ const label = entity.getAttribute('data-label');
90
+ maybeSetColor(entity, 'other', label);
91
+ }
92
+ });
93
+ }
94
+ }
95
+ function setHoverAduId(entityId) {
96
+ // get the textarea element that holds the reference adu id
97
+ let hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
98
+ // set the value of the input field
99
+ hoverAduIdDiv.value = entityId;
100
+ // trigger an input event to update the state
101
+ var event = new Event('input');
102
+ hoverAduIdDiv.dispatchEvent(event);
103
+ }
104
+ function setReferenceAduIdFromHover() {
105
+ // get the hover adu id
106
+ const hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
107
+ // get the value of the input field
108
+ const entityId = hoverAduIdDiv.value;
109
+ // get the textarea element that holds the reference adu id
110
+ let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea');
111
+ // set the value of the input field
112
+ referenceAduIdDiv.value = entityId;
113
+ // trigger an input event to update the state
114
+ var event = new Event('input');
115
+ referenceAduIdDiv.dispatchEvent(event);
116
+ }
117
+
118
+ const entities = document.querySelectorAll('.entity');
119
+ entities.forEach(entity => {
120
+ // make the cursor a pointer
121
+ entity.style.cursor = 'pointer';
122
+ const alreadyHasListener = entity.getAttribute('data-has-listener');
123
+ if (alreadyHasListener) {
124
+ return;
125
+ }
126
+ entity.addEventListener('mouseover', () => {
127
+ const entityId = entity.getAttribute('data-entity-id');
128
+ highlightRelationArguments(entityId);
129
+ setHoverAduId(entityId);
130
+ });
131
+ entity.addEventListener('mouseout', () => {
132
+ highlightRelationArguments(null);
133
+ });
134
+ entity.setAttribute('data-has-listener', 'true');
135
+ });
136
+ const entityContainer = document.querySelector('.entities');
137
+ if (entityContainer) {
138
+ entityContainer.addEventListener('click', () => {
139
+ setReferenceAduIdFromHover();
140
+ });
141
+ // make the cursor a pointer
142
+ // entityContainer.style.cursor = 'pointer';
143
+ }
144
+ }
145
+ """
146
+
147
+
148
+ def render_pretty_table(
149
+ text: str,
150
+ spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
151
+ span_id2idx: Dict[str, int],
152
+ binary_relations: Sequence[BinaryRelation],
153
+ **render_kwargs,
154
+ ):
155
+ from prettytable import PrettyTable
156
+
157
+ t = PrettyTable()
158
+ t.field_names = ["head", "tail", "relation"]
159
+ t.align = "l"
160
+ for relation in list(binary_relations) + list(binary_relations):
161
+ t.add_row([str(relation.head), str(relation.tail), relation.label])
162
+
163
+ html = t.get_html_string(format=True)
164
+ html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
165
+
166
+ return html
167
+
168
+
169
+ def render_displacy(
170
+ text: str,
171
+ spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
172
+ span_id2idx: Dict[str, int],
173
+ binary_relations: Sequence[BinaryRelation],
174
+ inject_relations=True,
175
+ colors_hover=None,
176
+ entity_options={},
177
+ **render_kwargs,
178
+ ):
179
+
180
+ ents: List[Dict[str, Any]] = []
181
+ for entity_id, idx in span_id2idx.items():
182
+ labeled_span = spans[idx]
183
+ # pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
184
+ # on hover and to inject the relation data.
185
+ if isinstance(labeled_span, LabeledSpan):
186
+ ents.append(
187
+ {
188
+ "start": labeled_span.start,
189
+ "end": labeled_span.end,
190
+ "label": labeled_span.label,
191
+ "params": {"entity_id": entity_id, "slice_idx": 0},
192
+ }
193
+ )
194
+ elif isinstance(labeled_span, LabeledMultiSpan):
195
+ for i, (start, end) in enumerate(labeled_span.slices):
196
+ ents.append(
197
+ {
198
+ "start": start,
199
+ "end": end,
200
+ "label": labeled_span.label,
201
+ "params": {"entity_id": entity_id, "slice_idx": i},
202
+ }
203
+ )
204
+ else:
205
+ raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}")
206
+
207
+ ents_sorted = sorted(ents, key=lambda x: (x["start"], x["end"]))
208
+ spacy_doc = {
209
+ "text": text,
210
+ # the ents MUST be sorted by start and end
211
+ "ents": ents_sorted,
212
+ "title": None,
213
+ }
214
+
215
+ # copy to avoid modifying the original options
216
+ entity_options = entity_options.copy()
217
+ # use the custom template with the entity ID
218
+ entity_options["template"] = TPL_ENT_WITH_ID
219
+ renderer = EntityRenderer(options=entity_options)
220
+ html = renderer.render([spacy_doc], page=True, minify=True).strip()
221
+
222
+ html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
223
+ if inject_relations:
224
+ html = inject_relation_data(
225
+ html,
226
+ spans=spans,
227
+ span_id2idx=span_id2idx,
228
+ binary_relations=binary_relations,
229
+ additional_colors=colors_hover,
230
+ )
231
+ return html
232
+
233
+
234
+ def inject_relation_data(
235
+ html: str,
236
+ spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
237
+ span_id2idx: Dict[str, int],
238
+ binary_relations: Sequence[BinaryRelation],
239
+ additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
240
+ ) -> str:
241
+ from bs4 import BeautifulSoup
242
+
243
+ # Parse the HTML using BeautifulSoup
244
+ soup = BeautifulSoup(html, "html.parser")
245
+
246
+ entity2tails = defaultdict(list)
247
+ entity2heads = defaultdict(list)
248
+ for relation in binary_relations:
249
+ entity2heads[relation.tail].append((relation.head, relation.label))
250
+ entity2tails[relation.head].append((relation.tail, relation.label))
251
+
252
+ annotation2id = {spans[span_idx]: span_id for span_id, span_idx in span_id2idx.items()}
253
+ # Add unique IDs to each entity
254
+ entities = soup.find_all(class_="entity")
255
+ for entity in entities:
256
+ original_color = entity["style"].split("background:")[1].split(";")[0].strip()
257
+ entity["data-color-original"] = original_color
258
+ if additional_colors is not None:
259
+ for key, color in additional_colors.items():
260
+ entity[f"data-color-{key}"] = (
261
+ json.dumps(color) if isinstance(color, dict) else color
262
+ )
263
+
264
+ entity_annotation = spans[span_id2idx[entity["data-entity-id"]]]
265
+
266
+ # sanity check.
267
+ if isinstance(entity_annotation, LabeledSpan):
268
+ annotation_text = entity_annotation.resolve()[1]
269
+ elif isinstance(entity_annotation, LabeledMultiSpan):
270
+ slice_idx = int(entity["data-slice-idx"])
271
+ annotation_text = entity_annotation.resolve()[1][slice_idx]
272
+ else:
273
+ raise ValueError(f"Unsupported entity type: {type(entity_annotation)}")
274
+ annotation_text_without_newline = annotation_text.replace("\n", "")
275
+ # Just check the start, because the text has the label attached to the end
276
+ if not entity.text.startswith(annotation_text_without_newline):
277
+ logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
278
+
279
+ entity["data-label"] = entity_annotation.label
280
+ entity["data-relation-tails"] = json.dumps(
281
+ [
282
+ {"entity-id": annotation2id[tail], "label": label}
283
+ for tail, label in entity2tails.get(entity_annotation, [])
284
+ if tail in annotation2id
285
+ ]
286
+ )
287
+ entity["data-relation-heads"] = json.dumps(
288
+ [
289
+ {"entity-id": annotation2id[head], "label": label}
290
+ for head, label in entity2heads.get(entity_annotation, [])
291
+ if head in annotation2id
292
+ ]
293
+ )
294
+
295
+ # Return the modified HTML as a string
296
+ return str(soup)
src/demo/rendering_utils_displacy.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is mainly taken from
2
+ # https://github.com/explosion/spaCy/blob/master/spacy/displacy/templates.py, and from
3
+ # https://github.com/explosion/spaCy/blob/master/spacy/displacy/render.py.
4
+
5
+ # Setting explicit height and max-width: none on the SVG is required for
6
+ # Jupyter to render it properly in a cell
7
+
8
+ TPL_DEP_SVG = """
9
+ <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:lang="{lang}" id="{id}" class="displacy" width="{width}" height="{height}" direction="{dir}" style="max-width: none; height: {height}px; color: {color}; background: {bg}; font-family: {font}; direction: {dir}">{content}</svg>
10
+ """
11
+
12
+
13
+ TPL_DEP_WORDS = """
14
+ <text class="displacy-token" fill="currentColor" text-anchor="middle" y="{y}">
15
+ <tspan class="displacy-word" fill="currentColor" x="{x}">{text}</tspan>
16
+ <tspan class="displacy-tag" dy="2em" fill="currentColor" x="{x}">{tag}</tspan>
17
+ </text>
18
+ """
19
+
20
+
21
+ TPL_DEP_WORDS_LEMMA = """
22
+ <text class="displacy-token" fill="currentColor" text-anchor="middle" y="{y}">
23
+ <tspan class="displacy-word" fill="currentColor" x="{x}">{text}</tspan>
24
+ <tspan class="displacy-lemma" dy="2em" fill="currentColor" x="{x}">{lemma}</tspan>
25
+ <tspan class="displacy-tag" dy="2em" fill="currentColor" x="{x}">{tag}</tspan>
26
+ </text>
27
+ """
28
+
29
+
30
+ TPL_DEP_ARCS = """
31
+ <g class="displacy-arrow">
32
+ <path class="displacy-arc" id="arrow-{id}-{i}" stroke-width="{stroke}px" d="{arc}" fill="none" stroke="currentColor"/>
33
+ <text dy="1.25em" style="font-size: 0.8em; letter-spacing: 1px">
34
+ <textPath xlink:href="#arrow-{id}-{i}" class="displacy-label" startOffset="50%" side="{label_side}" fill="currentColor" text-anchor="middle">{label}</textPath>
35
+ </text>
36
+ <path class="displacy-arrowhead" d="{head}" fill="currentColor"/>
37
+ </g>
38
+ """
39
+
40
+
41
+ TPL_FIGURE = """
42
+ <figure style="margin-bottom: 6rem">{content}</figure>
43
+ """
44
+
45
+ TPL_TITLE = """
46
+ <h2 style="margin: 0">{title}</h2>
47
+ """
48
+
49
+
50
+ TPL_ENTS = """
51
+ <div class="entities" style="line-height: 2.5; direction: {dir}">{content}</div>
52
+ """
53
+
54
+
55
+ TPL_ENT = """
56
+ <mark class="entity" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
57
+ {text}
58
+ <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
59
+ </mark>
60
+ """
61
+
62
+ TPL_ENT_RTL = """
63
+ <mark class="entity" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em">
64
+ {text}
65
+ <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-right: 0.5rem">{label}</span>
66
+ </mark>
67
+ """
68
+
69
+
70
+ TPL_PAGE = """
71
+ <!DOCTYPE html>
72
+ <html lang="{lang}">
73
+ <head>
74
+ <title>displaCy</title>
75
+ </head>
76
+
77
+ <body style="font-size: 16px; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol'; padding: 4rem 2rem; direction: {dir}">{content}</body>
78
+ </html>
79
+ """
80
+
81
+
82
+ DEFAULT_LANG = "en"
83
+ DEFAULT_DIR = "ltr"
84
+
85
+
86
+ def minify_html(html):
87
+ """Perform a template-specific, rudimentary HTML minification for displaCy.
88
+ Disclaimer: NOT a general-purpose solution, only removes indentation and
89
+ newlines.
90
+
91
+ html (unicode): Markup to minify.
92
+ RETURNS (unicode): "Minified" HTML.
93
+ """
94
+ return html.strip().replace(" ", "").replace("\n", "")
95
+
96
+
97
+ def escape_html(text):
98
+ """Replace <, >, &, " with their HTML encoded representation. Intended to prevent HTML errors
99
+ in rendered displaCy markup.
100
+
101
+ text (unicode): The original text. RETURNS (unicode): Equivalent text to be safely used within
102
+ HTML.
103
+ """
104
+ text = text.replace("&", "&amp;")
105
+ text = text.replace("<", "&lt;")
106
+ text = text.replace(">", "&gt;")
107
+ text = text.replace('"', "&quot;")
108
+ return text
109
+
110
+
111
+ class EntityRenderer(object):
112
+ """Render named entities as HTML."""
113
+
114
+ style = "ent"
115
+
116
+ def __init__(self, options={}):
117
+ """Initialise dependency renderer.
118
+
119
+ options (dict): Visualiser-specific options (colors, ents)
120
+ """
121
+ colors = {
122
+ "ORG": "#7aecec",
123
+ "PRODUCT": "#bfeeb7",
124
+ "GPE": "#feca74",
125
+ "LOC": "#ff9561",
126
+ "PERSON": "#aa9cfc",
127
+ "NORP": "#c887fb",
128
+ "FACILITY": "#9cc9cc",
129
+ "EVENT": "#ffeb80",
130
+ "LAW": "#ff8197",
131
+ "LANGUAGE": "#ff8197",
132
+ "WORK_OF_ART": "#f0d0ff",
133
+ "DATE": "#bfe1d9",
134
+ "TIME": "#bfe1d9",
135
+ "MONEY": "#e4e7d2",
136
+ "QUANTITY": "#e4e7d2",
137
+ "ORDINAL": "#e4e7d2",
138
+ "CARDINAL": "#e4e7d2",
139
+ "PERCENT": "#e4e7d2",
140
+ }
141
+ # user_colors = registry.displacy_colors.get_all()
142
+ # for user_color in user_colors.values():
143
+ # colors.update(user_color)
144
+ colors.update(options.get("colors", {}))
145
+ self.default_color = "#ddd"
146
+ self.colors = colors
147
+ self.ents = options.get("ents", None)
148
+ self.direction = DEFAULT_DIR
149
+ self.lang = DEFAULT_LANG
150
+
151
+ template = options.get("template")
152
+ if template:
153
+ self.ent_template = template
154
+ else:
155
+ if self.direction == "rtl":
156
+ self.ent_template = TPL_ENT_RTL
157
+ else:
158
+ self.ent_template = TPL_ENT
159
+
160
+ def render(self, parsed, page=False, minify=False):
161
+ """Render complete markup.
162
+
163
+ parsed (list): Dependency parses to render. page (bool): Render parses wrapped as full HTML
164
+ page. minify (bool): Minify HTML markup. RETURNS (unicode): Rendered HTML markup.
165
+ """
166
+ rendered = []
167
+ for i, p in enumerate(parsed):
168
+ if i == 0:
169
+ settings = p.get("settings", {})
170
+ self.direction = settings.get("direction", DEFAULT_DIR)
171
+ self.lang = settings.get("lang", DEFAULT_LANG)
172
+ rendered.append(self.render_ents(p["text"], p["ents"], p.get("title")))
173
+ if page:
174
+ docs = "".join([TPL_FIGURE.format(content=doc) for doc in rendered])
175
+ markup = TPL_PAGE.format(content=docs, lang=self.lang, dir=self.direction)
176
+ else:
177
+ markup = "".join(rendered)
178
+ if minify:
179
+ return minify_html(markup)
180
+ return markup
181
+
182
+ def render_ents(self, text, spans, title):
183
+ """Render entities in text.
184
+
185
+ text (unicode): Original text. spans (list): Individual entity spans and their start, end
186
+ and label. title (unicode or None): Document title set in Doc.user_data['title'].
187
+ """
188
+ markup = ""
189
+ offset = 0
190
+ for span in spans:
191
+ label = span["label"]
192
+ start = span["start"]
193
+ end = span["end"]
194
+ additional_params = span.get("params", {})
195
+ entity = escape_html(text[start:end])
196
+ fragments = text[offset:start].split("\n")
197
+ for i, fragment in enumerate(fragments):
198
+ markup += escape_html(fragment)
199
+ if len(fragments) > 1 and i != len(fragments) - 1:
200
+ markup += "<br/>"
201
+ if self.ents is None or label.upper() in self.ents:
202
+ color = self.colors.get(label.upper(), self.default_color)
203
+ ent_settings = {"label": label, "text": entity, "bg": color}
204
+ ent_settings.update(additional_params)
205
+ markup += self.ent_template.format(**ent_settings)
206
+ else:
207
+ markup += entity
208
+ offset = end
209
+ fragments = text[offset:].split("\n")
210
+ for i, fragment in enumerate(fragments):
211
+ markup += escape_html(fragment)
212
+ if len(fragments) > 1 and i != len(fragments) - 1:
213
+ markup += "<br/>"
214
+ markup = TPL_ENTS.format(content=markup, dir=self.direction)
215
+ if title:
216
+ markup = TPL_TITLE.format(title=title) + markup
217
+ return markup
src/demo/retrieve_and_dump_all_relevant.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+
3
+ root = pyrootutils.setup_root(
4
+ search_from=__file__,
5
+ indicator=[".project-root"],
6
+ pythonpath=True,
7
+ dotenv=True,
8
+ )
9
+
10
+ import argparse
11
+ import logging
12
+
13
+ from src.demo.retriever_utils import (
14
+ retrieve_all_relevant_spans,
15
+ retrieve_all_relevant_spans_for_all_documents,
16
+ retrieve_relevant_spans,
17
+ )
18
+ from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ if __name__ == "__main__":
24
+
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument(
27
+ "-c",
28
+ "--config_path",
29
+ type=str,
30
+ default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml",
31
+ )
32
+ parser.add_argument(
33
+ "--data_path",
34
+ type=str,
35
+ required=True,
36
+ help="Path to a zip or directory containing a retriever dump.",
37
+ )
38
+ parser.add_argument("-k", "--top_k", type=int, default=10)
39
+ parser.add_argument("-t", "--threshold", type=float, default=0.95)
40
+ parser.add_argument(
41
+ "-o",
42
+ "--output_path",
43
+ type=str,
44
+ required=True,
45
+ )
46
+ parser.add_argument(
47
+ "--query_doc_id",
48
+ type=str,
49
+ default=None,
50
+ help="If provided, retrieve all spans for only this query document.",
51
+ )
52
+ parser.add_argument(
53
+ "--query_span_id",
54
+ type=str,
55
+ default=None,
56
+ help="If provided, retrieve all spans for only this query span.",
57
+ )
58
+ args = parser.parse_args()
59
+
60
+ logging.basicConfig(
61
+ format="%(asctime)s %(levelname)-8s %(message)s",
62
+ level=logging.INFO,
63
+ datefmt="%Y-%m-%d %H:%M:%S",
64
+ )
65
+
66
+ if not args.output_path.endswith(".json"):
67
+ raise ValueError("only support json output")
68
+
69
+ logger.info(f"instantiating retriever from {args.config_path}...")
70
+ retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file(
71
+ args.config_path
72
+ )
73
+ logger.info(f"loading data from {args.data_path}...")
74
+ retriever.load_from_disc(args.data_path)
75
+
76
+ search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
77
+ logger.info(f"use search_kwargs: {search_kwargs}")
78
+
79
+ if args.query_span_id is not None:
80
+ logger.warning(f"retrieving results for single span: {args.query_span_id}")
81
+ all_spans_for_all_documents = retrieve_relevant_spans(
82
+ retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
83
+ )
84
+ elif args.query_doc_id is not None:
85
+ logger.warning(f"retrieving results for single document: {args.query_doc_id}")
86
+ all_spans_for_all_documents = retrieve_all_relevant_spans(
87
+ retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
88
+ )
89
+ else:
90
+ all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents(
91
+ retriever=retriever, **search_kwargs
92
+ )
93
+
94
+ if all_spans_for_all_documents is None:
95
+ logger.warning("no relevant spans found in any document")
96
+ exit(0)
97
+
98
+ logger.info(f"dumping results to {args.output_path}...")
99
+ all_spans_for_all_documents.to_json(args.output_path)
100
+
101
+ logger.info("done")
src/demo/retriever_utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, Optional, Sequence, Tuple, Union
3
+
4
+ import gradio as gr
5
+ import pandas as pd
6
+ from pytorch_ie import Annotation
7
+ from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
8
+ from typing_extensions import Protocol
9
+
10
+ from src.langchain_modules import DocumentAwareSpanRetriever
11
+ from src.langchain_modules.span_retriever import (
12
+ DocumentAwareSpanRetrieverWithRelations,
13
+ _parse_config,
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def get_document_as_dict(retriever: DocumentAwareSpanRetriever, doc_id: str) -> Dict:
20
+ document = retriever.get_document(doc_id=doc_id)
21
+ return retriever.docstore.as_dict(document)
22
+
23
+
24
+ def load_retriever(
25
+ retriever_config_str: str,
26
+ config_format: str,
27
+ device: str = "cpu",
28
+ previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None,
29
+ ) -> DocumentAwareSpanRetrieverWithRelations:
30
+ try:
31
+ retriever_config = _parse_config(retriever_config_str, format=config_format)
32
+ # set device for the embeddings pipeline
33
+ retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device
34
+ result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config)
35
+ # if a previous retriever is provided, load all documents and vectors from the previous retriever
36
+ if previous_retriever is not None:
37
+ # documents
38
+ all_doc_ids = list(previous_retriever.docstore.yield_keys())
39
+ gr.Info(f"Storing {len(all_doc_ids)} documents from previous retriever...")
40
+ all_docs = previous_retriever.docstore.mget(all_doc_ids)
41
+ result.docstore.mset([(doc.id, doc) for doc in all_docs])
42
+ # spans (with vectors)
43
+ all_span_ids = list(previous_retriever.vectorstore.yield_keys())
44
+ all_spans = previous_retriever.vectorstore.mget(all_span_ids)
45
+ result.vectorstore.mset([(span.id, span) for span in all_spans])
46
+
47
+ gr.Info("Retriever loaded successfully.")
48
+ return result
49
+ except Exception as e:
50
+ raise gr.Error(f"Failed to load retriever: {e}")
51
+
52
+
53
+ def retrieve_similar_spans(
54
+ retriever: DocumentAwareSpanRetriever,
55
+ query_span_id: str,
56
+ **kwargs,
57
+ ) -> pd.DataFrame:
58
+ if not query_span_id.strip():
59
+ raise gr.Error("No query span selected.")
60
+ try:
61
+ retrieval_result = retriever.invoke(input=query_span_id, **kwargs)
62
+ records = []
63
+ for similar_span_doc in retrieval_result:
64
+ pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc)
65
+ span_ann = metadata["attached_span"]
66
+ records.append(
67
+ {
68
+ "doc_id": pie_doc.id,
69
+ "span_id": similar_span_doc.id,
70
+ "score": metadata["relevance_score"],
71
+ "label": span_ann.label,
72
+ "text": str(span_ann),
73
+ }
74
+ )
75
+ return (
76
+ pd.DataFrame(records, columns=["doc_id", "score", "label", "text", "span_id"])
77
+ .sort_values(by="score", ascending=False)
78
+ .round(3)
79
+ )
80
+ except Exception as e:
81
+ raise gr.Error(f"Failed to retrieve similar ADUs: {e}")
82
+
83
+
84
+ def retrieve_relevant_spans(
85
+ retriever: DocumentAwareSpanRetriever,
86
+ query_span_id: str,
87
+ relation_label_mapping: Optional[dict[str, str]] = None,
88
+ **kwargs,
89
+ ) -> pd.DataFrame:
90
+ if not query_span_id.strip():
91
+ raise gr.Error("No query span selected.")
92
+ try:
93
+ relation_label_mapping = relation_label_mapping or {}
94
+ retrieval_result = retriever.invoke(input=query_span_id, return_related=True, **kwargs)
95
+ records = []
96
+ for relevant_span_doc in retrieval_result:
97
+ pie_doc, metadata = retriever.docstore.unwrap_with_metadata(relevant_span_doc)
98
+ span_ann = metadata["attached_span"]
99
+ tail_span_ann = metadata["attached_tail_span"]
100
+ mapped_relation_label = relation_label_mapping.get(
101
+ metadata["relation_label"], metadata["relation_label"]
102
+ )
103
+ records.append(
104
+ {
105
+ "doc_id": pie_doc.id,
106
+ "type": mapped_relation_label,
107
+ "rel_score": metadata["relation_score"],
108
+ "text": str(tail_span_ann),
109
+ "span_id": relevant_span_doc.id,
110
+ "label": tail_span_ann.label,
111
+ "ref_score": metadata["relevance_score"],
112
+ "ref_label": span_ann.label,
113
+ "ref_text": str(span_ann),
114
+ "ref_span_id": metadata["head_id"],
115
+ }
116
+ )
117
+ return (
118
+ pd.DataFrame(
119
+ records,
120
+ columns=[
121
+ "type",
122
+ # omitted for now, we get no valid relation scores for the generative model
123
+ # "rel_score",
124
+ "ref_score",
125
+ "label",
126
+ "text",
127
+ "ref_label",
128
+ "ref_text",
129
+ "doc_id",
130
+ "span_id",
131
+ "ref_span_id",
132
+ ],
133
+ )
134
+ .sort_values(by=["ref_score"], ascending=False)
135
+ .round(3)
136
+ )
137
+ except Exception as e:
138
+ raise gr.Error(f"Failed to retrieve relevant ADUs: {e}")
139
+
140
+
141
+ class RetrieverCallable(Protocol):
142
+ def __call__(
143
+ self,
144
+ retriever: DocumentAwareSpanRetriever,
145
+ query_span_id: str,
146
+ **kwargs,
147
+ ) -> Optional[pd.DataFrame]:
148
+ pass
149
+
150
+
151
+ def _retrieve_for_all_spans(
152
+ retriever: DocumentAwareSpanRetriever,
153
+ query_doc_id: str,
154
+ retrieve_func: RetrieverCallable,
155
+ query_span_id_column: str = "query_span_id",
156
+ **kwargs,
157
+ ) -> Optional[pd.DataFrame]:
158
+ if not query_doc_id.strip():
159
+ raise gr.Error("No query document selected.")
160
+ try:
161
+ span_id2idx = retriever.get_span_id2idx_from_doc(query_doc_id)
162
+ gr.Info(f"Retrieving results for {len(span_id2idx)} ADUs in document {query_doc_id}...")
163
+ span_results = {
164
+ query_span_id: retrieve_func(
165
+ retriever=retriever,
166
+ query_span_id=query_span_id,
167
+ **kwargs,
168
+ )
169
+ for query_span_id in span_id2idx.keys()
170
+ }
171
+ span_results_not_empty = {
172
+ query_span_id: df
173
+ for query_span_id, df in span_results.items()
174
+ if df is not None and not df.empty
175
+ }
176
+
177
+ # add column with query_span_id
178
+ for query_span_id, query_span_result in span_results_not_empty.items():
179
+ query_span_result[query_span_id_column] = query_span_id
180
+
181
+ if len(span_results_not_empty) == 0:
182
+ gr.Info(f"No results found for any ADU in document {query_doc_id}.")
183
+ return None
184
+ else:
185
+ result = pd.concat(span_results_not_empty.values(), ignore_index=True)
186
+ gr.Info(f"Retrieved {len(result)} ADUs for document {query_doc_id}.")
187
+ return result
188
+ except Exception as e:
189
+ raise gr.Error(
190
+ f'Failed to retrieve results for all ADUs in document "{query_doc_id}": {e}'
191
+ )
192
+
193
+
194
+ def retrieve_all_similar_spans(
195
+ retriever: DocumentAwareSpanRetriever,
196
+ query_doc_id: str,
197
+ **kwargs,
198
+ ) -> Optional[pd.DataFrame]:
199
+ return _retrieve_for_all_spans(
200
+ retriever=retriever,
201
+ query_doc_id=query_doc_id,
202
+ retrieve_func=retrieve_similar_spans,
203
+ **kwargs,
204
+ )
205
+
206
+
207
+ def retrieve_all_relevant_spans(
208
+ retriever: DocumentAwareSpanRetriever,
209
+ query_doc_id: str,
210
+ **kwargs,
211
+ ) -> Optional[pd.DataFrame]:
212
+ return _retrieve_for_all_spans(
213
+ retriever=retriever,
214
+ query_doc_id=query_doc_id,
215
+ retrieve_func=retrieve_relevant_spans,
216
+ **kwargs,
217
+ )
218
+
219
+
220
+ class RetrieverForAllSpansCallable(Protocol):
221
+ def __call__(
222
+ self,
223
+ retriever: DocumentAwareSpanRetriever,
224
+ query_doc_id: str,
225
+ **kwargs,
226
+ ) -> Optional[pd.DataFrame]:
227
+ pass
228
+
229
+
230
+ def _retrieve_for_all_documents(
231
+ retriever: DocumentAwareSpanRetriever,
232
+ retrieve_func: RetrieverForAllSpansCallable,
233
+ query_doc_id_column: str = "query_doc_id",
234
+ **kwargs,
235
+ ) -> Optional[pd.DataFrame]:
236
+ try:
237
+ all_doc_ids = list(retriever.docstore.yield_keys())
238
+ gr.Info(f"Retrieving results for {len(all_doc_ids)} documents...")
239
+ doc_results = {
240
+ doc_id: retrieve_func(retriever=retriever, query_doc_id=doc_id, **kwargs)
241
+ for doc_id in all_doc_ids
242
+ }
243
+ doc_results_not_empty = {
244
+ doc_id: df for doc_id, df in doc_results.items() if df is not None and not df.empty
245
+ }
246
+ # add column with query_doc_id
247
+ for doc_id, doc_result in doc_results_not_empty.items():
248
+ doc_result[query_doc_id_column] = doc_id
249
+
250
+ if len(doc_results_not_empty) == 0:
251
+ gr.Info("No results found for any document.")
252
+ return None
253
+ else:
254
+ result = pd.concat(doc_results_not_empty, ignore_index=True)
255
+ gr.Info(f"Retrieved {len(result)} ADUs for all documents.")
256
+ return result
257
+ except Exception as e:
258
+ raise gr.Error(f"Failed to retrieve results for all documents: {e}")
259
+
260
+
261
+ def retrieve_all_similar_spans_for_all_documents(
262
+ retriever: DocumentAwareSpanRetriever,
263
+ **kwargs,
264
+ ) -> Optional[pd.DataFrame]:
265
+ return _retrieve_for_all_documents(
266
+ retriever=retriever,
267
+ retrieve_func=retrieve_all_similar_spans,
268
+ **kwargs,
269
+ )
270
+
271
+
272
+ def retrieve_all_relevant_spans_for_all_documents(
273
+ retriever: DocumentAwareSpanRetriever,
274
+ **kwargs,
275
+ ) -> Optional[pd.DataFrame]:
276
+ return _retrieve_for_all_documents(
277
+ retriever=retriever,
278
+ retrieve_func=retrieve_all_relevant_spans,
279
+ **kwargs,
280
+ )
281
+
282
+
283
+ def get_text_spans_and_relations_from_document(
284
+ retriever: DocumentAwareSpanRetrieverWithRelations, document_id: str
285
+ ) -> Tuple[
286
+ str,
287
+ Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
288
+ Dict[str, int],
289
+ Sequence[BinaryRelation],
290
+ ]:
291
+ document = retriever.get_document(doc_id=document_id)
292
+ pie_document = retriever.docstore.unwrap(document)
293
+ use_predicted_annotations = retriever.use_predicted_annotations(document)
294
+ spans = retriever.get_base_layer(
295
+ pie_document=pie_document, use_predicted_annotations=use_predicted_annotations
296
+ )
297
+ relations = retriever.get_relation_layer(
298
+ pie_document=pie_document, use_predicted_annotations=use_predicted_annotations
299
+ )
300
+ span_id2idx = retriever.get_span_id2idx_from_doc(document)
301
+ return pie_document.text, spans, span_id2idx, relations
302
+
303
+
304
+ def get_span_annotation(
305
+ retriever: DocumentAwareSpanRetriever,
306
+ span_id: str,
307
+ ) -> Annotation:
308
+ if span_id.strip() == "":
309
+ raise gr.Error("No span selected.")
310
+ try:
311
+ return retriever.get_span_by_id(span_id=span_id)
312
+ except Exception as e:
313
+ raise gr.Error(f"Failed to retrieve span annotation: {e}")
src/document/__init__.py ADDED
File without changes
src/document/processing.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Any, Dict, Iterable, List, Sequence, Set, Tuple, TypeVar, Union
5
+
6
+ import networkx as nx
7
+ from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
8
+ from pie_modules.documents import TextDocumentWithLabeledMultiSpansAndBinaryRelations
9
+ from pytorch_ie import AnnotationLayer
10
+ from pytorch_ie.core import Document
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ D = TypeVar("D", bound=Document)
16
+
17
+
18
+ def _remove_overlapping_entities(
19
+ entities: Iterable[Dict[str, Any]], relations: Iterable[Dict[str, Any]]
20
+ ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
21
+ sorted_entities = sorted(entities, key=lambda span: span["start"])
22
+ entities_wo_overlap = []
23
+ skipped_entities = []
24
+ last_end = 0
25
+ for entity_dict in sorted_entities:
26
+ if entity_dict["start"] < last_end:
27
+ skipped_entities.append(entity_dict)
28
+ else:
29
+ entities_wo_overlap.append(entity_dict)
30
+ last_end = entity_dict["end"]
31
+ if len(skipped_entities) > 0:
32
+ logger.warning(f"skipped overlapping entities: {skipped_entities}")
33
+ valid_entity_ids = set(entity_dict["_id"] for entity_dict in entities_wo_overlap)
34
+ valid_relations = [
35
+ relation_dict
36
+ for relation_dict in relations
37
+ if relation_dict["head"] in valid_entity_ids and relation_dict["tail"] in valid_entity_ids
38
+ ]
39
+ return entities_wo_overlap, valid_relations
40
+
41
+
42
+ def remove_overlapping_entities(
43
+ doc: D,
44
+ entity_layer_name: str = "entities",
45
+ relation_layer_name: str = "relations",
46
+ ) -> D:
47
+ # TODO: use document.add_all_annotations_from_other()
48
+ document_dict = doc.asdict()
49
+ entities_wo_overlap, valid_relations = _remove_overlapping_entities(
50
+ entities=document_dict[entity_layer_name]["annotations"],
51
+ relations=document_dict[relation_layer_name]["annotations"],
52
+ )
53
+
54
+ document_dict[entity_layer_name] = {
55
+ "annotations": entities_wo_overlap,
56
+ "predictions": [],
57
+ }
58
+ document_dict[relation_layer_name] = {
59
+ "annotations": valid_relations,
60
+ "predictions": [],
61
+ }
62
+ new_doc = type(doc).fromdict(document_dict)
63
+
64
+ return new_doc
65
+
66
+
67
+ def _merge_spans_via_relation(
68
+ spans: Sequence[LabeledSpan],
69
+ relations: Sequence[BinaryRelation],
70
+ link_relation_label: str,
71
+ create_multi_spans: bool = True,
72
+ ) -> Tuple[Union[Set[LabeledSpan], Set[LabeledMultiSpan]], Set[BinaryRelation]]:
73
+ # convert list of relations to a graph to easily calculate connected components to merge
74
+ g = nx.Graph()
75
+ link_relations = []
76
+ other_relations = []
77
+ for rel in relations:
78
+ if rel.label == link_relation_label:
79
+ link_relations.append(rel)
80
+ # never merge spans that have not the same label
81
+ if (
82
+ not (isinstance(rel.head, LabeledSpan) or isinstance(rel.tail, LabeledSpan))
83
+ or rel.head.label == rel.tail.label
84
+ ):
85
+ g.add_edge(rel.head, rel.tail)
86
+ else:
87
+ logger.debug(
88
+ f"spans to merge do not have the same label, do not merge them: {rel.head}, {rel.tail}"
89
+ )
90
+ else:
91
+ other_relations.append(rel)
92
+
93
+ span_mapping = {}
94
+ connected_components: Set[LabeledSpan]
95
+ for connected_components in nx.connected_components(g):
96
+ # all spans in a connected component have the same label
97
+ label = list(span.label for span in connected_components)[0]
98
+ connected_components_sorted = sorted(connected_components, key=lambda span: span.start)
99
+ if create_multi_spans:
100
+ new_span = LabeledMultiSpan(
101
+ slices=tuple((span.start, span.end) for span in connected_components_sorted),
102
+ label=label,
103
+ )
104
+ else:
105
+ new_span = LabeledSpan(
106
+ start=min(span.start for span in connected_components_sorted),
107
+ end=max(span.end for span in connected_components_sorted),
108
+ label=label,
109
+ )
110
+ for span in connected_components_sorted:
111
+ span_mapping[span] = new_span
112
+ for span in spans:
113
+ if span not in span_mapping:
114
+ if create_multi_spans:
115
+ span_mapping[span] = LabeledMultiSpan(
116
+ slices=((span.start, span.end),), label=span.label, score=span.score
117
+ )
118
+ else:
119
+ span_mapping[span] = LabeledSpan(
120
+ start=span.start, end=span.end, label=span.label, score=span.score
121
+ )
122
+
123
+ new_spans = set(span_mapping.values())
124
+ new_relations = set(
125
+ BinaryRelation(
126
+ head=span_mapping[rel.head],
127
+ tail=span_mapping[rel.tail],
128
+ label=rel.label,
129
+ score=rel.score,
130
+ )
131
+ for rel in other_relations
132
+ )
133
+
134
+ return new_spans, new_relations
135
+
136
+
137
+ def merge_spans_via_relation(
138
+ document: D,
139
+ relation_layer: str,
140
+ link_relation_label: str,
141
+ use_predicted_spans: bool = False,
142
+ process_predictions: bool = True,
143
+ create_multi_spans: bool = False,
144
+ ) -> D:
145
+
146
+ rel_layer = document[relation_layer]
147
+ span_layer = rel_layer.target_layer
148
+ new_gold_spans, new_gold_relations = _merge_spans_via_relation(
149
+ spans=span_layer,
150
+ relations=rel_layer,
151
+ link_relation_label=link_relation_label,
152
+ create_multi_spans=create_multi_spans,
153
+ )
154
+ if process_predictions:
155
+ new_pred_spans, new_pred_relations = _merge_spans_via_relation(
156
+ spans=span_layer.predictions if use_predicted_spans else span_layer,
157
+ relations=rel_layer.predictions,
158
+ link_relation_label=link_relation_label,
159
+ create_multi_spans=create_multi_spans,
160
+ )
161
+ else:
162
+ assert not use_predicted_spans
163
+ new_pred_spans = set(span_layer.predictions.clear())
164
+ new_pred_relations = set(rel_layer.predictions.clear())
165
+
166
+ relation_layer_name = relation_layer
167
+ span_layer_name = document[relation_layer].target_name
168
+ if create_multi_spans:
169
+ doc_dict = document.asdict()
170
+ for f in document.annotation_fields():
171
+ doc_dict.pop(f.name)
172
+
173
+ result = TextDocumentWithLabeledMultiSpansAndBinaryRelations.fromdict(doc_dict)
174
+ result.labeled_multi_spans.extend(new_gold_spans)
175
+ result.labeled_multi_spans.predictions.extend(new_pred_spans)
176
+ result.binary_relations.extend(new_gold_relations)
177
+ result.binary_relations.predictions.extend(new_pred_relations)
178
+ else:
179
+ result = document.copy(with_annotations=False)
180
+ result[span_layer_name].extend(new_gold_spans)
181
+ result[span_layer_name].predictions.extend(new_pred_spans)
182
+ result[relation_layer_name].extend(new_gold_relations)
183
+ result[relation_layer_name].predictions.extend(new_pred_relations)
184
+
185
+ return result
186
+
187
+
188
+ def remove_partitions_by_labels(
189
+ document: D, partition_layer: str, label_blacklist: List[str]
190
+ ) -> D:
191
+ document = document.copy()
192
+ layer: AnnotationLayer = document[partition_layer]
193
+ new_partitions = []
194
+ for partition in layer.clear():
195
+ if partition.label not in label_blacklist:
196
+ new_partitions.append(partition)
197
+ layer.extend(new_partitions)
198
+ return document
199
+
200
+
201
+ D_text = TypeVar("D_text", bound=Document)
202
+
203
+
204
+ def replace_substrings_in_text(
205
+ document: D_text, replacements: Dict[str, str], enforce_same_length: bool = True
206
+ ) -> D_text:
207
+ new_text = document.text
208
+ for old_str, new_str in replacements.items():
209
+ if enforce_same_length and len(old_str) != len(new_str):
210
+ raise ValueError(
211
+ f'Replacement strings must have the same length, but got "{old_str}" -> "{new_str}"'
212
+ )
213
+ new_text = new_text.replace(old_str, new_str)
214
+ result_dict = document.asdict()
215
+ result_dict["text"] = new_text
216
+ result = type(document).fromdict(result_dict)
217
+ result.text = new_text
218
+ return result
219
+
220
+
221
+ def replace_substrings_in_text_with_spaces(document: D_text, substrings: Iterable[str]) -> D_text:
222
+ replacements = {substring: " " * len(substring) for substring in substrings}
223
+ return replace_substrings_in_text(document, replacements=replacements)
src/evaluate.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+
3
+ root = pyrootutils.setup_root(
4
+ search_from=__file__,
5
+ indicator=[".project-root"],
6
+ pythonpath=True,
7
+ dotenv=True,
8
+ )
9
+
10
+ # ------------------------------------------------------------------------------------ #
11
+ # `pyrootutils.setup_root(...)` is an optional line at the top of each entry file
12
+ # that helps to make the environment more robust and convenient
13
+ #
14
+ # the main advantages are:
15
+ # - allows you to keep all entry files in "src/" without installing project as a package
16
+ # - makes paths and scripts always work no matter where is your current work dir
17
+ # - automatically loads environment variables from ".env" file if exists
18
+ #
19
+ # how it works:
20
+ # - the line above recursively searches for either ".git" or "pyproject.toml" in present
21
+ # and parent dirs, to determine the project root dir
22
+ # - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
23
+ # any place without installing project as a package
24
+ # - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
25
+ # to make all paths always relative to the project root
26
+ # - loads environment variables from ".env" file in root dir (if `dotenv=True`)
27
+ #
28
+ # you can remove `pyrootutils.setup_root(...)` if you:
29
+ # 1. either install project as a package or move each entry file to the project root dir
30
+ # 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
31
+ # 3. always run entry files from the project root dir
32
+ #
33
+ # https://github.com/ashleve/pyrootutils
34
+ # ------------------------------------------------------------------------------------ #
35
+
36
+ from typing import Tuple
37
+
38
+ import hydra
39
+ import pytorch_lightning as pl
40
+ from omegaconf import DictConfig
41
+ from pie_datasets import DatasetDict
42
+ from pie_modules.models import * # noqa: F403
43
+ from pie_modules.taskmodules import * # noqa: F403
44
+ from pytorch_ie.core import PyTorchIEModel, TaskModule
45
+ from pytorch_ie.models import * # noqa: F403
46
+ from pytorch_ie.taskmodules import * # noqa: F403
47
+ from pytorch_lightning import Trainer
48
+
49
+ from src import utils
50
+ from src.datamodules import PieDataModule
51
+ from src.models import * # noqa: F403
52
+ from src.taskmodules import * # noqa: F403
53
+
54
+ log = utils.get_pylogger(__name__)
55
+
56
+
57
+ @utils.task_wrapper
58
+ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
59
+ """Evaluates given checkpoint on a datamodule testset.
60
+
61
+ This method is wrapped in optional @task_wrapper decorator which applies extra utilities
62
+ before and after the call.
63
+
64
+ Args:
65
+ cfg (DictConfig): Configuration composed by Hydra.
66
+
67
+ Returns:
68
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
69
+ """
70
+
71
+ # Set seed for random number generators in pytorch, numpy and python.random
72
+ if cfg.get("seed"):
73
+ pl.seed_everything(cfg.seed, workers=True)
74
+
75
+ # Init pytorch-ie dataset
76
+ log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
77
+ dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial")
78
+
79
+ # Init pytorch-ie taskmodule
80
+ log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
81
+ taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")
82
+
83
+ # auto-convert the dataset if the metric specifies a document type
84
+ dataset = taskmodule.convert_dataset(dataset)
85
+
86
+ # Init pytorch-ie datamodule
87
+ log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
88
+ datamodule: PieDataModule = hydra.utils.instantiate(
89
+ cfg.datamodule, dataset=dataset, taskmodule=taskmodule, _convert_="partial"
90
+ )
91
+
92
+ # Init pytorch-ie model
93
+ log.info(f"Instantiating model <{cfg.model._target_}>")
94
+ model: PyTorchIEModel = hydra.utils.instantiate(cfg.model, _convert_="partial")
95
+
96
+ # Init lightning loggers
97
+ logger = utils.instantiate_dict_entries(cfg, "logger")
98
+
99
+ # Init lightning trainer
100
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
101
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger, _convert_="partial")
102
+
103
+ object_dict = {
104
+ "cfg": cfg,
105
+ "taskmodule": taskmodule,
106
+ "dataset": dataset,
107
+ "model": model,
108
+ "logger": logger,
109
+ "trainer": trainer,
110
+ }
111
+
112
+ if logger:
113
+ log.info("Logging hyperparameters!")
114
+ utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)
115
+
116
+ log.info("Starting testing!")
117
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
118
+
119
+ # for predictions use trainer.predict(...)
120
+ # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
121
+
122
+ metric_dict = trainer.callback_metrics
123
+
124
+ return metric_dict, object_dict
125
+
126
+
127
+ @hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="evaluate.yaml")
128
+ def main(cfg: DictConfig) -> None:
129
+ metric_dict, _ = evaluate(cfg)
130
+
131
+ return metric_dict
132
+
133
+
134
+ if __name__ == "__main__":
135
+ utils.replace_sys_args_with_values_from_files()
136
+ utils.prepare_omegaconf()
137
+ main()
src/evaluate_documents.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+
3
+ root = pyrootutils.setup_root(
4
+ search_from=__file__,
5
+ indicator=[".project-root"],
6
+ pythonpath=True,
7
+ dotenv=True,
8
+ )
9
+
10
+ # ------------------------------------------------------------------------------------ #
11
+ # `pyrootutils.setup_root(...)` is an optional line at the top of each entry file
12
+ # that helps to make the environment more robust and convenient
13
+ #
14
+ # the main advantages are:
15
+ # - allows you to keep all entry files in "src/" without installing project as a package
16
+ # - makes paths and scripts always work no matter where is your current work dir
17
+ # - automatically loads environment variables from ".env" file if exists
18
+ #
19
+ # how it works:
20
+ # - the line above recursively searches for either ".git" or "pyproject.toml" in present
21
+ # and parent dirs, to determine the project root dir
22
+ # - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
23
+ # any place without installing project as a package
24
+ # - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
25
+ # to make all paths always relative to the project root
26
+ # - loads environment variables from ".env" file in root dir (if `dotenv=True`)
27
+ #
28
+ # you can remove `pyrootutils.setup_root(...)` if you:
29
+ # 1. either install project as a package or move each entry file to the project root dir
30
+ # 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
31
+ # 3. always run entry files from the project root dir
32
+ #
33
+ # https://github.com/ashleve/pyrootutils
34
+ # ------------------------------------------------------------------------------------ #
35
+
36
+ from typing import Any, Tuple
37
+
38
+ import hydra
39
+ import pytorch_lightning as pl
40
+ from omegaconf import DictConfig
41
+ from pie_datasets import DatasetDict
42
+ from pytorch_ie.core import DocumentMetric
43
+ from pytorch_ie.metrics import * # noqa: F403
44
+
45
+ from src import utils
46
+ from src.metrics import * # noqa: F403
47
+
48
+ log = utils.get_pylogger(__name__)
49
+
50
+
51
+ @utils.task_wrapper
52
+ def evaluate_documents(cfg: DictConfig) -> Tuple[dict, dict]:
53
+ """Evaluates serialized PIE documents.
54
+
55
+ This method is wrapped in optional @task_wrapper decorator which applies extra utilities
56
+ before and after the call.
57
+ Args:
58
+ cfg (DictConfig): Configuration composed by Hydra.
59
+ Returns:
60
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
61
+ """
62
+
63
+ # Set seed for random number generators in pytorch, numpy and python.random
64
+ if cfg.get("seed"):
65
+ pl.seed_everything(cfg.seed, workers=True)
66
+
67
+ # Init pytorch-ie dataset
68
+ log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
69
+ dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial")
70
+
71
+ # Init pytorch-ie taskmodule
72
+ log.info(f"Instantiating metric <{cfg.metric._target_}>")
73
+ metric: DocumentMetric = hydra.utils.instantiate(cfg.metric, _convert_="partial")
74
+
75
+ # auto-convert the dataset if the metric specifies a document type
76
+ dataset = metric.convert_dataset(dataset)
77
+
78
+ # Init lightning loggers
79
+ loggers = utils.instantiate_dict_entries(cfg, "logger")
80
+
81
+ object_dict = {
82
+ "cfg": cfg,
83
+ "dataset": dataset,
84
+ "metric": metric,
85
+ "logger": loggers,
86
+ }
87
+
88
+ if loggers:
89
+ log.info("Logging hyperparameters!")
90
+ # send hparams to all loggers
91
+ for logger in loggers:
92
+ logger.log_hyperparams(cfg)
93
+
94
+ splits = cfg.get("splits", None)
95
+ if splits is None:
96
+ documents = dataset
97
+ else:
98
+ documents = type(dataset)({k: v for k, v in dataset.items() if k in splits})
99
+
100
+ metric_dict = metric(documents)
101
+
102
+ return metric_dict, object_dict
103
+
104
+
105
+ @hydra.main(
106
+ version_base="1.2", config_path=str(root / "configs"), config_name="evaluate_documents.yaml"
107
+ )
108
+ def main(cfg: DictConfig) -> Any:
109
+ metric_dict, _ = evaluate_documents(cfg)
110
+ return metric_dict
111
+
112
+
113
+ if __name__ == "__main__":
114
+ utils.replace_sys_args_with_values_from_files()
115
+ utils.prepare_omegaconf()
116
+ main()
src/hydra_callbacks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .save_job_return_value import SaveJobReturnValueCallback
src/hydra_callbacks/save_job_return_value.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pickle
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Generator, List, Tuple, Union
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ from hydra.core.utils import JobReturn
12
+ from hydra.experimental.callback import Callback
13
+ from omegaconf import DictConfig
14
+
15
+
16
+ def to_py_obj(obj):
17
+ """Convert a PyTorch tensor, Numpy array or python list to a python list.
18
+
19
+ Modified version of transformers.utils.generic.to_py_obj.
20
+ """
21
+ if isinstance(obj, dict):
22
+ return {k: to_py_obj(v) for k, v in obj.items()}
23
+ elif isinstance(obj, (list, tuple)):
24
+ return type(obj)(to_py_obj(o) for o in obj)
25
+ elif isinstance(obj, torch.Tensor):
26
+ return obj.detach().cpu().tolist()
27
+ elif isinstance(obj, (np.ndarray, np.number)): # tolist also works on 0d np arrays
28
+ return obj.tolist()
29
+ else:
30
+ return obj
31
+
32
+
33
+ def list_of_dicts_to_dict_of_lists_recursive(list_of_dicts):
34
+ """Convert a list of dicts to a dict of lists recursively.
35
+
36
+ Example:
37
+ # works with nested dicts
38
+ >>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": {"c": 2}}, {"a": 3, "b": {"c": 4}}])
39
+ {'b': {'c': [2, 4]}, 'a': [1, 3]}
40
+ # works with incomplete dicts
41
+ >>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": 2}, {"a": 3}])
42
+ {'b': [2, None], 'a': [1, 3]}
43
+
44
+ Args:
45
+ list_of_dicts (List[dict]): A list of dicts.
46
+
47
+ Returns:
48
+ dict: A dict of lists.
49
+ """
50
+ if isinstance(list_of_dicts, list):
51
+ if len(list_of_dicts) == 0:
52
+ return {}
53
+ elif isinstance(list_of_dicts[0], dict):
54
+ keys = set()
55
+ for d in list_of_dicts:
56
+ if not isinstance(d, dict):
57
+ raise ValueError("Not all elements of the list are dicts.")
58
+ keys.update(d.keys())
59
+ return {
60
+ k: list_of_dicts_to_dict_of_lists_recursive(
61
+ [d.get(k, None) for d in list_of_dicts]
62
+ )
63
+ for k in keys
64
+ }
65
+ else:
66
+ return list_of_dicts
67
+ else:
68
+ return list_of_dicts
69
+
70
+
71
+ def _flatten_dict_gen(d, parent_key: Tuple[str, ...] = ()) -> Generator:
72
+ for k, v in d.items():
73
+ new_key = parent_key + (k,)
74
+ if isinstance(v, dict):
75
+ yield from dict(_flatten_dict_gen(v, new_key)).items()
76
+ else:
77
+ yield new_key, v
78
+
79
+
80
+ def flatten_dict(d: Dict[str, Any]) -> Dict[Tuple[str, ...], Any]:
81
+ return dict(_flatten_dict_gen(d))
82
+
83
+
84
+ def unflatten_dict(d: Dict[Tuple[str, ...], Any]) -> Union[Dict[str, Any], Any]:
85
+ """Unflattens a dictionary with nested keys.
86
+
87
+ Example:
88
+ >>> d = {("a", "b", "c"): 1, ("a", "b", "d"): 2, ("a", "e"): 3}
89
+ >>> unflatten_dict(d)
90
+ {'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
91
+ """
92
+ result: Dict[str, Any] = {}
93
+ for k, v in d.items():
94
+ if len(k) == 0:
95
+ if len(result) > 1:
96
+ raise ValueError("Cannot unflatten dictionary with multiple root keys.")
97
+ return v
98
+ current = result
99
+ for key in k[:-1]:
100
+ current = current.setdefault(key, {})
101
+ current[k[-1]] = v
102
+ return result
103
+
104
+
105
+ def overrides_to_identifiers(overrides_per_result: List[List[str]], sep: str = "-") -> List[str]:
106
+ """Converts a list of lists of overrides to a list of identifiers. But takes only the overrides
107
+ into account, that are not identical for all results.
108
+
109
+ Example:
110
+ >>> overrides_per_result = [
111
+ ... ["a=1", "b=2", "c=3"],
112
+ ... ["a=1", "b=2", "c=4"],
113
+ ... ["a=1", "b=3", "c=3"],
114
+ ]
115
+ >>> overrides_to_identifiers(overrides_per_result)
116
+ ['b=2-c=3', 'b=2-c=4', 'b=3-c=3']
117
+
118
+ Args:
119
+ overrides_per_result (List[List[str]]): A list of lists of overrides.
120
+ sep (str, optional): The separator to use between the overrides. Defaults to "-".
121
+
122
+ Returns:
123
+ List[str]: A list of identifiers.
124
+ """
125
+ # get the overrides that are not identical for all results
126
+ overrides_per_result_transposed = np.array(overrides_per_result).T.tolist()
127
+ indices = [
128
+ i for i, entries in enumerate(overrides_per_result_transposed) if len(set(entries)) > 1
129
+ ]
130
+ # convert the overrides to identifiers
131
+ identifiers = [
132
+ sep.join([overrides[idx] for idx in indices]) for overrides in overrides_per_result
133
+ ]
134
+ return identifiers
135
+
136
+
137
+ class SaveJobReturnValueCallback(Callback):
138
+ """Save the job return-value in ${output_dir}/{job_return_value_filename}.
139
+
140
+ This also works for multi-runs (e.g. sweeps for hyperparameter search). In this case, the result will be saved
141
+ additionally in a common file in the multi-run log directory. If integrate_multirun_result=True, the
142
+ job return-values are also aggregated (e.g. mean, min, max) and saved in another file.
143
+
144
+ params:
145
+ -------
146
+ filenames: str or List[str] (default: "job_return_value.json")
147
+ The filename(s) of the file(s) to save the job return-value to. If it ends with ".json",
148
+ the return-value will be saved as a json file. If it ends with ".pkl", the return-value will be
149
+ saved as a pickle file, if it ends with ".md", the return-value will be saved as a markdown file.
150
+ integrate_multirun_result: bool (default: True)
151
+ If True, the job return-values of all jobs from a multi-run will be rearranged into a dict of lists (maybe
152
+ nested), where the keys are the keys of the job return-values and the values are lists of the corresponding
153
+ values of all jobs. This is useful if you want to access specific values of all jobs in a multi-run all at once.
154
+ Also, aggregated values (e.g. mean, min, max) are created for all numeric values and saved in another file.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ filenames: Union[str, List[str]] = "job_return_value.json",
160
+ integrate_multirun_result: bool = False,
161
+ ) -> None:
162
+ self.log = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
163
+ self.filenames = [filenames] if isinstance(filenames, str) else filenames
164
+ self.integrate_multirun_result = integrate_multirun_result
165
+ self.job_returns: List[JobReturn] = []
166
+
167
+ def on_job_end(self, config: DictConfig, job_return: JobReturn, **kwargs: Any) -> None:
168
+ self.job_returns.append(job_return)
169
+ output_dir = Path(config.hydra.runtime.output_dir) # / Path(config.hydra.output_subdir)
170
+ for filename in self.filenames:
171
+ self._save(obj=job_return.return_value, filename=filename, output_dir=output_dir)
172
+
173
+ def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
174
+ if self.integrate_multirun_result:
175
+ # rearrange the job return-values of all jobs from a multi-run into a dict of lists (maybe nested),
176
+ obj = list_of_dicts_to_dict_of_lists_recursive(
177
+ [jr.return_value for jr in self.job_returns]
178
+ )
179
+ # also create an aggregated result
180
+ # convert to python object to allow selecting numeric columns
181
+ obj_py = to_py_obj(obj)
182
+ obj_flat = flatten_dict(obj_py)
183
+ # create dataframe from flattened dict
184
+ df_flat = pd.DataFrame(obj_flat)
185
+ # select only the numeric values
186
+ df_numbers_only = df_flat.select_dtypes(["number"])
187
+ cols_removed = set(df_flat.columns) - set(df_numbers_only.columns)
188
+ if len(cols_removed) > 0:
189
+ self.log.warning(
190
+ f"Removed the following columns from the aggregated result because they are not numeric: "
191
+ f"{cols_removed}"
192
+ )
193
+ if len(df_numbers_only.columns) == 0:
194
+ obj_aggregated = None
195
+ else:
196
+ # aggregate the numeric values
197
+ df_described = df_numbers_only.describe()
198
+ # add the aggregation keys (e.g. mean, min, ...) as most inner keys and convert back to dict
199
+ obj_flat_aggregated = df_described.T.stack().to_dict()
200
+ # unflatten because _save() works better with nested dicts
201
+ obj_aggregated = unflatten_dict(obj_flat_aggregated)
202
+ else:
203
+ # create a dict of the job return-values of all jobs from a multi-run
204
+ # (_save() works better with nested dicts)
205
+ ids = overrides_to_identifiers([jr.overrides for jr in self.job_returns])
206
+ obj = {identifier: jr.return_value for identifier, jr in zip(ids, self.job_returns)}
207
+ obj_aggregated = None
208
+ output_dir = Path(config.hydra.sweep.dir)
209
+ for filename in self.filenames:
210
+ self._save(
211
+ obj=obj,
212
+ filename=filename,
213
+ output_dir=output_dir,
214
+ multi_run_result=self.integrate_multirun_result,
215
+ )
216
+ # if available, also save the aggregated result
217
+ if obj_aggregated is not None:
218
+ file_base_name, ext = os.path.splitext(filename)
219
+ filename_aggregated = f"{file_base_name}.aggregated{ext}"
220
+ self._save(obj=obj_aggregated, filename=filename_aggregated, output_dir=output_dir)
221
+
222
+ def _save(
223
+ self, obj: Any, filename: str, output_dir: Path, multi_run_result: bool = False
224
+ ) -> None:
225
+ self.log.info(f"Saving job_return in {output_dir / filename}")
226
+ output_dir.mkdir(parents=True, exist_ok=True)
227
+ assert output_dir is not None
228
+ if filename.endswith(".pkl"):
229
+ with open(str(output_dir / filename), "wb") as file:
230
+ pickle.dump(obj, file, protocol=4)
231
+ elif filename.endswith(".json"):
232
+ # Convert PyTorch tensors and numpy arrays to native python types
233
+ obj_py = to_py_obj(obj)
234
+ with open(str(output_dir / filename), "w") as file:
235
+ json.dump(obj_py, file, indent=2)
236
+ elif filename.endswith(".md"):
237
+ # Convert PyTorch tensors and numpy arrays to native python types
238
+ obj_py = to_py_obj(obj)
239
+ obj_py_flat = flatten_dict(obj_py)
240
+
241
+ if multi_run_result:
242
+ # In the case of multi-run, we expect to have multiple values for each key.
243
+ # We therefore just convert the dict to a pandas DataFrame.
244
+ result = pd.DataFrame(obj_py_flat)
245
+ else:
246
+ # In the case of a single job, we expect to have only one value for each key.
247
+ # We therefore convert the dict to a pandas Series and ...
248
+ series = pd.Series(obj_py_flat)
249
+ if len(series.index.levels) > 1:
250
+ # ... if the Series has multiple index levels, we create a DataFrame by unstacking the last level.
251
+ result = series.unstack(-1)
252
+ else:
253
+ # ... otherwise we just unpack the one-entry index values and save the resulting Series.
254
+ series.index = series.index.get_level_values(0)
255
+ result = series
256
+
257
+ with open(str(output_dir / filename), "w") as file:
258
+ file.write(result.to_markdown())
259
+
260
+ else:
261
+ raise ValueError("Unknown file extension")
src/langchain_modules/span_retriever.py CHANGED
@@ -4,7 +4,7 @@ import uuid
4
  from collections import defaultdict
5
  from copy import copy
6
  from enum import Enum
7
- from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union
8
 
9
  from langchain_core.callbacks import (
10
  AsyncCallbackManagerForRetrieverRun,
@@ -674,14 +674,14 @@ class DocumentAwareSpanRetriever(BaseRetriever, SerializableStore):
674
 
675
  def add_pie_documents(
676
  self,
677
- documents: List[TextBasedDocument],
678
  use_predicted_annotations: bool,
679
  metadata: Optional[Dict[str, Any]] = None,
680
  ) -> None:
681
  """Add pie documents to the retriever.
682
 
683
  Args:
684
- documents: List of pie documents to add
685
  use_predicted_annotations: Whether to use the predicted annotations or the gold annotations
686
  metadata: Optional metadata to add to each document
687
  """
 
4
  from collections import defaultdict
5
  from copy import copy
6
  from enum import Enum
7
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Type, Union
8
 
9
  from langchain_core.callbacks import (
10
  AsyncCallbackManagerForRetrieverRun,
 
674
 
675
  def add_pie_documents(
676
  self,
677
+ documents: Iterable[TextBasedDocument],
678
  use_predicted_annotations: bool,
679
  metadata: Optional[Dict[str, Any]] = None,
680
  ) -> None:
681
  """Add pie documents to the retriever.
682
 
683
  Args:
684
+ documents: Iterable of pie documents to add
685
  use_predicted_annotations: Whether to use the predicted annotations or the gold annotations
686
  metadata: Optional metadata to add to each document
687
  """
src/metrics/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .coref_sklearn import CorefMetricsSKLearn
2
+ from .coref_torchmetrics import CorefMetricsTorchmetrics
src/metrics/annotation_processor.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+
3
+ from pytorch_ie.annotations import BinaryRelation, Span, LabeledMultiSpan
4
+ from pytorch_ie.core import Annotation
5
+
6
+
7
+ def decode_span_without_label(ann: Annotation) -> Tuple[Tuple[int, int], ...]:
8
+ if isinstance(ann, Span):
9
+ return (ann.start, ann.end),
10
+ elif isinstance(ann, LabeledMultiSpan):
11
+ return ann.slices
12
+ else:
13
+ raise ValueError("Annotation must be a Span or LabeledMultiSpan")
14
+
15
+
16
+ def to_binary_relation_without_argument_labels(ann: Annotation) -> Tuple[Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int], ...], str]:
17
+ if not isinstance(ann, BinaryRelation):
18
+ raise ValueError("Annotation must be a BinaryRelation")
19
+ return (
20
+ decode_span_without_label(ann.head),
21
+ decode_span_without_label(ann.tail),
22
+ ann.label,
23
+ )
src/metrics/coref_sklearn.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from pandas import MultiIndex
8
+ from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
9
+ from pytorch_ie import DocumentMetric
10
+ from pytorch_ie.core.metric import T
11
+ from pytorch_ie.utils.hydra import resolve_target
12
+ from torchmetrics import Metric, MetricCollection
13
+
14
+ from src.hydra_callbacks.save_job_return_value import to_py_obj
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def get_num_total(targets: List[int], preds: List[float]):
20
+ return len(targets)
21
+
22
+
23
+ def get_num_positives(targets: List[int], preds: List[float], positive_idx: int = 1):
24
+ return len([v for v in targets if v == positive_idx])
25
+
26
+
27
+ def discretize(
28
+ values: List[float], threshold: Union[float, List[float], dict]
29
+ ) -> Union[List[float], Dict[Any, List[float]]]:
30
+ if isinstance(threshold, float):
31
+ result = (np.array(values) >= threshold).astype(int).tolist()
32
+ return result
33
+ if isinstance(threshold, list):
34
+ return {t: discretize(values=values, threshold=t) for t in threshold} # type: ignore
35
+ if isinstance(threshold, dict):
36
+ thresholds = (
37
+ np.arange(threshold["start"], threshold["end"], threshold["step"]).round(4).tolist()
38
+ )
39
+ return discretize(values, threshold=thresholds)
40
+ raise TypeError(f"threshold has unknown type: {threshold}")
41
+
42
+
43
+ class CorefMetricsSKLearn(DocumentMetric):
44
+ DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
45
+
46
+ def __init__(
47
+ self,
48
+ metrics: Dict[str, str],
49
+ thresholds: Optional[Dict[str, float]] = None,
50
+ default_target_idx: int = 0,
51
+ default_prediction_score: float = 0.0,
52
+ show_as_markdown: bool = False,
53
+ markdown_precision: int = 4,
54
+ plot: bool = False,
55
+ ):
56
+ self.metrics = {name: resolve_target(metric) for name, metric in metrics.items()}
57
+ self.thresholds = thresholds or {}
58
+ thresholds_not_in_metrics = {
59
+ name: t for name, t in self.thresholds.items() if name not in self.metrics
60
+ }
61
+ if len(thresholds_not_in_metrics) > 0:
62
+ logger.warning(
63
+ f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}"
64
+ )
65
+ self.default_target_idx = default_target_idx
66
+ self.default_prediction_score = default_prediction_score
67
+ self.show_as_markdown = show_as_markdown
68
+ self.markdown_precision = markdown_precision
69
+ self.plot = plot
70
+
71
+ super().__init__()
72
+
73
+ def reset(self) -> None:
74
+ self._preds: List[float] = []
75
+ self._targets: List[int] = []
76
+
77
+ def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None:
78
+ target_args2idx = {
79
+ (rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations
80
+ }
81
+ prediction_args2score = {
82
+ (rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions
83
+ }
84
+ all_args = set(target_args2idx) | set(prediction_args2score)
85
+ all_targets: List[int] = []
86
+ all_predictions: List[float] = []
87
+ for args in all_args:
88
+ target_idx = target_args2idx.get(args, self.default_target_idx)
89
+ prediction_score = prediction_args2score.get(args, self.default_prediction_score)
90
+ all_targets.append(target_idx)
91
+ all_predictions.append(prediction_score)
92
+ # prediction_scores = torch.tensor(all_predictions)
93
+ # target_indices = torch.tensor(all_targets)
94
+ # self.metrics.update(preds=prediction_scores, target=target_indices)
95
+ self._preds.extend(all_predictions)
96
+ self._targets.extend(all_targets)
97
+
98
+ def do_plot(self):
99
+ raise NotImplementedError()
100
+
101
+ from matplotlib import pyplot as plt
102
+
103
+ # Get the number of metrics
104
+ num_metrics = len(self.metrics)
105
+
106
+ # Calculate rows and columns for subplots (aim for a square-like layout)
107
+ ncols = math.ceil(math.sqrt(num_metrics))
108
+ nrows = math.ceil(num_metrics / ncols)
109
+
110
+ # Create the subplots
111
+ fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
112
+
113
+ # Flatten the ax_list if necessary (in case of multiple rows/columns)
114
+ ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary
115
+
116
+ # Ensure that we pass exactly the number of axes required by metrics
117
+ ax_list = ax_list[:num_metrics]
118
+
119
+ # Plot the metrics using the list of axes
120
+ self.metrics.plot(ax=ax_list, together=False)
121
+
122
+ # Adjust layout to avoid overlapping plots
123
+ plt.tight_layout()
124
+ plt.show()
125
+
126
+ def _compute(self) -> T:
127
+
128
+ if self.plot:
129
+ self.do_plot()
130
+
131
+ result = {}
132
+ for name, metric in self.metrics.items():
133
+
134
+ if name in self.thresholds:
135
+ preds = discretize(values=self._preds, threshold=self.thresholds[name])
136
+ else:
137
+ preds = self._preds
138
+ if isinstance(preds, dict):
139
+ metric_results = {
140
+ t: metric(self._targets, t_preds) for t, t_preds in preds.items()
141
+ }
142
+ # just get the max
143
+ max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1])
144
+ result[f"{name}-{max_t}"] = max_v
145
+ else:
146
+ result[name] = metric(self._targets, preds)
147
+
148
+ result = to_py_obj(result)
149
+ if self.show_as_markdown:
150
+ import pandas as pd
151
+
152
+ series = pd.Series(result)
153
+ if isinstance(series.index, MultiIndex):
154
+ if len(series.index.levels) > 1:
155
+ # in fact, this is not a series anymore
156
+ series = series.unstack(-1)
157
+ else:
158
+ series.index = series.index.get_level_values(0)
159
+ logger.info(
160
+ f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
161
+ )
162
+ return result
src/metrics/coref_torchmetrics.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Dict
4
+
5
+ import torch
6
+ from pandas import MultiIndex
7
+ from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
8
+ from pytorch_ie import DocumentMetric
9
+ from pytorch_ie.core.metric import T
10
+ from torchmetrics import Metric, MetricCollection
11
+
12
+ from src.hydra_callbacks.save_job_return_value import to_py_obj
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class CorefMetricsTorchmetrics(DocumentMetric):
18
+ DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
19
+
20
+ def __init__(
21
+ self,
22
+ metrics: Dict[str, Metric],
23
+ default_target_idx: int = 0,
24
+ default_prediction_score: float = 0.0,
25
+ show_as_markdown: bool = False,
26
+ markdown_precision: int = 4,
27
+ plot: bool = False,
28
+ ):
29
+ self.metrics = MetricCollection(metrics)
30
+ self.default_target_idx = default_target_idx
31
+ self.default_prediction_score = default_prediction_score
32
+ self.show_as_markdown = show_as_markdown
33
+ self.markdown_precision = markdown_precision
34
+ self.plot = plot
35
+
36
+ super().__init__()
37
+
38
+ def reset(self) -> None:
39
+ self.metrics.reset()
40
+
41
+ def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None:
42
+ target_args2idx = {
43
+ (rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations
44
+ }
45
+ prediction_args2score = {
46
+ (rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions
47
+ }
48
+ all_args = set(target_args2idx) | set(prediction_args2score)
49
+ all_targets = []
50
+ all_predictions = []
51
+ for args in all_args:
52
+ target_idx = target_args2idx.get(args, self.default_target_idx)
53
+ prediction_score = prediction_args2score.get(args, self.default_prediction_score)
54
+ all_targets.append(target_idx)
55
+ all_predictions.append(prediction_score)
56
+ prediction_scores = torch.tensor(all_predictions)
57
+ target_indices = torch.tensor(all_targets)
58
+ self.metrics.update(preds=prediction_scores, target=target_indices)
59
+
60
+ def do_plot(self):
61
+ from matplotlib import pyplot as plt
62
+
63
+ # Get the number of metrics
64
+ num_metrics = len(self.metrics)
65
+
66
+ # Calculate rows and columns for subplots (aim for a square-like layout)
67
+ ncols = math.ceil(math.sqrt(num_metrics))
68
+ nrows = math.ceil(num_metrics / ncols)
69
+
70
+ # Create the subplots
71
+ fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
72
+
73
+ # Flatten the ax_list if necessary (in case of multiple rows/columns)
74
+ ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary
75
+
76
+ # Ensure that we pass exactly the number of axes required by metrics
77
+ ax_list = ax_list[:num_metrics]
78
+
79
+ # Plot the metrics using the list of axes
80
+ self.metrics.plot(ax=ax_list, together=False)
81
+
82
+ # Adjust layout to avoid overlapping plots
83
+ plt.tight_layout()
84
+ plt.show()
85
+
86
+ def _compute(self) -> T:
87
+
88
+ if self.plot:
89
+ self.do_plot()
90
+
91
+ result = self.metrics.compute()
92
+
93
+ result = to_py_obj(result)
94
+ if self.show_as_markdown:
95
+ import pandas as pd
96
+
97
+ series = pd.Series(result)
98
+ if isinstance(series.index, MultiIndex):
99
+ if len(series.index.levels) > 1:
100
+ # in fact, this is not a series anymore
101
+ series = series.unstack(-1)
102
+ else:
103
+ series.index = series.index.get_level_values(0)
104
+ logger.info(
105
+ f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
106
+ )
107
+ return result
src/models/__init__.py CHANGED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .sequence_classification_with_pooler import (
2
+ SequencePairSimilarityModelWithMaxCosineSim,
3
+ SequencePairSimilarityModelWithPooler2,
4
+ SequencePairSimilarityModelWithPoolerAndAdapter,
5
+ )
src/models/components/__init__.py ADDED
File without changes
src/models/components/pooler.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor, cat, nn
3
+
4
+
5
+ class SpanMeanPooler(nn.Module):
6
+ """Pooler that takes the mean hidden state over spans. If the start or end index is negative, a
7
+ learned embedding is used. The indices are expected to have the shape [batch_size,
8
+ num_indices].
9
+
10
+ The resulting embeddings are concatenated, so the output shape is [batch_size, num_indices * input_dim].
11
+ Note this a slightly modified version of the pie_modules.models.components.pooler.SpanMaxPooler,
12
+ i.e. we changed the aggregation method from torch.amax to torch.mean.
13
+
14
+ Args:
15
+ input_dim: The input dimension of the hidden state.
16
+ num_indices: The number of indices to pool.
17
+
18
+ Returns:
19
+ The pooled hidden states with shape [batch_size, num_indices * input_dim].
20
+ """
21
+
22
+ def __init__(self, input_dim: int, num_indices: int = 2, **kwargs):
23
+ super().__init__(**kwargs)
24
+ self.input_dim = input_dim
25
+ self.num_indices = num_indices
26
+ self.missing_embeddings = nn.Parameter(torch.empty(num_indices, self.input_dim))
27
+ nn.init.normal_(self.missing_embeddings)
28
+
29
+ def forward(
30
+ self, hidden_state: Tensor, start_indices: Tensor, end_indices: Tensor, **kwargs
31
+ ) -> Tensor:
32
+ batch_size, seq_len, hidden_size = hidden_state.shape
33
+ if start_indices.shape[1] != self.num_indices:
34
+ raise ValueError(
35
+ f"number of start indices [{start_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
36
+ )
37
+
38
+ if end_indices.shape[1] != self.num_indices:
39
+ raise ValueError(
40
+ f"number of end indices [{end_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
41
+ )
42
+
43
+ # check that start_indices are before end_indices
44
+ mask_both_positive = (start_indices >= 0) & (end_indices >= 0)
45
+ mask_start_before_end = start_indices < end_indices
46
+ mask_valid = mask_start_before_end | ~mask_both_positive
47
+ if not torch.all(mask_valid):
48
+ raise ValueError(
49
+ f"values in start_indices have to be smaller than respective values in "
50
+ f"end_indices, but start_indices=\n{start_indices}\n and end_indices=\n{end_indices}"
51
+ )
52
+
53
+ # times num_indices due to concat
54
+ result = torch.zeros(
55
+ batch_size, hidden_size * self.num_indices, device=hidden_state.device
56
+ )
57
+ for batch_idx in range(batch_size):
58
+ current_start_indices = start_indices[batch_idx]
59
+ current_end_indices = end_indices[batch_idx]
60
+ current_embeddings = [
61
+ (
62
+ torch.mean(
63
+ hidden_state[
64
+ batch_idx, current_start_indices[i] : current_end_indices[i], :
65
+ ],
66
+ dim=0,
67
+ )
68
+ if current_start_indices[i] >= 0 and current_end_indices[i] >= 0
69
+ else self.missing_embeddings[i]
70
+ )
71
+ for i in range(self.num_indices)
72
+ ]
73
+ result[batch_idx] = cat(current_embeddings, 0)
74
+
75
+ return result
76
+
77
+ @property
78
+ def output_dim(self) -> int:
79
+ return self.input_dim * self.num_indices
src/models/sequence_classification_with_pooler.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import logging
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from adapters import AutoAdapterModel
8
+ from pie_modules.models import SequencePairSimilarityModelWithPooler
9
+ from pie_modules.models.components.pooler import MENTION_POOLING
10
+ from pie_modules.models.sequence_classification_with_pooler import (
11
+ InputType,
12
+ OutputType,
13
+ SequenceClassificationModelWithPooler,
14
+ SequenceClassificationModelWithPoolerBase,
15
+ TargetType,
16
+ separate_arguments_by_prefix,
17
+ )
18
+ from pytorch_ie import PyTorchIEModel
19
+ from torch import FloatTensor, Tensor
20
+ from transformers import AutoConfig, PreTrainedModel
21
+ from transformers.modeling_outputs import SequenceClassifierOutput
22
+
23
+ from src.models.components.pooler import SpanMeanPooler
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class SequenceClassificationModelWithPoolerBase2(
29
+ SequenceClassificationModelWithPoolerBase, abc.ABC
30
+ ):
31
+ def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]:
32
+ aggregate = self.pooler_config.get("aggregate", "max")
33
+ if self.pooler_config["type"] == MENTION_POOLING and aggregate != "max":
34
+ if aggregate == "mean":
35
+ pooler_config = dict(self.pooler_config)
36
+ pooler_config.pop("type")
37
+ pooler_config.pop("aggregate")
38
+ pooler = SpanMeanPooler(input_dim=input_dim, **pooler_config)
39
+ return pooler, pooler.output_dim
40
+ else:
41
+ raise ValueError(f"Unknown aggregation method: {aggregate}")
42
+ else:
43
+ return super().setup_pooler(input_dim)
44
+
45
+
46
+ class SequenceClassificationModelWithPoolerAndAdapterBase(
47
+ SequenceClassificationModelWithPoolerBase2, abc.ABC
48
+ ):
49
+ def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs):
50
+ self.adapter_name_or_path = adapter_name_or_path
51
+ super().__init__(**kwargs)
52
+
53
+ def setup_base_model(self) -> PreTrainedModel:
54
+ if self.adapter_name_or_path is None:
55
+ return super().setup_base_model()
56
+ else:
57
+ config = AutoConfig.from_pretrained(self.model_name_or_path)
58
+ if self.is_from_pretrained:
59
+ model = AutoAdapterModel.from_config(config=config)
60
+ else:
61
+ model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config)
62
+ # load the adapter in any case (it looks like it is not saved in the state or loaded
63
+ # from a serialized state)
64
+ logger.info(f"load adapter: {self.adapter_name_or_path}")
65
+ model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True)
66
+ return model
67
+
68
+
69
+ @PyTorchIEModel.register()
70
+ class SequencePairSimilarityModelWithPooler2(
71
+ SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerBase2
72
+ ):
73
+ pass
74
+
75
+
76
+ @PyTorchIEModel.register()
77
+ class SequencePairSimilarityModelWithPoolerAndAdapter(
78
+ SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
79
+ ):
80
+ pass
81
+
82
+
83
+ @PyTorchIEModel.register()
84
+ class SequenceClassificationModelWithPoolerAndAdapter(
85
+ SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
86
+ ):
87
+ pass
88
+
89
+
90
+ def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor:
91
+ # Normalize the embeddings
92
+ embeddings_normalized = F.normalize(embeddings, p=2, dim=1) # Shape: (n, k)
93
+ embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1) # Shape: (m, k)
94
+
95
+ # Compute the cosine similarity matrix
96
+ cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T) # Shape: (n, m)
97
+
98
+ # Get the overall maximum cosine similarity value
99
+ max_cosine_sim = torch.max(cosine_sim) # This will return a scalar
100
+ return max_cosine_sim
101
+
102
+
103
+ def get_span_embeddings(
104
+ embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor
105
+ ) -> List[FloatTensor]:
106
+ result = []
107
+ for embeds, starts, ends in zip(embeddings, start_indices, end_indices):
108
+ span_embeds = embeds[starts[0] : ends[0]]
109
+ result.append(span_embeds)
110
+ return result
111
+
112
+
113
+ @PyTorchIEModel.register()
114
+ class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler):
115
+ def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]:
116
+ output = self.model(**model_inputs)
117
+ hidden_state = output.last_hidden_state
118
+ # pooled_output = self.pooler(hidden_state, **pooler_inputs)
119
+ # pooled_output = self.dropout(pooled_output)
120
+ span_embeds = get_span_embeddings(hidden_state, **pooler_inputs)
121
+ return span_embeds
122
+
123
+ def forward(
124
+ self,
125
+ inputs: InputType,
126
+ targets: Optional[TargetType] = None,
127
+ return_hidden_states: bool = False,
128
+ ) -> OutputType:
129
+ sanitized_inputs = separate_arguments_by_prefix(
130
+ # Note that the order of the prefixes is important because one is a prefix of the other,
131
+ # so we need to start with the longer!
132
+ arguments=inputs,
133
+ prefixes=["pooler_pair_", "pooler_"],
134
+ )
135
+
136
+ span_embeddings = self.get_pooled_output(
137
+ model_inputs=sanitized_inputs["remaining"]["encoding"],
138
+ pooler_inputs=sanitized_inputs["pooler_"],
139
+ )
140
+ span_embeddings_pair = self.get_pooled_output(
141
+ model_inputs=sanitized_inputs["remaining"]["encoding_pair"],
142
+ pooler_inputs=sanitized_inputs["pooler_pair_"],
143
+ )
144
+
145
+ logits_list = [
146
+ get_max_cosine_sim(span_embeds, span_embeds_pair)
147
+ for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair)
148
+ ]
149
+ logits = torch.stack(logits_list)
150
+
151
+ result = {"logits": logits}
152
+ if targets is not None:
153
+ labels = targets["scores"]
154
+ loss = self.loss_fct(logits, labels)
155
+ result["loss"] = loss
156
+ if return_hidden_states:
157
+ raise NotImplementedError("return_hidden_states is not yet implemented")
158
+
159
+ return SequenceClassifierOutput(**result)
160
+
161
+
162
+ @PyTorchIEModel.register()
163
+ class SequencePairSimilarityModelWithMaxCosineSimAndAdapter(
164
+ SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter
165
+ ):
166
+ pass
src/models/utils/__init__.py CHANGED
@@ -1 +1,5 @@
1
- from .loading import load_model_from_pie_model, load_model_with_adapter, load_tokenizer_from_pie_taskmodule
 
 
 
 
 
1
+ from .loading import (
2
+ load_model_from_pie_model,
3
+ load_model_with_adapter,
4
+ load_tokenizer_from_pie_taskmodule,
5
+ )
src/models/utils/loading.py CHANGED
@@ -23,10 +23,10 @@ def load_tokenizer_from_pie_taskmodule(taskmodule_kwargs: Dict[str, Any]) -> Pre
23
 
24
 
25
  def load_model_with_adapter(
26
- model_kwargs: Dict[str, Any], adapter_kwargs: Dict[str, Any]
27
- ) -> "ModelAdaptersMixin":
28
- from adapters import AutoAdapterModel, ModelAdaptersMixin
29
 
30
  model = AutoAdapterModel.from_pretrained(**model_kwargs)
31
  model.load_adapter(set_active=True, **adapter_kwargs)
32
- return model
 
23
 
24
 
25
  def load_model_with_adapter(
26
+ model_kwargs: Dict[str, Any], adapter_kwargs: Dict[str, Any]
27
+ ) -> PreTrainedModel:
28
+ from adapters import AutoAdapterModel
29
 
30
  model = AutoAdapterModel.from_pretrained(**model_kwargs)
31
  model.load_adapter(set_active=True, **adapter_kwargs)
32
+ return model
src/pipeline/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .ner_re_pipeline import NerRePipeline
2
+ from .span_retrieval_based_re_pipeline import SpanRetrievalBasedRelationExtractionPipeline
src/pipeline/ner_re_pipeline.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from functools import partial
5
+ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, TypeVar, Union
6
+
7
+ from pie_modules.utils import resolve_type
8
+ from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
9
+ from pytorch_ie.core import Document
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ D = TypeVar("D", bound=Document)
15
+
16
+
17
+ def clear_annotation_layers(doc: D, layer_names: List[str], predictions: bool = False) -> None:
18
+ for layer_name in layer_names:
19
+ if predictions:
20
+ doc[layer_name].predictions.clear()
21
+ else:
22
+ doc[layer_name].clear()
23
+
24
+
25
+ def move_annotations_from_predictions(doc: D, layer_names: List[str]) -> None:
26
+ for layer_name in layer_names:
27
+ annotations = list(doc[layer_name].predictions)
28
+ # remove any previous annotations
29
+ doc[layer_name].clear()
30
+ # each annotation can be attached to just one annotation container, so we need to clear the predictions
31
+ doc[layer_name].predictions.clear()
32
+ doc[layer_name].extend(annotations)
33
+
34
+
35
+ def move_annotations_to_predictions(doc: D, layer_names: List[str]) -> None:
36
+ for layer_name in layer_names:
37
+ annotations = list(doc[layer_name])
38
+ # each annotation can be attached to just one annotation container, so we need to clear the layer
39
+ doc[layer_name].clear()
40
+ # remove any previous annotations
41
+ doc[layer_name].predictions.clear()
42
+ doc[layer_name].predictions.extend(annotations)
43
+
44
+
45
+ def add_annotations_from_other_documents(
46
+ docs: Iterable[D],
47
+ other_docs: Sequence[Document],
48
+ layer_names: List[str],
49
+ from_predictions: bool = False,
50
+ to_predictions: bool = False,
51
+ clear_before: bool = True,
52
+ ) -> None:
53
+ for i, doc in enumerate(docs):
54
+ other_doc = other_docs[i]
55
+ # copy to not modify the input
56
+ other_doc = type(other_doc).fromdict(other_doc.asdict())
57
+
58
+ for layer_name in layer_names:
59
+ if clear_before:
60
+ doc[layer_name].clear()
61
+ other_layer = other_doc[layer_name]
62
+ if from_predictions:
63
+ other_layer = other_layer.predictions
64
+ other_annotations = list(other_layer)
65
+ other_layer.clear()
66
+ if to_predictions:
67
+ doc[layer_name].predictions.extend(other_annotations)
68
+ else:
69
+ doc[layer_name].extend(other_annotations)
70
+
71
+
72
+ def process_pipeline_steps(
73
+ documents: Sequence[Document],
74
+ processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]],
75
+ ) -> Sequence[Document]:
76
+
77
+ # call the processors in the order they are provided
78
+ for step_name, processor in processors.items():
79
+ logger.info(f"process {step_name} ...")
80
+ processed_documents = processor(documents)
81
+ if processed_documents is not None:
82
+ documents = processed_documents
83
+
84
+ return documents
85
+
86
+
87
+ def process_documents(
88
+ documents: List[Document], processor: Callable[..., Optional[Document]], **kwargs
89
+ ) -> List[Document]:
90
+ result = []
91
+ for doc in documents:
92
+ processed_doc = processor(doc, **kwargs)
93
+ if processed_doc is not None:
94
+ result.append(processed_doc)
95
+ else:
96
+ result.append(doc)
97
+ return result
98
+
99
+
100
+ class DummyTaskmodule(WithDocumentTypeMixin):
101
+ def __init__(self, document_type: Optional[Union[Type[Document], str]]):
102
+ if isinstance(document_type, str):
103
+ self._document_type = resolve_type(document_type, expected_super_type=Document)
104
+ else:
105
+ self._document_type = document_type
106
+
107
+ @property
108
+ def document_type(self) -> Optional[Type[Document]]:
109
+ return self._document_type
110
+
111
+
112
+ class NerRePipeline:
113
+ def __init__(
114
+ self,
115
+ ner_model_path: str,
116
+ re_model_path: str,
117
+ entity_layer: str,
118
+ relation_layer: str,
119
+ device: Optional[int] = None,
120
+ batch_size: Optional[int] = None,
121
+ show_progress_bar: Optional[bool] = None,
122
+ document_type: Optional[Union[Type[Document], str]] = None,
123
+ **processor_kwargs,
124
+ ):
125
+ self.taskmodule = DummyTaskmodule(document_type)
126
+ self.ner_model_path = ner_model_path
127
+ self.re_model_path = re_model_path
128
+ self.processor_kwargs = processor_kwargs or {}
129
+ self.entity_layer = entity_layer
130
+ self.relation_layer = relation_layer
131
+ # set some values for the inference processors, if provided
132
+ for inference_pipeline in ["ner_pipeline", "re_pipeline"]:
133
+ if inference_pipeline not in self.processor_kwargs:
134
+ self.processor_kwargs[inference_pipeline] = {}
135
+ if "device" not in self.processor_kwargs[inference_pipeline] and device is not None:
136
+ self.processor_kwargs[inference_pipeline]["device"] = device
137
+ if (
138
+ "batch_size" not in self.processor_kwargs[inference_pipeline]
139
+ and batch_size is not None
140
+ ):
141
+ self.processor_kwargs[inference_pipeline]["batch_size"] = batch_size
142
+ if (
143
+ "show_progress_bar" not in self.processor_kwargs[inference_pipeline]
144
+ and show_progress_bar is not None
145
+ ):
146
+ self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar
147
+
148
+ def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]:
149
+
150
+ input_docs: Sequence[Document]
151
+ # we need to keep the original documents to add the gold data back
152
+ original_docs: Sequence[Document]
153
+ if inplace:
154
+ input_docs = documents
155
+ original_docs = [doc.copy() for doc in documents]
156
+ else:
157
+ input_docs = [doc.copy() for doc in documents]
158
+ original_docs = documents
159
+
160
+ docs_with_predictions = process_pipeline_steps(
161
+ documents=input_docs,
162
+ processors={
163
+ "clear_annotations": partial(
164
+ process_documents,
165
+ processor=clear_annotation_layers,
166
+ layer_names=[self.entity_layer, self.relation_layer],
167
+ **self.processor_kwargs.get("clear_annotations", {}),
168
+ ),
169
+ "ner_pipeline": AutoPipeline.from_pretrained(
170
+ self.ner_model_path, **self.processor_kwargs.get("ner_pipeline", {})
171
+ ),
172
+ "use_predicted_entities": partial(
173
+ process_documents,
174
+ processor=move_annotations_from_predictions,
175
+ layer_names=[self.entity_layer],
176
+ **self.processor_kwargs.get("use_predicted_entities", {}),
177
+ ),
178
+ # "create_candidate_relations": partial(
179
+ # process_documents,
180
+ # processor=CandidateRelationAdder(
181
+ # **self.processor_kwargs.get("create_candidate_relations", {})
182
+ # ),
183
+ # ),
184
+ "re_pipeline": AutoPipeline.from_pretrained(
185
+ self.re_model_path, **self.processor_kwargs.get("re_pipeline", {})
186
+ ),
187
+ # otherwise we can not move the entities back to predictions
188
+ "clear_candidate_relations": partial(
189
+ process_documents,
190
+ processor=clear_annotation_layers,
191
+ layer_names=[self.relation_layer],
192
+ **self.processor_kwargs.get("clear_candidate_relations", {}),
193
+ ),
194
+ "move_entities_to_predictions": partial(
195
+ process_documents,
196
+ processor=move_annotations_to_predictions,
197
+ layer_names=[self.entity_layer],
198
+ **self.processor_kwargs.get("move_entities_to_predictions", {}),
199
+ ),
200
+ "re_add_gold_data": partial(
201
+ add_annotations_from_other_documents,
202
+ other_docs=original_docs,
203
+ layer_names=[self.entity_layer, self.relation_layer],
204
+ **self.processor_kwargs.get("re_add_gold_data", {}),
205
+ ),
206
+ },
207
+ )
208
+ return docs_with_predictions
src/pipeline/span_retrieval_based_re_pipeline.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, Sequence, Type
3
+
4
+ from langchain_core.documents import Document as LCDocument
5
+ from pie_datasets import Dataset, IterableDataset
6
+ from pytorch_ie import Document, WithDocumentTypeMixin
7
+ from pytorch_ie.annotations import BinaryRelation, LabeledSpan
8
+ from pytorch_ie.documents import TextBasedDocument
9
+
10
+ from src.langchain_modules import DocumentAwareSpanRetriever
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class DummyTaskmodule(WithDocumentTypeMixin):
16
+ def __init__(self, document_type: Type[Document]):
17
+ self._document_type = document_type
18
+
19
+ @property
20
+ def document_type(self) -> Optional[Type[Document]]:
21
+ return self._document_type
22
+
23
+
24
+ class SpanRetrievalBasedRelationExtractionPipeline:
25
+ """Pipeline for adding binary relations between spans based on span retrieval within the same document.
26
+
27
+ This pipeline retrieves spans for all existing spans as query and adds binary relations between the
28
+ query spans and the retrieved spans.
29
+
30
+ Args:
31
+ retriever: The span retriever to use for retrieving spans.
32
+ relation_label: The label to use for the binary relations.
33
+ relation_layer_name: The name of the annotation layer to add the binary relations to.
34
+ load_store_path: If provided, the retriever store(s) will be loaded from this path before processing.
35
+ save_store_path: If provided, the retriever store(s) will be saved to this path after processing.
36
+ fast_dev_run: Whether to run the pipeline in fast dev mode, i.e. only processing the first 2 documents.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ retriever: DocumentAwareSpanRetriever,
42
+ relation_label: str,
43
+ relation_layer_name: str = "binary_relations",
44
+ use_predicted_annotations: bool = False,
45
+ load_store_path: Optional[str] = None,
46
+ save_store_path: Optional[str] = None,
47
+ fast_dev_run: bool = False,
48
+ ):
49
+ self.retriever = retriever
50
+ if not self.retriever.retrieve_from_same_document:
51
+ raise NotImplementedError("Retriever must retrieve from the same document")
52
+ self.relation_label = relation_label
53
+ self.relation_layer_name = relation_layer_name
54
+ self.use_predicted_annotations = use_predicted_annotations
55
+ self.load_store_path = load_store_path
56
+ self.save_store_path = save_store_path
57
+ if self.load_store_path is not None:
58
+ self.retriever.load_from_directory(path=self.load_store_path)
59
+
60
+ self.fast_dev_run = fast_dev_run
61
+
62
+ # to make auto-conversion work: we request documents of type pipeline.taskmodule.document_type
63
+ # from the dataset
64
+ @property
65
+ def taskmodule(self) -> DummyTaskmodule:
66
+ return DummyTaskmodule(self.retriever.pie_document_type)
67
+
68
+ def _construct_similarity_relations(
69
+ self,
70
+ query_results: list[LCDocument],
71
+ query_span: LabeledSpan,
72
+ ) -> list[BinaryRelation]:
73
+ return [
74
+ BinaryRelation(
75
+ head=query_span,
76
+ tail=lc_doc.metadata["attached_span"],
77
+ label=self.relation_label,
78
+ score=float(lc_doc.metadata["relevance_score"]),
79
+ )
80
+ for lc_doc in query_results
81
+ ]
82
+
83
+ def _process_single_document(
84
+ self,
85
+ document: Document,
86
+ ) -> TextBasedDocument:
87
+ if not isinstance(document, TextBasedDocument):
88
+ raise ValueError("Document must be a TextBasedDocument")
89
+
90
+ self.retriever.add_pie_documents(
91
+ [document], use_predicted_annotations=self.use_predicted_annotations
92
+ )
93
+
94
+ all_new_rels = []
95
+ spans = self.retriever.get_base_layer(
96
+ document, use_predicted_annotations=self.use_predicted_annotations
97
+ )
98
+ span_id2idx = self.retriever.get_span_id2idx_from_doc(document.id)
99
+ for span_id, span_idx in span_id2idx.items():
100
+ query_span = spans[span_idx]
101
+ query_result = self.retriever.invoke(input=span_id)
102
+ query_rels = self._construct_similarity_relations(query_result, query_span=query_span)
103
+ all_new_rels.extend(query_rels)
104
+
105
+ if self.relation_layer_name not in document:
106
+ raise ValueError(f"Document does not have a layer named {self.relation_layer_name}")
107
+ document[self.relation_layer_name].predictions.extend(all_new_rels)
108
+
109
+ if self.retriever.retrieve_from_same_document and self.save_store_path is None:
110
+ self.retriever.delete_documents([document.id])
111
+
112
+ return document
113
+
114
+ def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]:
115
+ if inplace:
116
+ raise NotImplementedError("Inplace processing is not supported yet")
117
+
118
+ if self.fast_dev_run:
119
+ logger.warning("Fast dev run enabled, only processing the first 2 documents")
120
+ documents = documents[:2]
121
+
122
+ if not isinstance(documents, (Dataset, IterableDataset)):
123
+ documents = Dataset.from_documents(documents)
124
+
125
+ mapped_documents = documents.map(self._process_single_document)
126
+
127
+ if self.save_store_path is not None:
128
+ self.retriever.save_to_directory(path=self.save_store_path)
129
+
130
+ return mapped_documents
src/predict.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+
3
+ root = pyrootutils.setup_root(
4
+ search_from=__file__,
5
+ indicator=[".project-root"],
6
+ pythonpath=True,
7
+ dotenv=True,
8
+ )
9
+
10
+ # ------------------------------------------------------------------------------------ #
11
+ # `pyrootutils.setup_root(...)` is an optional line at the top of each entry file
12
+ # that helps to make the environment more robust and convenient
13
+ #
14
+ # the main advantages are:
15
+ # - allows you to keep all entry files in "src/" without installing project as a package
16
+ # - makes paths and scripts always work no matter where is your current work dir
17
+ # - automatically loads environment variables from ".env" file if exists
18
+ #
19
+ # how it works:
20
+ # - the line above recursively searches for either ".git" or "pyproject.toml" in present
21
+ # and parent dirs, to determine the project root dir
22
+ # - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
23
+ # any place without installing project as a package
24
+ # - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
25
+ # to make all paths always relative to the project root
26
+ # - loads environment variables from ".env" file in root dir (if `dotenv=True`)
27
+ #
28
+ # you can remove `pyrootutils.setup_root(...)` if you:
29
+ # 1. either install project as a package or move each entry file to the project root dir
30
+ # 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
31
+ # 3. always run entry files from the project root dir
32
+ #
33
+ # https://github.com/ashleve/pyrootutils
34
+ # ------------------------------------------------------------------------------------ #
35
+
36
+ import os
37
+ import timeit
38
+ from collections.abc import Iterable, Sequence
39
+ from typing import Any, Dict, Optional, Tuple, Union
40
+
41
+ import hydra
42
+ import pytorch_lightning as pl
43
+ from omegaconf import DictConfig, OmegaConf
44
+ from pie_datasets import Dataset, DatasetDict
45
+ from pie_modules.models import * # noqa: F403
46
+ from pie_modules.taskmodules import * # noqa: F403
47
+ from pytorch_ie import Document, Pipeline
48
+ from pytorch_ie.models import * # noqa: F403
49
+ from pytorch_ie.taskmodules import * # noqa: F403
50
+
51
+ from src import utils
52
+ from src.models import * # noqa: F403
53
+ from src.serializer.interface import DocumentSerializer
54
+ from src.taskmodules import * # noqa: F403
55
+
56
+ log = utils.get_pylogger(__name__)
57
+
58
+
59
+ def document_batch_iter(
60
+ dataset: Union[Sequence[Document], Iterable[Document]], batch_size: int
61
+ ) -> Iterable[Sequence[Document]]:
62
+ if isinstance(dataset, Sequence):
63
+ for i in range(0, len(dataset), batch_size):
64
+ yield dataset[i : i + batch_size]
65
+ elif isinstance(dataset, Iterable):
66
+ docs = []
67
+ for doc in dataset:
68
+ docs.append(doc)
69
+ if len(docs) == batch_size:
70
+ yield docs
71
+ docs = []
72
+ if docs:
73
+ yield docs
74
+ else:
75
+ raise ValueError(f"Unsupported dataset type: {type(dataset)}")
76
+
77
+
78
+ @utils.task_wrapper
79
+ def predict(cfg: DictConfig) -> Tuple[dict, dict]:
80
+ """Contains minimal example of the prediction pipeline. Uses a pretrained model to annotate
81
+ documents from a dataset and serializes them.
82
+
83
+ Args:
84
+ cfg (DictConfig): Configuration composed by Hydra.
85
+
86
+ Returns:
87
+ None
88
+ """
89
+
90
+ # Set seed for random number generators in pytorch, numpy and python.random
91
+ if cfg.get("seed"):
92
+ pl.seed_everything(cfg.seed, workers=True)
93
+
94
+ # Init pytorch-ie dataset
95
+ log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
96
+ dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial")
97
+
98
+ # Init pytorch-ie pipeline
99
+ # The pipeline, and therefore the inference step, is optional to allow for easy testing
100
+ # of the dataset creation and processing.
101
+ pipeline: Optional[Pipeline] = None
102
+ if cfg.get("pipeline") and cfg.pipeline.get("_target_"):
103
+ log.info(f"Instantiating pipeline <{cfg.pipeline._target_}> from {cfg.model_name_or_path}")
104
+ pipeline = hydra.utils.instantiate(cfg.pipeline, _convert_="partial")
105
+
106
+ # Per default, the model is loaded with .from_pretrained() which already loads the weights.
107
+ # However, ckpt_path can be used to load different weights from any checkpoint.
108
+ if cfg.ckpt_path is not None:
109
+ pipeline.model = pipeline.model.load_from_checkpoint(checkpoint_path=cfg.ckpt_path).to(
110
+ pipeline.device
111
+ )
112
+
113
+ # auto-convert the dataset if the metric specifies a document type
114
+ dataset = pipeline.taskmodule.convert_dataset(dataset)
115
+
116
+ # Init the serializer
117
+ serializer: Optional[DocumentSerializer] = None
118
+ if cfg.get("serializer") and cfg.serializer.get("_target_"):
119
+ log.info(f"Instantiating serializer <{cfg.serializer._target_}>")
120
+ serializer = hydra.utils.instantiate(cfg.serializer, _convert_="partial")
121
+
122
+ # select the dataset split for prediction
123
+ dataset_predict = dataset[cfg.dataset_split]
124
+
125
+ object_dict = {
126
+ "cfg": cfg,
127
+ "dataset": dataset,
128
+ "pipeline": pipeline,
129
+ "serializer": serializer,
130
+ }
131
+ result: Dict[str, Any] = {}
132
+ if pipeline is not None:
133
+ log.info("Starting inference!")
134
+ prediction_time = 0.0
135
+ else:
136
+ log.warning("No prediction pipeline is defined, skip inference!")
137
+ prediction_time = None
138
+ document_batch_size = cfg.get("document_batch_size", None)
139
+ for docs_batch in (
140
+ document_batch_iter(dataset_predict, document_batch_size)
141
+ if document_batch_size
142
+ else [dataset_predict]
143
+ ):
144
+ if pipeline is not None:
145
+ t_start = timeit.default_timer()
146
+ docs_batch = pipeline(docs_batch, inplace=False)
147
+ prediction_time += timeit.default_timer() - t_start # type: ignore
148
+
149
+ # serialize the documents
150
+ if serializer is not None:
151
+ # the serializer should not return the serialized documents, but write them to disk
152
+ # and instead return some metadata such as the path to the serialized documents
153
+ serializer_result = serializer(docs_batch)
154
+ if "serializer" in result and result["serializer"] != serializer_result:
155
+ log.warning(
156
+ f"serializer result changed from {result['serializer']} to {serializer_result}"
157
+ " during prediction. Only the last result is returned."
158
+ )
159
+ result["serializer"] = serializer_result
160
+
161
+ if prediction_time is not None:
162
+ result["prediction_time"] = prediction_time
163
+
164
+ # serialize config with resolved paths
165
+ if cfg.get("config_out_path"):
166
+ config_out_dir = os.path.dirname(cfg.config_out_path)
167
+ os.makedirs(config_out_dir, exist_ok=True)
168
+ OmegaConf.save(config=cfg, f=cfg.config_out_path)
169
+ result["config"] = cfg.config_out_path
170
+
171
+ return result, object_dict
172
+
173
+
174
+ @hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="predict.yaml")
175
+ def main(cfg: DictConfig) -> None:
176
+ result_dict, _ = predict(cfg)
177
+ return result_dict
178
+
179
+
180
+ if __name__ == "__main__":
181
+ utils.replace_sys_args_with_values_from_files()
182
+ utils.prepare_omegaconf()
183
+ main()
src/serializer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .json import JsonSerializer, JsonSerializer2
src/serializer/interface.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Sequence
3
+
4
+ from pytorch_ie.core import Document
5
+
6
+
7
+ class DocumentSerializer(ABC):
8
+ """This defines the interface for a document serializer.
9
+
10
+ The serializer should not return the serialized documents, but write them to disk and instead
11
+ return some metadata such as the path to the serialized documents.
12
+ """
13
+
14
+ @abstractmethod
15
+ def __call__(self, documents: Sequence[Document]) -> Any:
16
+ pass
src/serializer/json.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Dict, List, Optional, Sequence, Type, TypeVar
4
+
5
+ from pie_datasets import Dataset, DatasetDict, IterableDataset
6
+ from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
7
+ from pytorch_ie.core import Document
8
+ from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type
9
+
10
+ from src.serializer.interface import DocumentSerializer
11
+ from src.utils import get_pylogger
12
+
13
+ log = get_pylogger(__name__)
14
+
15
+ D = TypeVar("D", bound=Document)
16
+
17
+
18
+ def as_json_lines(file_name: str) -> bool:
19
+ if file_name.lower().endswith(".jsonl"):
20
+ return True
21
+ elif file_name.lower().endswith(".json"):
22
+ return False
23
+ else:
24
+ raise Exception(f"unknown file extension: {file_name}")
25
+
26
+
27
+ class JsonSerializer(DocumentSerializer):
28
+ def __init__(self, **kwargs):
29
+ self.default_kwargs = kwargs
30
+
31
+ @classmethod
32
+ def write(
33
+ cls,
34
+ documents: Sequence[Document],
35
+ path: str,
36
+ file_name: str = "documents.jsonl",
37
+ metadata_file_name: str = METADATA_FILE_NAME,
38
+ split: Optional[str] = None,
39
+ **kwargs,
40
+ ) -> Dict[str, str]:
41
+ realpath = os.path.realpath(path)
42
+ log.info(f'serialize documents to "{realpath}" ...')
43
+ os.makedirs(realpath, exist_ok=True)
44
+
45
+ # dump metadata including the document_type
46
+ if len(documents) == 0:
47
+ raise Exception("cannot serialize empty list of documents")
48
+ document_type = type(documents[0])
49
+ metadata = {"document_type": serialize_document_type(document_type)}
50
+ full_metadata_file_name = os.path.join(realpath, metadata_file_name)
51
+ if os.path.exists(full_metadata_file_name):
52
+ # load previous metadata
53
+ with open(full_metadata_file_name) as f:
54
+ previous_metadata = json.load(f)
55
+ if previous_metadata != metadata:
56
+ raise ValueError(
57
+ f"metadata file {full_metadata_file_name} already exists, "
58
+ "but the content does not match the current metadata"
59
+ "\nprevious metadata: {previous_metadata}"
60
+ "\ncurrent metadata: {metadata}"
61
+ )
62
+ else:
63
+ with open(full_metadata_file_name, "w") as f:
64
+ json.dump(metadata, f, indent=2)
65
+
66
+ if split is not None:
67
+ realpath = os.path.join(realpath, split)
68
+ os.makedirs(realpath, exist_ok=True)
69
+ full_file_name = os.path.join(realpath, file_name)
70
+ if as_json_lines(file_name):
71
+ # if the file already exists, append to it
72
+ mode = "a" if os.path.exists(full_file_name) else "w"
73
+ with open(full_file_name, mode) as f:
74
+ for doc in documents:
75
+ f.write(json.dumps(doc.asdict(), **kwargs) + "\n")
76
+ else:
77
+ docs_list = [doc.asdict() for doc in documents]
78
+ if os.path.exists(full_file_name):
79
+ # load previous documents
80
+ with open(full_file_name) as f:
81
+ previous_doc_list = json.load(f)
82
+ docs_list = previous_doc_list + docs_list
83
+ with open(full_file_name, "w") as f:
84
+ json.dump(docs_list, fp=f, **kwargs)
85
+ return {"path": realpath, "file_name": file_name, "metadata_file_name": metadata_file_name}
86
+
87
+ @classmethod
88
+ def read(
89
+ cls,
90
+ path: str,
91
+ document_type: Optional[Type[D]] = None,
92
+ file_name: str = "documents.jsonl",
93
+ metadata_file_name: str = METADATA_FILE_NAME,
94
+ split: Optional[str] = None,
95
+ ) -> List[D]:
96
+ realpath = os.path.realpath(path)
97
+ log.info(f'load documents from "{realpath}" ...')
98
+
99
+ # try to load metadata including the document_type
100
+ full_metadata_file_name = os.path.join(realpath, metadata_file_name)
101
+ if os.path.exists(full_metadata_file_name):
102
+ with open(full_metadata_file_name) as f:
103
+ metadata = json.load(f)
104
+ document_type = resolve_optional_document_type(metadata.get("document_type"))
105
+
106
+ if document_type is None:
107
+ raise Exception("document_type is required to load serialized documents")
108
+
109
+ if split is not None:
110
+ realpath = os.path.join(realpath, split)
111
+ full_file_name = os.path.join(realpath, file_name)
112
+ documents = []
113
+ if as_json_lines(str(file_name)):
114
+ with open(full_file_name) as f:
115
+ for line in f:
116
+ json_dict = json.loads(line)
117
+ documents.append(document_type.fromdict(json_dict))
118
+ else:
119
+ with open(full_file_name) as f:
120
+ json_list = json.load(f)
121
+ for json_dict in json_list:
122
+ documents.append(document_type.fromdict(json_dict))
123
+ return documents
124
+
125
+ def read_with_defaults(self, **kwargs) -> List[D]:
126
+ all_kwargs = {**self.default_kwargs, **kwargs}
127
+ return self.read(**all_kwargs)
128
+
129
+ def write_with_defaults(self, **kwargs) -> Dict[str, str]:
130
+ all_kwargs = {**self.default_kwargs, **kwargs}
131
+ return self.write(**all_kwargs)
132
+
133
+ def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]:
134
+ return self.write_with_defaults(documents=documents, **kwargs)
135
+
136
+
137
+ class JsonSerializer2(DocumentSerializer):
138
+ def __init__(self, **kwargs):
139
+ self.default_kwargs = kwargs
140
+
141
+ @classmethod
142
+ def write(
143
+ cls,
144
+ documents: Sequence[Document],
145
+ path: str,
146
+ split: str = "train",
147
+ ) -> Dict[str, str]:
148
+ if not isinstance(documents, (Dataset, IterableDataset)):
149
+ documents = Dataset.from_documents(documents)
150
+ dataset_dict = DatasetDict({split: documents})
151
+ dataset_dict.to_json(path=path)
152
+ return {"path": path, "split": split}
153
+
154
+ @classmethod
155
+ def read(
156
+ cls,
157
+ path: str,
158
+ document_type: Optional[Type[D]] = None,
159
+ split: Optional[str] = None,
160
+ ) -> Dataset[Document]:
161
+ dataset_dict = DatasetDict.from_json(
162
+ data_dir=path, document_type=document_type, split=split
163
+ )
164
+ if split is not None:
165
+ return dataset_dict[split]
166
+ if len(dataset_dict) == 1:
167
+ return dataset_dict[list(dataset_dict.keys())[0]]
168
+ raise ValueError(f"multiple splits found in dataset_dict: {list(dataset_dict.keys())}")
169
+
170
+ def read_with_defaults(self, **kwargs) -> Sequence[D]:
171
+ all_kwargs = {**self.default_kwargs, **kwargs}
172
+ return self.read(**all_kwargs)
173
+
174
+ def write_with_defaults(self, **kwargs) -> Dict[str, str]:
175
+ all_kwargs = {**self.default_kwargs, **kwargs}
176
+ return self.write(**all_kwargs)
177
+
178
+ def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]:
179
+ return self.write_with_defaults(documents=documents, **kwargs)
src/start_demo.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import pyrootutils
3
+ from omegaconf import DictConfig, OmegaConf, SCMode
4
+
5
+ root = pyrootutils.setup_root(
6
+ search_from=__file__,
7
+ indicator=[".project-root"],
8
+ pythonpath=True,
9
+ dotenv=True,
10
+ )
11
+
12
+ import json
13
+ import logging
14
+
15
+ import gradio as gr
16
+ import torch
17
+ import yaml
18
+
19
+ from src.demo.annotation_utils import load_argumentation_model
20
+ from src.demo.backend_utils import (
21
+ download_processed_documents,
22
+ process_text_from_arxiv,
23
+ process_uploaded_files,
24
+ render_annotated_document,
25
+ upload_processed_documents,
26
+ wrapped_add_annotated_pie_documents_from_dataset,
27
+ wrapped_process_text,
28
+ )
29
+ from src.demo.frontend_utils import (
30
+ change_tab,
31
+ escape_regex,
32
+ get_cell_for_fixed_column_from_df,
33
+ get_fix_df_height_css,
34
+ open_accordion,
35
+ unescape_regex,
36
+ )
37
+ from src.demo.rendering_utils import AVAILABLE_RENDER_MODES, HIGHLIGHT_SPANS_JS
38
+ from src.demo.retriever_utils import (
39
+ get_document_as_dict,
40
+ get_span_annotation,
41
+ load_retriever,
42
+ retrieve_all_relevant_spans,
43
+ retrieve_all_similar_spans,
44
+ retrieve_relevant_spans,
45
+ retrieve_similar_spans,
46
+ )
47
+
48
+
49
+ def load_yaml_config(path: str) -> str:
50
+ with open(path, "r") as file:
51
+ yaml_string = file.read()
52
+ config = yaml.safe_load(yaml_string)
53
+ return yaml.dump(config)
54
+
55
+
56
+ def resolve_config(cfg) -> dict:
57
+ return OmegaConf.to_container(cfg, resolve=True, structured_config_mode=SCMode.DICT)
58
+
59
+
60
+ @hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="demo.yaml")
61
+ def main(cfg: DictConfig) -> None:
62
+
63
+ # configure logging
64
+ logging.basicConfig()
65
+
66
+ # resolve everything in the config to prevent any issues with to json serialization etc.
67
+ cfg = resolve_config(cfg)
68
+
69
+ example_text = cfg["example_text"]
70
+
71
+ default_device = "cuda" if torch.cuda.is_available() else "cpu"
72
+
73
+ default_retriever_config_str = load_yaml_config(cfg["default_retriever_config_path"])
74
+
75
+ default_model_name = cfg["default_model_name"]
76
+ default_model_revision = cfg["default_model_revision"]
77
+ handle_parts_of_same = cfg["handle_parts_of_same"]
78
+
79
+ default_arxiv_id = cfg["default_arxiv_id"]
80
+ default_load_pie_dataset_kwargs_str = json.dumps(
81
+ cfg["default_load_pie_dataset_kwargs"], indent=2
82
+ )
83
+
84
+ default_render_mode = cfg["default_render_mode"]
85
+ if default_render_mode not in AVAILABLE_RENDER_MODES:
86
+ raise ValueError(
87
+ f"Invalid default render mode '{default_render_mode}'. "
88
+ f"Choose one of {AVAILABLE_RENDER_MODES}."
89
+ )
90
+ default_render_kwargs = cfg["default_render_kwargs"]
91
+
92
+ # captions for better readability
93
+ default_split_regex = cfg["default_split_regex"]
94
+ # map from render mode to the corresponding caption
95
+ render_mode2caption = {
96
+ render_mode: cfg["render_mode_captions"].get(render_mode, render_mode)
97
+ for render_mode in AVAILABLE_RENDER_MODES
98
+ }
99
+ render_caption2mode = {v: k for k, v in render_mode2caption.items()}
100
+ default_min_similarity = cfg["default_min_similarity"]
101
+ layer_caption_mapping = cfg["layer_caption_mapping"]
102
+ relation_name_mapping = cfg["relation_name_mapping"]
103
+
104
+ gr.Info("Loading models ...")
105
+ argumentation_model = load_argumentation_model(
106
+ model_name=default_model_name,
107
+ revision=default_model_revision,
108
+ device=default_device,
109
+ )
110
+ retriever = load_retriever(
111
+ default_retriever_config_str, device=default_device, config_format="yaml"
112
+ )
113
+
114
+ with gr.Blocks(css=get_fix_df_height_css(css_class="df-docstore", max_height=300)) as demo:
115
+ # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
116
+ # models_state = gr.State((argumentation_model, embedding_model))
117
+ argumentation_model_state = gr.State((argumentation_model,))
118
+ retriever_state = gr.State((retriever,))
119
+
120
+ with gr.Row():
121
+ with gr.Tabs() as left_tabs:
122
+ with gr.Tab("User Input", id="user_input") as user_input_tab:
123
+ doc_id = gr.Textbox(
124
+ label="Document ID",
125
+ value="user_input",
126
+ )
127
+ doc_text = gr.Textbox(
128
+ label="Text",
129
+ lines=20,
130
+ value=example_text,
131
+ )
132
+
133
+ with gr.Accordion("Model Configuration", open=False):
134
+ with gr.Accordion("argumentation structure", open=True):
135
+ model_name = gr.Textbox(
136
+ label="Model Name",
137
+ value=default_model_name,
138
+ )
139
+ model_revision = gr.Textbox(
140
+ label="Model Revision",
141
+ value=default_model_revision,
142
+ )
143
+ load_arg_model_btn = gr.Button("Load Argumentation Model")
144
+
145
+ with gr.Accordion("retriever", open=True):
146
+ retriever_config = gr.Code(
147
+ language="yaml",
148
+ label="Retriever Configuration",
149
+ value=default_retriever_config_str,
150
+ lines=len(default_retriever_config_str.split("\n")),
151
+ )
152
+ load_retriever_btn = gr.Button("Load Retriever")
153
+
154
+ device = gr.Textbox(
155
+ label="Device (e.g. 'cuda' or 'cpu')",
156
+ value=default_device,
157
+ )
158
+ load_arg_model_btn.click(
159
+ fn=lambda _model_name, _model_revision, _device: (
160
+ load_argumentation_model(
161
+ model_name=_model_name,
162
+ revision=_model_revision,
163
+ device=_device,
164
+ ),
165
+ ),
166
+ inputs=[model_name, model_revision, device],
167
+ outputs=argumentation_model_state,
168
+ )
169
+ load_retriever_btn.click(
170
+ fn=lambda _retriever_config, _device, _previous_retriever: (
171
+ load_retriever(
172
+ retriever_config_str=_retriever_config,
173
+ device=_device,
174
+ previous_retriever=_previous_retriever[0],
175
+ config_format="yaml",
176
+ ),
177
+ ),
178
+ inputs=[retriever_config, device, retriever_state],
179
+ outputs=retriever_state,
180
+ )
181
+
182
+ split_regex_escaped = gr.Textbox(
183
+ label="Regex to partition the text",
184
+ placeholder="Regular expression pattern to split the text into partitions",
185
+ value=escape_regex(default_split_regex),
186
+ )
187
+
188
+ predict_btn = gr.Button("Analyse")
189
+
190
+ with gr.Tab("Analysed Document", id="analysed_document") as analysed_document_tab:
191
+ selected_document_id = gr.Textbox(
192
+ label="Document ID", max_lines=1, interactive=False
193
+ )
194
+ rendered_output = gr.HTML(label="Rendered Output")
195
+
196
+ with gr.Accordion("Render Options", open=False):
197
+ render_as = gr.Dropdown(
198
+ label="Render with",
199
+ choices=list(render_mode2caption.values()),
200
+ value=render_mode2caption[default_render_mode],
201
+ )
202
+ render_kwargs = gr.Code(
203
+ language="json",
204
+ label="Render Arguments",
205
+ lines=len(json.dumps(default_render_kwargs, indent=2).split("\n")),
206
+ value=json.dumps(default_render_kwargs, indent=2),
207
+ )
208
+ render_btn = gr.Button("Re-render")
209
+
210
+ with gr.Accordion("See plain result ...", open=False):
211
+ get_document_json_btn = gr.Button("Fetch annotated document as JSON")
212
+ document_json = gr.JSON(label="Model Output")
213
+
214
+ with gr.Tabs() as right_tabs:
215
+ with gr.Tab("Retrieval", id="retrieval") as retrieval_tab:
216
+ with gr.Accordion(
217
+ "Indexed Documents", open=False
218
+ ) as processed_documents_accordion:
219
+ processed_documents_df = gr.DataFrame(
220
+ headers=["id", "num_adus", "num_relations"],
221
+ interactive=False,
222
+ elem_classes="df-docstore",
223
+ )
224
+ gr.Markdown("Data Snapshot:")
225
+ with gr.Row():
226
+ download_processed_documents_btn = gr.DownloadButton("Download")
227
+ upload_processed_documents_btn = gr.UploadButton(
228
+ "Upload", file_types=["file"]
229
+ )
230
+
231
+ # currently not used
232
+ # relation_types = set_relation_types(
233
+ # argumentation_model_state.value[0], default=["supports_reversed", "contradicts_reversed"]
234
+ # )
235
+
236
+ # Dummy textbox to hold the hover adu id. On click on the rendered output,
237
+ # its content will be copied to selected_adu_id which will trigger the retrieval.
238
+ hover_adu_id = gr.Textbox(
239
+ label="ID (hover)",
240
+ elem_id="hover_adu_id",
241
+ interactive=False,
242
+ visible=False,
243
+ )
244
+ selected_adu_id = gr.Textbox(
245
+ label="ID (selected)",
246
+ elem_id="selected_adu_id",
247
+ interactive=False,
248
+ visible=False,
249
+ )
250
+ selected_adu_text = gr.Textbox(label="Selected ADU", interactive=False)
251
+
252
+ with gr.Accordion("Relevant ADUs from other documents", open=True):
253
+ relevant_adus_df = gr.DataFrame(
254
+ headers=[
255
+ "relation",
256
+ "adu",
257
+ "reference_adu",
258
+ "doc_id",
259
+ "sim_score",
260
+ "rel_score",
261
+ ],
262
+ interactive=False,
263
+ )
264
+
265
+ with gr.Accordion("Retrieval Configuration", open=False):
266
+ min_similarity = gr.Slider(
267
+ label="Minimum Similarity",
268
+ minimum=0.0,
269
+ maximum=1.0,
270
+ step=0.01,
271
+ value=default_min_similarity,
272
+ )
273
+ top_k = gr.Slider(
274
+ label="Top K",
275
+ minimum=2,
276
+ maximum=50,
277
+ step=1,
278
+ value=10,
279
+ )
280
+ retrieve_similar_adus_btn = gr.Button(
281
+ "Retrieve *similar* ADUs for *selected* ADU"
282
+ )
283
+ similar_adus_df = gr.DataFrame(
284
+ headers=["doc_id", "adu_id", "score", "text"], interactive=False
285
+ )
286
+ retrieve_all_similar_adus_btn = gr.Button(
287
+ "Retrieve *similar* ADUs for *all* ADUs in the document"
288
+ )
289
+ all_similar_adus_df = gr.DataFrame(
290
+ headers=["doc_id", "query_adu_id", "adu_id", "score", "text"],
291
+ interactive=False,
292
+ )
293
+ retrieve_all_relevant_adus_btn = gr.Button(
294
+ "Retrieve *relevant* ADUs for *all* ADUs in the document"
295
+ )
296
+ all_relevant_adus_df = gr.DataFrame(
297
+ headers=["doc_id", "adu_id", "score", "text"], interactive=False
298
+ )
299
+
300
+ with gr.Tab("Import Documents", id="import_documents") as import_documents_tab:
301
+ upload_btn = gr.UploadButton(
302
+ "Batch Analyse Texts",
303
+ file_types=["text"],
304
+ file_count="multiple",
305
+ )
306
+
307
+ with gr.Accordion("Import text from arXiv", open=False):
308
+ arxiv_id = gr.Textbox(
309
+ label="arXiv paper ID",
310
+ placeholder=f"e.g. {default_arxiv_id}",
311
+ max_lines=1,
312
+ )
313
+ load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False)
314
+ load_arxiv_btn = gr.Button(
315
+ "Load & Analyse from arXiv", variant="secondary"
316
+ )
317
+
318
+ with gr.Accordion(
319
+ "Import argument structure annotated PIE dataset", open=False
320
+ ):
321
+ load_pie_dataset_kwargs_str = gr.Code(
322
+ language="json",
323
+ label="Parameters for Loading the PIE Dataset",
324
+ value=default_load_pie_dataset_kwargs_str,
325
+ lines=len(default_load_pie_dataset_kwargs_str.split("\n")),
326
+ )
327
+ load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
328
+
329
+ render_event_kwargs = dict(
330
+ fn=lambda _retriever, _document_id, _render_as, _render_kwargs: render_annotated_document(
331
+ retriever=_retriever[0],
332
+ document_id=_document_id,
333
+ render_with=render_caption2mode[_render_as],
334
+ render_kwargs_json=_render_kwargs,
335
+ ),
336
+ inputs=[retriever_state, selected_document_id, render_as, render_kwargs],
337
+ outputs=rendered_output,
338
+ )
339
+
340
+ show_overview_kwargs = dict(
341
+ fn=lambda _retriever: _retriever[0].docstore.overview(
342
+ layer_captions=layer_caption_mapping, use_predictions=True
343
+ ),
344
+ inputs=[retriever_state],
345
+ outputs=[processed_documents_df],
346
+ )
347
+ predict_btn.click(
348
+ fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
349
+ ).then(
350
+ fn=lambda _doc_text, _doc_id, _argumentation_model, _retriever, _split_regex_escaped: wrapped_process_text(
351
+ text=_doc_text,
352
+ doc_id=_doc_id,
353
+ argumentation_model=_argumentation_model[0],
354
+ retriever=_retriever[0],
355
+ split_regex_escaped=(
356
+ unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
357
+ ),
358
+ handle_parts_of_same=handle_parts_of_same,
359
+ ),
360
+ inputs=[
361
+ doc_text,
362
+ doc_id,
363
+ argumentation_model_state,
364
+ retriever_state,
365
+ split_regex_escaped,
366
+ ],
367
+ outputs=[selected_document_id],
368
+ api_name="predict",
369
+ ).success(
370
+ **show_overview_kwargs
371
+ ).success(
372
+ **render_event_kwargs
373
+ )
374
+ render_btn.click(**render_event_kwargs, api_name="render")
375
+
376
+ load_arxiv_btn.click(
377
+ fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
378
+ ).then(
379
+ fn=lambda _arxiv_id, _load_arxiv_only_abstract, _argumentation_model, _retriever, _split_regex_escaped: process_text_from_arxiv(
380
+ arxiv_id=_arxiv_id.strip() or default_arxiv_id,
381
+ abstract_only=_load_arxiv_only_abstract,
382
+ argumentation_model=_argumentation_model[0],
383
+ retriever=_retriever[0],
384
+ split_regex_escaped=(
385
+ unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
386
+ ),
387
+ handle_parts_of_same=handle_parts_of_same,
388
+ ),
389
+ inputs=[
390
+ arxiv_id,
391
+ load_arxiv_only_abstract,
392
+ argumentation_model_state,
393
+ retriever_state,
394
+ split_regex_escaped,
395
+ ],
396
+ outputs=[selected_document_id],
397
+ api_name="predict",
398
+ ).success(
399
+ **show_overview_kwargs
400
+ )
401
+
402
+ load_pie_dataset_btn.click(
403
+ fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs]
404
+ ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then(
405
+ fn=lambda _retriever, _load_pie_dataset_kwargs_str: wrapped_add_annotated_pie_documents_from_dataset(
406
+ retriever=_retriever[0],
407
+ verbose=True,
408
+ layer_captions=layer_caption_mapping,
409
+ **json.loads(_load_pie_dataset_kwargs_str),
410
+ ),
411
+ inputs=[retriever_state, load_pie_dataset_kwargs_str],
412
+ outputs=[processed_documents_df],
413
+ )
414
+
415
+ selected_document_id.change(
416
+ fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
417
+ ).then(**render_event_kwargs)
418
+
419
+ get_document_json_btn.click(
420
+ fn=lambda _retriever, _document_id: get_document_as_dict(
421
+ retriever=_retriever[0], doc_id=_document_id
422
+ ),
423
+ inputs=[retriever_state, selected_document_id],
424
+ outputs=[document_json],
425
+ )
426
+
427
+ upload_btn.upload(
428
+ fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs]
429
+ ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then(
430
+ fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_files(
431
+ file_names=_file_names,
432
+ argumentation_model=_argumentation_model[0],
433
+ retriever=_retriever[0],
434
+ split_regex_escaped=unescape_regex(_split_regex_escaped),
435
+ handle_parts_of_same=handle_parts_of_same,
436
+ layer_captions=layer_caption_mapping,
437
+ ),
438
+ inputs=[
439
+ upload_btn,
440
+ argumentation_model_state,
441
+ retriever_state,
442
+ split_regex_escaped,
443
+ ],
444
+ outputs=[processed_documents_df],
445
+ )
446
+ processed_documents_df.select(
447
+ fn=get_cell_for_fixed_column_from_df,
448
+ inputs=[processed_documents_df, gr.State("doc_id")],
449
+ outputs=[selected_document_id],
450
+ )
451
+
452
+ download_processed_documents_btn.click(
453
+ fn=lambda _retriever: download_processed_documents(
454
+ _retriever[0], file_name="processed_documents"
455
+ ),
456
+ inputs=[retriever_state],
457
+ outputs=[download_processed_documents_btn],
458
+ )
459
+ upload_processed_documents_btn.upload(
460
+ fn=lambda file_name, _retriever: upload_processed_documents(
461
+ file_name, retriever=_retriever[0], layer_captions=layer_caption_mapping
462
+ ),
463
+ inputs=[upload_processed_documents_btn, retriever_state],
464
+ outputs=[processed_documents_df],
465
+ )
466
+
467
+ retrieve_relevant_adus_event_kwargs = dict(
468
+ fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans(
469
+ retriever=_retriever[0],
470
+ query_span_id=_selected_adu_id,
471
+ k=_top_k,
472
+ score_threshold=_min_similarity,
473
+ relation_label_mapping=relation_name_mapping,
474
+ # columns=relevant_adus.headers
475
+ ),
476
+ inputs=[
477
+ retriever_state,
478
+ selected_adu_id,
479
+ min_similarity,
480
+ top_k,
481
+ ],
482
+ outputs=[relevant_adus_df],
483
+ )
484
+ relevant_adus_df.select(
485
+ fn=get_cell_for_fixed_column_from_df,
486
+ inputs=[relevant_adus_df, gr.State("doc_id")],
487
+ outputs=[selected_document_id],
488
+ )
489
+
490
+ selected_adu_id.change(
491
+ fn=lambda _retriever, _selected_adu_id: get_span_annotation(
492
+ retriever=_retriever[0], span_id=_selected_adu_id
493
+ ),
494
+ inputs=[retriever_state, selected_adu_id],
495
+ outputs=[selected_adu_text],
496
+ ).success(**retrieve_relevant_adus_event_kwargs)
497
+
498
+ retrieve_similar_adus_btn.click(
499
+ fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k: retrieve_similar_spans(
500
+ retriever=_retriever[0],
501
+ query_span_id=_selected_adu_id,
502
+ k=_tok_k,
503
+ score_threshold=_min_similarity,
504
+ ),
505
+ inputs=[
506
+ retriever_state,
507
+ selected_adu_id,
508
+ min_similarity,
509
+ top_k,
510
+ ],
511
+ outputs=[similar_adus_df],
512
+ )
513
+ similar_adus_df.select(
514
+ fn=get_cell_for_fixed_column_from_df,
515
+ inputs=[similar_adus_df, gr.State("doc_id")],
516
+ outputs=[selected_document_id],
517
+ )
518
+
519
+ retrieve_all_similar_adus_btn.click(
520
+ fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_similar_spans(
521
+ retriever=_retriever[0],
522
+ query_doc_id=_document_id,
523
+ k=_tok_k,
524
+ score_threshold=_min_similarity,
525
+ query_span_id_column="query_span_id",
526
+ ),
527
+ inputs=[
528
+ retriever_state,
529
+ selected_document_id,
530
+ min_similarity,
531
+ top_k,
532
+ ],
533
+ outputs=[all_similar_adus_df],
534
+ )
535
+
536
+ retrieve_all_relevant_adus_btn.click(
537
+ fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_relevant_spans(
538
+ retriever=_retriever[0],
539
+ query_doc_id=_document_id,
540
+ k=_tok_k,
541
+ score_threshold=_min_similarity,
542
+ query_span_id_column="query_span_id",
543
+ ),
544
+ inputs=[
545
+ retriever_state,
546
+ selected_document_id,
547
+ min_similarity,
548
+ top_k,
549
+ ],
550
+ outputs=[all_relevant_adus_df],
551
+ )
552
+
553
+ # select query span id from the "retrieve all" result data frames
554
+ all_similar_adus_df.select(
555
+ fn=get_cell_for_fixed_column_from_df,
556
+ inputs=[all_similar_adus_df, gr.State("query_span_id")],
557
+ outputs=[selected_adu_id],
558
+ )
559
+ all_relevant_adus_df.select(
560
+ fn=get_cell_for_fixed_column_from_df,
561
+ inputs=[all_relevant_adus_df, gr.State("query_span_id")],
562
+ outputs=[selected_adu_id],
563
+ )
564
+
565
+ # argumentation_model_state.change(
566
+ # fn=lambda _argumentation_model: set_relation_types(_argumentation_model[0]),
567
+ # inputs=[argumentation_model_state],
568
+ # outputs=[relation_types],
569
+ # )
570
+
571
+ rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[])
572
+
573
+ demo.launch()
574
+
575
+
576
+ if __name__ == "__main__":
577
+
578
+ main()
src/taskmodules/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .cross_text_binary_coref import CrossTextBinaryCorefTaskModuleWithOptionalContext
2
+ from .cross_text_binary_coref_nli import CrossTextBinaryCorefTaskModuleByNli
3
+ from .re_text_classification_with_indices import (
4
+ CrossTextBinaryCorefByRETextClassificationTaskModule,
5
+ RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers,
6
+ )
7
+
8
+ CrossTextBinaryCorefTaskModule2 = CrossTextBinaryCorefByRETextClassificationTaskModule
src/taskmodules/components/__init__.py ADDED
File without changes
src/taskmodules/cross_text_binary_coref.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, Sequence, TypeVar, Union
3
+
4
+ from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule
5
+ from pie_modules.taskmodules.cross_text_binary_coref import (
6
+ DocumentType,
7
+ SpanDoesNotFitIntoAvailableWindow,
8
+ TaskEncodingType,
9
+ )
10
+ from pie_modules.utils.tokenization import SpanNotAlignedWithTokenException
11
+ from pytorch_ie.annotations import Span
12
+ from pytorch_ie.core import TaskEncoding, TaskModule
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ S = TypeVar("S", bound=Span)
18
+
19
+
20
+ def shift_span(span: S, offset: int) -> S:
21
+ return span.copy(start=span.start + offset, end=span.end + offset)
22
+
23
+
24
+ @TaskModule.register()
25
+ class CrossTextBinaryCorefTaskModuleWithOptionalContext(CrossTextBinaryCorefTaskModule):
26
+ """Same as CrossTextBinaryCorefTaskModule, but:
27
+ - optionally without context.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ without_context: bool = False,
33
+ **kwargs,
34
+ ) -> None:
35
+ super().__init__(**kwargs)
36
+ self.without_context = without_context
37
+
38
+ def encode_input(
39
+ self,
40
+ document: DocumentType,
41
+ is_training: bool = False,
42
+ ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
43
+ if self.without_context:
44
+ return self.encode_input_without_context(document)
45
+ else:
46
+ return super().encode_input(document)
47
+
48
+ def encode_input_without_context(
49
+ self, document: DocumentType
50
+ ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
51
+ self.collect_all_relations(kind="available", relations=document.binary_coref_relations)
52
+ tokenizer_kwargs = dict(
53
+ padding=False,
54
+ truncation=False,
55
+ add_special_tokens=False,
56
+ )
57
+
58
+ task_encodings = []
59
+ for coref_rel in document.binary_coref_relations:
60
+
61
+ # TODO: This can miss instances if both texts are the same. We could check that
62
+ # coref_rel.head is in document.labeled_spans (same for the tail), but would this
63
+ # slow down the encoding?
64
+ if not (
65
+ coref_rel.head.target == document.text
66
+ or coref_rel.tail.target == document.text_pair
67
+ ):
68
+ raise ValueError(
69
+ f"It is expected that coref relations go from (head) spans over 'text' "
70
+ f"to (tail) spans over 'text_pair', but this is not the case for this "
71
+ f"relation (i.e. it points into the other direction): {coref_rel.resolve()}"
72
+ )
73
+ encoding = self.tokenizer(text=str(coref_rel.head), **tokenizer_kwargs)
74
+ encoding_pair = self.tokenizer(text=str(coref_rel.tail), **tokenizer_kwargs)
75
+
76
+ try:
77
+ current_encoding, token_span = self.truncate_encoding_around_span(
78
+ encoding=encoding, char_span=shift_span(coref_rel.head, -coref_rel.head.start)
79
+ )
80
+ current_encoding_pair, token_span_pair = self.truncate_encoding_around_span(
81
+ encoding=encoding_pair,
82
+ char_span=shift_span(coref_rel.tail, -coref_rel.tail.start),
83
+ )
84
+ except SpanNotAlignedWithTokenException as e:
85
+ logger.warning(
86
+ f"Could not get token offsets for argument ({e.span}) of coref relation: "
87
+ f"{coref_rel.resolve()}. Skip it."
88
+ )
89
+ self.collect_relation(kind="skipped_args_not_aligned", relation=coref_rel)
90
+ continue
91
+ except SpanDoesNotFitIntoAvailableWindow as e:
92
+ logger.warning(
93
+ f"Argument span [{e.span}] does not fit into available token window "
94
+ f"({self.available_window}). Skip it."
95
+ )
96
+ self.collect_relation(
97
+ kind="skipped_span_does_not_fit_into_window", relation=coref_rel
98
+ )
99
+ continue
100
+
101
+ task_encodings.append(
102
+ TaskEncoding(
103
+ document=document,
104
+ inputs={
105
+ "encoding": current_encoding,
106
+ "encoding_pair": current_encoding_pair,
107
+ "pooler_start_indices": token_span.start,
108
+ "pooler_end_indices": token_span.end,
109
+ "pooler_pair_start_indices": token_span_pair.start,
110
+ "pooler_pair_end_indices": token_span_pair.end,
111
+ },
112
+ metadata={"candidate_annotation": coref_rel},
113
+ )
114
+ )
115
+ self.collect_relation("used", coref_rel)
116
+ return task_encodings
src/taskmodules/cross_text_binary_coref_nli.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TypedDict, Union
3
+
4
+ import torch
5
+ from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
6
+ from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin
7
+ from pytorch_ie import Annotation
8
+ from pytorch_ie.core import TaskEncoding, TaskModule
9
+ from transformers import AutoTokenizer
10
+ from typing_extensions import TypeAlias
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ InputEncodingType: TypeAlias = Dict[str, Any]
15
+ TargetEncodingType: TypeAlias = Sequence[float]
16
+ DocumentType: TypeAlias = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
17
+
18
+ TaskEncodingType: TypeAlias = TaskEncoding[
19
+ DocumentType,
20
+ InputEncodingType,
21
+ TargetEncodingType,
22
+ ]
23
+
24
+
25
+ class TaskOutputType(TypedDict, total=False):
26
+ label_pair: Tuple[str, str]
27
+ entailment_probability_pair: Tuple[float, float]
28
+
29
+
30
+ ModelInputType: TypeAlias = Dict[str, torch.Tensor]
31
+ ModelTargetType: TypeAlias = Dict[str, torch.Tensor]
32
+ ModelOutputType: TypeAlias = Dict[str, torch.Tensor]
33
+
34
+ TaskModuleType: TypeAlias = TaskModule[
35
+ # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput
36
+ DocumentType,
37
+ InputEncodingType,
38
+ TargetEncodingType,
39
+ Tuple[ModelInputType, Optional[ModelTargetType]],
40
+ ModelTargetType,
41
+ TaskOutputType,
42
+ ]
43
+
44
+
45
+ @TaskModule.register()
46
+ class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleType):
47
+ """This taskmodule processes documents of type
48
+ TextPairDocumentWithLabeledSpansAndBinaryCorefRelations in preparation for a sequence
49
+ classification model trained for NLI. The assumption is that if the entailment class is
50
+ predicted for both directions, a coreference relation exists between the two spans.
51
+
52
+ It simply tokenizes and encodes the head and tail texts of the coreference relations as text
53
+ pairs, i.e. no context of head and tail is considered. During decoding, coreference relations
54
+ are created if the entailment class (see parameter entailment_label) is predicted for both
55
+ directions and the average probability is used as the score.
56
+ """
57
+
58
+ DOCUMENT_TYPE = DocumentType
59
+
60
+ def __init__(
61
+ self,
62
+ tokenizer_name_or_path: str,
63
+ labels: List[str],
64
+ entailment_label: str,
65
+ **kwargs,
66
+ ) -> None:
67
+ super().__init__(**kwargs)
68
+ self.save_hyperparameters()
69
+
70
+ self.labels = labels
71
+ self.entailment_label = entailment_label
72
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
73
+
74
+ def _post_prepare(self):
75
+ self.id_to_label = dict(enumerate(self.labels))
76
+ self.label_to_id = {v: k for k, v in self.id_to_label.items()}
77
+ self.entailment_idx = self.label_to_id[self.entailment_label]
78
+
79
+ def encode(self, documents: Union[DocumentType, Iterable[DocumentType]], **kwargs):
80
+ self.reset_statistics()
81
+ result = super().encode(documents=documents, **kwargs)
82
+ self.show_statistics()
83
+ return result
84
+
85
+ def encode_input(
86
+ self,
87
+ document: DocumentType,
88
+ is_training: bool = False,
89
+ ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
90
+ self.collect_all_relations(kind="available", relations=document.binary_coref_relations)
91
+ result = []
92
+ for coref_rel in document.binary_coref_relations:
93
+ head_text = str(coref_rel.head)
94
+ tail_text = str(coref_rel.tail)
95
+ task_encoding = TaskEncoding(
96
+ document=document,
97
+ inputs={"text": [head_text, tail_text], "text_pair": [tail_text, head_text]},
98
+ metadata={"candidate_annotation": coref_rel},
99
+ )
100
+ result.append(task_encoding)
101
+ self.collect_relation("used", coref_rel)
102
+ return result
103
+
104
+ def encode_target(
105
+ self,
106
+ task_encoding: TaskEncodingType,
107
+ ) -> Optional[TargetEncodingType]:
108
+ raise NotImplementedError()
109
+
110
+ def collate(
111
+ self,
112
+ task_encodings: Sequence[
113
+ TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType]
114
+ ],
115
+ ) -> Tuple[ModelInputType, Optional[ModelTargetType]]:
116
+ all_texts = []
117
+ all_texts_pair = []
118
+ for task_encoding in task_encodings:
119
+ all_texts.extend(task_encoding.inputs["text"])
120
+ all_texts_pair.extend(task_encoding.inputs["text_pair"])
121
+ inputs = self.tokenizer(
122
+ text=all_texts,
123
+ text_pair=all_texts_pair,
124
+ truncation=True,
125
+ padding=True,
126
+ return_tensors="pt",
127
+ )
128
+ if not task_encodings[0].has_targets:
129
+ return inputs, None
130
+ raise NotImplementedError()
131
+
132
+ def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]:
133
+ probs_tensor = model_output["probabilities"]
134
+ labels_tensor = model_output["labels"]
135
+
136
+ bs, num_classes = probs_tensor.size()
137
+ # Reshape the probs tensor to (bs/2, 2, num_classes)
138
+ probs_paired = probs_tensor.view(bs // 2, 2, num_classes).detach().cpu().tolist()
139
+
140
+ # Reshape the labels tensor to (bs/2, 2)
141
+ labels_paired = labels_tensor.view(bs // 2, 2).detach().cpu().tolist()
142
+
143
+ result = []
144
+ for (label_id, label_id_pair), (probs_list, probs_list_pair) in zip(
145
+ labels_paired, probs_paired
146
+ ):
147
+ task_output: TaskOutputType = {
148
+ "label_pair": (self.id_to_label[label_id], self.id_to_label[label_id_pair]),
149
+ "entailment_probability_pair": (
150
+ probs_list[self.entailment_idx],
151
+ probs_list_pair[self.entailment_idx],
152
+ ),
153
+ }
154
+ result.append(task_output)
155
+ return result
156
+
157
+ def create_annotations_from_output(
158
+ self,
159
+ task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType],
160
+ task_output: TaskOutputType,
161
+ ) -> Iterator[Tuple[str, Annotation]]:
162
+ if all(label == self.entailment_label for label in task_output["label_pair"]):
163
+ probs = task_output["entailment_probability_pair"]
164
+ score = (probs[0] + probs[1]) / 2
165
+ new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score)
166
+ yield "binary_coref_relations", new_coref_rel
src/taskmodules/re_text_classification_with_indices.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from itertools import chain
3
+ from typing import Dict, Optional, Sequence, Type
4
+
5
+ import torch
6
+ from pie_modules.annotations import BinaryCorefRelation
7
+ from pie_modules.document.processing.text_pair import shift_span
8
+ from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
9
+ from pie_modules.taskmodules import RETextClassificationWithIndicesTaskModule
10
+ from pie_modules.taskmodules.common import TaskModuleWithDocumentConverter
11
+ from pie_modules.taskmodules.re_text_classification_with_indices import MarkerFactory
12
+ from pie_modules.taskmodules.re_text_classification_with_indices import (
13
+ ModelTargetType as REModelTargetType,
14
+ )
15
+ from pie_modules.taskmodules.re_text_classification_with_indices import (
16
+ TaskOutputType as RETaskOutputType,
17
+ )
18
+ from pytorch_ie import Document, TaskModule
19
+ from pytorch_ie.annotations import LabeledSpan
20
+ from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations
21
+
22
+
23
+ class SharpBracketMarkerFactory(MarkerFactory):
24
+ def _get_marker(self, role: str, is_start: bool, label: Optional[str] = None) -> str:
25
+ result = "<"
26
+ if not is_start:
27
+ result += "/"
28
+ result += self._get_role_marker(role)
29
+ if label is not None:
30
+ result += f":{label}"
31
+ result += ">"
32
+ return result
33
+
34
+ def get_append_marker(self, role: str, label: Optional[str] = None) -> str:
35
+ role_marker = self._get_role_marker(role)
36
+ if label is None:
37
+ return f"<{role_marker}>"
38
+ else:
39
+ return f"<{role_marker}={label}>"
40
+
41
+
42
+ @TaskModule.register()
43
+ class RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers(
44
+ RETextClassificationWithIndicesTaskModule
45
+ ):
46
+ def __init__(self, use_sharp_marker: bool = False, **kwargs):
47
+ super().__init__(**kwargs)
48
+ self.use_sharp_marker = use_sharp_marker
49
+
50
+ def get_marker_factory(self) -> MarkerFactory:
51
+ if self.use_sharp_marker:
52
+ return SharpBracketMarkerFactory(role_to_marker=self.argument_role_to_marker)
53
+ else:
54
+ return MarkerFactory(role_to_marker=self.argument_role_to_marker)
55
+
56
+
57
+ def construct_text_document_from_text_pair_coref_document(
58
+ document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
59
+ glue_text: str,
60
+ no_relation_label: str,
61
+ relation_label_mapping: Optional[Dict[str, str]] = None,
62
+ add_span_mapping_to_metadata: bool = False,
63
+ ) -> TextDocumentWithLabeledSpansAndBinaryRelations:
64
+ if document.text == document.text_pair:
65
+ new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
66
+ id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text
67
+ )
68
+ old2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
69
+ new2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
70
+ for old_span in chain(document.labeled_spans, document.labeled_spans_pair):
71
+ new_span = old_span.copy()
72
+ # when detaching / copying the span, it may be the same as a previous span from the other
73
+ new_span = new2new_spans.get(new_span, new_span)
74
+ new2new_spans[new_span] = new_span
75
+ old2new_spans[old_span] = new_span
76
+ else:
77
+ new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
78
+ text=document.text + glue_text + document.text_pair,
79
+ id=document.id,
80
+ metadata=copy.deepcopy(document.metadata),
81
+ )
82
+ old2new_spans = {}
83
+ old2new_spans.update({span: span.copy() for span in document.labeled_spans})
84
+ offset = len(document.text) + len(glue_text)
85
+ old2new_spans.update(
86
+ {span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair}
87
+ )
88
+
89
+ # sort to make order deterministic
90
+ new_doc.labeled_spans.extend(
91
+ sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label))
92
+ )
93
+ for old_rel in document.binary_coref_relations:
94
+ label = old_rel.label if old_rel.score > 0.0 else no_relation_label
95
+ if relation_label_mapping is not None:
96
+ label = relation_label_mapping.get(label, label)
97
+ new_rel = old_rel.copy(
98
+ head=old2new_spans[old_rel.head],
99
+ tail=old2new_spans[old_rel.tail],
100
+ label=label,
101
+ score=1.0,
102
+ )
103
+ new_doc.binary_relations.append(new_rel)
104
+
105
+ if add_span_mapping_to_metadata:
106
+ new_doc.metadata["span_mapping"] = old2new_spans
107
+ return new_doc
108
+
109
+
110
+ @TaskModule.register()
111
+ class CrossTextBinaryCorefByRETextClassificationTaskModule(
112
+ TaskModuleWithDocumentConverter,
113
+ RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers,
114
+ ):
115
+ def __init__(
116
+ self,
117
+ coref_relation_label: str,
118
+ relation_annotation: str = "binary_relations",
119
+ probability_threshold: float = 0.0,
120
+ **kwargs,
121
+ ):
122
+ if relation_annotation != "binary_relations":
123
+ raise ValueError(
124
+ f"{type(self).__name__} requires relation_annotation='binary_relations', "
125
+ f"but it is: {relation_annotation}"
126
+ )
127
+ super().__init__(relation_annotation=relation_annotation, **kwargs)
128
+ self.coref_relation_label = coref_relation_label
129
+ self.probability_threshold = probability_threshold
130
+
131
+ @property
132
+ def document_type(self) -> Optional[Type[Document]]:
133
+ return TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
134
+
135
+ def _get_glue_text(self) -> str:
136
+ result = self.tokenizer.decode(self._get_glue_token_ids())
137
+ return result
138
+
139
+ def _convert_document(
140
+ self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
141
+ ) -> TextDocumentWithLabeledSpansAndBinaryRelations:
142
+ return construct_text_document_from_text_pair_coref_document(
143
+ document,
144
+ glue_text=self._get_glue_text(),
145
+ relation_label_mapping={"coref": self.coref_relation_label},
146
+ no_relation_label=self.none_label,
147
+ add_span_mapping_to_metadata=True,
148
+ )
149
+
150
+ def _integrate_predictions_from_converted_document(
151
+ self,
152
+ document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
153
+ converted_document: TextDocumentWithLabeledSpansAndBinaryRelations,
154
+ ) -> None:
155
+ original2converted_span = converted_document.metadata["span_mapping"]
156
+ new2original_span = {
157
+ converted_s: orig_s for orig_s, converted_s in original2converted_span.items()
158
+ }
159
+
160
+ for rel in converted_document.binary_relations.predictions:
161
+ original_head = new2original_span[rel.head]
162
+ original_tail = new2original_span[rel.tail]
163
+ if rel.label != self.coref_relation_label:
164
+ raise ValueError(f"unexpected label: {rel.label}")
165
+ if rel.score >= self.probability_threshold:
166
+ original_predicted_rel = BinaryCorefRelation(
167
+ head=original_head, tail=original_tail, label="coref", score=rel.score
168
+ )
169
+ document.binary_coref_relations.predictions.append(original_predicted_rel)
170
+
171
+ def unbatch_output(self, model_output: REModelTargetType) -> Sequence[RETaskOutputType]:
172
+ coref_relation_idx = self.label_to_id[self.coref_relation_label]
173
+ # we are just concerned with the coref class, so we overwrite the labels field
174
+ model_output = copy.copy(model_output)
175
+ model_output["labels"] = torch.ones_like(model_output["labels"]) * coref_relation_idx
176
+ return super().unbatch_output(model_output=model_output)
src/train.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+
3
+ root = pyrootutils.setup_root(
4
+ search_from=__file__,
5
+ indicator=[".project-root"],
6
+ pythonpath=True,
7
+ dotenv=True,
8
+ )
9
+
10
+ # ------------------------------------------------------------------------------------ #
11
+ # `pyrootutils.setup_root(...)` is an optional line at the top of each entry file
12
+ # that helps to make the environment more robust and convenient
13
+ #
14
+ # the main advantages are:
15
+ # - allows you to keep all entry files in "src/" without installing project as a package
16
+ # - makes paths and scripts always work no matter where is your current work dir
17
+ # - automatically loads environment variables from ".env" file if exists
18
+ #
19
+ # how it works:
20
+ # - the line above recursively searches for either ".git" or "pyproject.toml" in present
21
+ # and parent dirs, to determine the project root dir
22
+ # - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
23
+ # any place without installing project as a package
24
+ # - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
25
+ # to make all paths always relative to the project root
26
+ # - loads environment variables from ".env" file in root dir (if `dotenv=True`)
27
+ #
28
+ # you can remove `pyrootutils.setup_root(...)` if you:
29
+ # 1. either install project as a package or move each entry file to the project root dir
30
+ # 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
31
+ # 3. always run entry files from the project root dir
32
+ #
33
+ # https://github.com/ashleve/pyrootutils
34
+ # ------------------------------------------------------------------------------------ #
35
+
36
+ import os.path
37
+ from typing import Any, Dict, List, Optional, Tuple
38
+
39
+ import hydra
40
+ import pytorch_lightning as pl
41
+ from omegaconf import DictConfig
42
+ from pie_datasets import DatasetDict
43
+ from pie_modules.models import * # noqa: F403
44
+ from pie_modules.models import SimpleGenerativeModel
45
+ from pie_modules.models.interface import RequiresTaskmoduleConfig
46
+ from pie_modules.taskmodules import * # noqa: F403
47
+ from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
48
+ from pytorch_ie.core import PyTorchIEModel, TaskModule
49
+ from pytorch_ie.models import * # noqa: F403
50
+ from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
51
+ from pytorch_ie.taskmodules import * # noqa: F403
52
+ from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize
53
+ from pytorch_lightning import Callback, Trainer
54
+ from pytorch_lightning.loggers import Logger
55
+
56
+ from src import utils
57
+ from src.datamodules import PieDataModule
58
+ from src.models import * # noqa: F403
59
+ from src.taskmodules import * # noqa: F403
60
+
61
+ log = utils.get_pylogger(__name__)
62
+
63
+
64
+ def get_metric_value(metric_dict: dict, metric_name: str) -> Optional[float]:
65
+ """Safely retrieves value of the metric logged in LightningModule."""
66
+
67
+ if not metric_name:
68
+ log.info("Metric name is None! Skipping metric value retrieval...")
69
+ return None
70
+
71
+ if metric_name not in metric_dict:
72
+ raise Exception(
73
+ f"Metric value not found! <metric_name={metric_name}>\n"
74
+ "Make sure metric name logged in LightningModule is correct!\n"
75
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
76
+ )
77
+
78
+ metric_value = metric_dict[metric_name].item()
79
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
80
+
81
+ return metric_value
82
+
83
+
84
+ @utils.task_wrapper
85
+ def train(cfg: DictConfig) -> Tuple[dict, dict]:
86
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
87
+ training.
88
+
89
+ This method is wrapped in optional @task_wrapper decorator which applies extra utilities
90
+ before and after the call.
91
+
92
+ Args:
93
+ cfg (DictConfig): Configuration composed by Hydra.
94
+
95
+ Returns:
96
+ Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
97
+ """
98
+
99
+ # set seed for random number generators in pytorch, numpy and python.random
100
+ if cfg.get("seed"):
101
+ pl.seed_everything(cfg.seed, workers=True)
102
+
103
+ # Init pytorch-ie taskmodule
104
+ log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
105
+ taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")
106
+
107
+ # Init pytorch-ie dataset
108
+ log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
109
+ dataset: DatasetDict = hydra.utils.instantiate(
110
+ cfg.dataset,
111
+ _convert_="partial",
112
+ )
113
+
114
+ # auto-convert the dataset if the taskmodule specifies a document type
115
+ dataset = taskmodule.convert_dataset(dataset)
116
+
117
+ # Init pytorch-ie datamodule
118
+ log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
119
+ datamodule: PieDataModule = hydra.utils.instantiate(
120
+ cfg.datamodule, dataset=dataset, taskmodule=taskmodule, _convert_="partial"
121
+ )
122
+ # Use the train dataset split to prepare the taskmodule
123
+ taskmodule.prepare(dataset[datamodule.train_split])
124
+
125
+ # Init the pytorch-ie model
126
+ log.info(f"Instantiating model <{cfg.model._target_}>")
127
+ # get additional model arguments
128
+ additional_model_kwargs: Dict[str, Any] = {}
129
+ model_cls = hydra.utils.get_class(cfg.model["_target_"])
130
+ # NOTE: MODIFY THE additional_model_kwargs IF YOUR MODEL REQUIRES ANY MORE PARAMETERS FROM THE TASKMODULE!
131
+ # SEE EXAMPLES BELOW.
132
+ if issubclass(model_cls, RequiresNumClasses):
133
+ additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id)
134
+ if issubclass(model_cls, RequiresModelNameOrPath):
135
+ if "model_name_or_path" not in cfg.model:
136
+ raise Exception(
137
+ f"Please specify model_name_or_path in the model config for {model_cls.__name__}."
138
+ )
139
+ if isinstance(taskmodule, ChangesTokenizerVocabSize):
140
+ additional_model_kwargs["tokenizer_vocab_size"] = len(taskmodule.tokenizer)
141
+
142
+ pooler_config = cfg["model"].get("pooler")
143
+ if pooler_config is not None:
144
+ if isinstance(pooler_config, str):
145
+ pooler_config = {"type": pooler_config}
146
+ pooler_config = dict(pooler_config)
147
+ if pooler_config["type"] in ["start_tokens", "mention_pooling"]:
148
+ # NOTE: This is very hacky, we should create a new interface class, e.g. RequiresPoolerNumIndices
149
+ if hasattr(taskmodule, "argument_role2idx"):
150
+ pooler_config["num_indices"] = len(taskmodule.argument_role2idx)
151
+ else:
152
+ pooler_config["num_indices"] = 1
153
+ elif pooler_config["type"] == "cls_token":
154
+ pass
155
+ else:
156
+ raise Exception(
157
+ f"unknown pooler type: {pooler_config['type']}. Please adjust the train.py script for that type."
158
+ )
159
+ additional_model_kwargs["pooler"] = pooler_config
160
+
161
+ if issubclass(model_cls, RequiresTaskmoduleConfig):
162
+ additional_model_kwargs["taskmodule_config"] = taskmodule.config
163
+
164
+ if model_cls == SimpleGenerativeModel:
165
+ # There may be already some base_model_config entries in the model config. Also need to convert the
166
+ # base_model_config to a dict, because it is a OmegaConf object which does not accept additional entries.
167
+ base_model_config = (
168
+ dict(cfg.model.base_model_config) if "base_model_config" in cfg.model else {}
169
+ )
170
+ if isinstance(taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
171
+ base_model_config.update(
172
+ dict(
173
+ bos_token_id=taskmodule.bos_id,
174
+ eos_token_id=taskmodule.eos_id,
175
+ pad_token_id=taskmodule.eos_id,
176
+ target_token_ids=taskmodule.target_token_ids,
177
+ embedding_weight_mapping=taskmodule.label_embedding_weight_mapping,
178
+ )
179
+ )
180
+ additional_model_kwargs["base_model_config"] = base_model_config
181
+
182
+ # initialize the model
183
+ model: PyTorchIEModel = hydra.utils.instantiate(
184
+ cfg.model, _convert_="partial", **additional_model_kwargs
185
+ )
186
+
187
+ log.info("Instantiating callbacks...")
188
+ callbacks: List[Callback] = utils.instantiate_dict_entries(cfg, key="callbacks")
189
+
190
+ log.info("Instantiating loggers...")
191
+ logger: List[Logger] = utils.instantiate_dict_entries(cfg, key="logger")
192
+
193
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
194
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
195
+
196
+ object_dict = {
197
+ "cfg": cfg,
198
+ "dataset": dataset,
199
+ "taskmodule": taskmodule,
200
+ "model": model,
201
+ "callbacks": callbacks,
202
+ "logger": logger,
203
+ "trainer": trainer,
204
+ }
205
+
206
+ if logger:
207
+ log.info("Logging hyperparameters!")
208
+ utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)
209
+
210
+ if cfg.model_save_dir is not None:
211
+ log.info(f"Save taskmodule to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
212
+ taskmodule.save_pretrained(save_directory=cfg.model_save_dir, push_to_hub=cfg.push_to_hub)
213
+ else:
214
+ log.warning("the taskmodule is not saved because no save_dir is specified")
215
+
216
+ if cfg.get("train"):
217
+ log.info("Starting training!")
218
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
219
+
220
+ train_metrics = trainer.callback_metrics
221
+
222
+ best_ckpt_path = trainer.checkpoint_callback.best_model_path
223
+ if best_ckpt_path != "":
224
+ log.info(f"Best ckpt path: {best_ckpt_path}")
225
+ best_checkpoint_file = os.path.basename(best_ckpt_path)
226
+ utils.log_hyperparameters(
227
+ logger=logger,
228
+ best_checkpoint=best_checkpoint_file,
229
+ checkpoint_dir=trainer.checkpoint_callback.dirpath,
230
+ )
231
+
232
+ if not cfg.trainer.get("fast_dev_run"):
233
+ if cfg.model_save_dir is not None:
234
+ if best_ckpt_path == "":
235
+ log.warning("Best ckpt not found! Using current weights for saving...")
236
+ else:
237
+ model = type(model).load_from_checkpoint(best_ckpt_path)
238
+
239
+ log.info(f"Save model to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
240
+ model.save_pretrained(save_directory=cfg.model_save_dir, push_to_hub=cfg.push_to_hub)
241
+ else:
242
+ log.warning("the model is not saved because no save_dir is specified")
243
+
244
+ if cfg.get("validate"):
245
+ log.info("Starting validation!")
246
+ if best_ckpt_path == "":
247
+ log.warning("Best ckpt not found! Using current weights for validation...")
248
+ trainer.validate(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None)
249
+ elif cfg.get("train"):
250
+ log.warning(
251
+ "Validation after training is skipped! That means, the finally reported validation scores are "
252
+ "the values from the *last* checkpoint, not from the *best* checkpoint (which is saved)!"
253
+ )
254
+
255
+ if cfg.get("test"):
256
+ log.info("Starting testing!")
257
+ if best_ckpt_path == "":
258
+ log.warning("Best ckpt not found! Using current weights for testing...")
259
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None)
260
+
261
+ test_metrics = trainer.callback_metrics
262
+
263
+ # merge train and test metrics
264
+ metric_dict = {**train_metrics, **test_metrics}
265
+
266
+ # add model_save_dir to the result so that it gets dumped to job_return_value.json
267
+ # if we use hydra_callbacks.SaveJobReturnValueCallback
268
+ if cfg.get("model_save_dir") is not None:
269
+ metric_dict["model_save_dir"] = cfg.model_save_dir
270
+
271
+ return metric_dict, object_dict
272
+
273
+
274
+ @hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="train.yaml")
275
+ def main(cfg: DictConfig) -> Optional[float]:
276
+ # train the model
277
+ metric_dict, _ = train(cfg)
278
+
279
+ # safely retrieve metric value for hydra-based hyperparameter optimization
280
+ if cfg.get("optimized_metric") is not None:
281
+ metric_value = get_metric_value(
282
+ metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
283
+ )
284
+
285
+ # return optimized metric
286
+ return metric_value
287
+ else:
288
+ return metric_dict
289
+
290
+
291
+ if __name__ == "__main__":
292
+ utils.replace_sys_args_with_values_from_files()
293
+ utils.prepare_omegaconf()
294
+ main()
src/utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .config_utils import execute_pipeline, instantiate_dict_entries, prepare_omegaconf
2
+ from .data_utils import download_and_unzip, filter_dataframe_and_get_column
3
+ from .logging_utils import close_loggers, get_pylogger, log_hyperparameters
4
+ from .rich_utils import enforce_tags, print_config_tree
5
+ from .span_utils import distance
6
+ from .task_utils import extras, replace_sys_args_with_values_from_files, save_file, task_wrapper
src/utils/config_utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import copy
2
+ from typing import Any, List, Optional
3
+
4
+ from hydra.utils import instantiate
5
+ from omegaconf import DictConfig, OmegaConf
6
+
7
+ from src.utils.logging_utils import get_pylogger
8
+
9
+ logger = get_pylogger(__name__)
10
+
11
+
12
+ def execute_pipeline(
13
+ input: Any,
14
+ setup: Optional[Any] = None,
15
+ **processors,
16
+ ) -> Any:
17
+ if setup is not None and callable(setup):
18
+ setup()
19
+ result = input
20
+ for processor_name, processor_config in processors.items():
21
+ if not isinstance(processor_config, dict) or "_processor_" not in processor_config:
22
+ continue
23
+ logger.info(f"call processor: {processor_name}")
24
+ config = copy(processor_config)
25
+ if not config.pop("_enabled_", True):
26
+ logger.warning(f"skip processor because it is disabled: {processor_name}")
27
+ continue
28
+ # rename key "_processor_" to "_target_"
29
+ if "_target_" in config:
30
+ raise ValueError(
31
+ f"processor {processor_name} has a key '_target_', which is not allowed"
32
+ )
33
+ config["_target_"] = config.pop("_processor_")
34
+ # IMPORTANT: We pass result as the first argument after the config in contrast to adding it to the config.
35
+ # By doing so, we prevent that it gets converted into a OmegaConf object which would be converted back to
36
+ # a simple dict breaking all the DatasetDict methods
37
+ tmp_result = instantiate(config, result, _convert_="partial")
38
+ if tmp_result is not None:
39
+ result = tmp_result
40
+ else:
41
+ logger.warning(f'processor "{processor_name}" did not return a result')
42
+ return result
43
+
44
+
45
+ def instantiate_dict_entries(
46
+ config: DictConfig, key: str, entry_description: Optional[str] = None
47
+ ) -> List:
48
+ entries: List = []
49
+ key_config = config.get(key)
50
+
51
+ if not key_config:
52
+ logger.warning(f"{key} config is empty.")
53
+ return entries
54
+
55
+ if not isinstance(key_config, DictConfig):
56
+ raise TypeError("Logger config must be a DictConfig!")
57
+
58
+ for _, entry_conf in key_config.items():
59
+ if isinstance(entry_conf, DictConfig) and "_target_" in entry_conf:
60
+ logger.info(f"Instantiating {entry_description or key} <{entry_conf._target_}>")
61
+ entries.append(instantiate(entry_conf, _convert_="partial"))
62
+
63
+ return entries
64
+
65
+
66
+ def prepare_omegaconf():
67
+ # register replace resolver (used to replace "/" with "-" in names to use them as e.g. wandb project names)
68
+ if not OmegaConf.has_resolver("replace"):
69
+ OmegaConf.register_new_resolver("replace", lambda s, x, y: s.replace(x, y))
70
+ else:
71
+ logger.warning("OmegaConf resolver 'replace' is already registered")
src/utils/data_utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from urllib.request import urlretrieve
4
+ from zipfile import ZipFile
5
+
6
+ import pandas as pd
7
+
8
+ from src.utils.logging_utils import get_pylogger
9
+
10
+ log = get_pylogger(__name__)
11
+
12
+
13
+ def filter_dataframe_and_get_column(
14
+ dataframe: pd.DataFrame, filter_column: str, filter_value: str, select_column: str
15
+ ) -> List[str]:
16
+ return dataframe[dataframe[filter_column] == filter_value][select_column].tolist()
17
+
18
+
19
+ def download_and_unzip(
20
+ url: str, target_path: str, force_download: bool = False, remove_tmp_file: bool = False
21
+ ):
22
+ log.warning(f"download zip file from {url} to {target_path} ...")
23
+ if not (url.startswith("http://") or url.startswith("https://")):
24
+ raise ValueError(f"url needs to point to a http(s) address, but it is: {url}")
25
+ tmp_file = os.path.join(target_path, os.path.basename(url))
26
+ if os.path.exists(tmp_file) and not force_download:
27
+ log.warning(f"tmp file {tmp_file} already exists, skip downloading {url}")
28
+ else:
29
+ urlretrieve(url, tmp_file) # nosec
30
+ with ZipFile(tmp_file, "r") as zfile:
31
+ zfile.extractall(target_path)
32
+ if remove_tmp_file:
33
+ os.remove(tmp_file)