https://github.com/ArneBinder/pie-document-level/pull/312
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- requirements.txt +38 -18
- src/datamodules/__init__.py +1 -0
- src/datamodules/components/__init__.py +0 -0
- src/datamodules/components/sampler.py +67 -0
- src/datamodules/datamodule.py +154 -0
- src/dataset/__init__.py +0 -0
- src/dataset/processing.py +29 -0
- src/demo/__init__.py +0 -0
- src/demo/annotation_utils.py +137 -0
- src/demo/backend_utils.py +221 -0
- src/demo/data_utils.py +63 -0
- src/demo/frontend_utils.py +56 -0
- src/demo/rendering_utils.py +296 -0
- src/demo/rendering_utils_displacy.py +217 -0
- src/demo/retrieve_and_dump_all_relevant.py +101 -0
- src/demo/retriever_utils.py +313 -0
- src/document/__init__.py +0 -0
- src/document/processing.py +223 -0
- src/evaluate.py +137 -0
- src/evaluate_documents.py +116 -0
- src/hydra_callbacks/__init__.py +1 -0
- src/hydra_callbacks/save_job_return_value.py +261 -0
- src/langchain_modules/span_retriever.py +3 -3
- src/metrics/__init__.py +2 -0
- src/metrics/annotation_processor.py +23 -0
- src/metrics/coref_sklearn.py +162 -0
- src/metrics/coref_torchmetrics.py +107 -0
- src/models/__init__.py +5 -0
- src/models/components/__init__.py +0 -0
- src/models/components/pooler.py +79 -0
- src/models/sequence_classification_with_pooler.py +166 -0
- src/models/utils/__init__.py +5 -1
- src/models/utils/loading.py +4 -4
- src/pipeline/__init__.py +2 -0
- src/pipeline/ner_re_pipeline.py +208 -0
- src/pipeline/span_retrieval_based_re_pipeline.py +130 -0
- src/predict.py +183 -0
- src/serializer/__init__.py +1 -0
- src/serializer/interface.py +16 -0
- src/serializer/json.py +179 -0
- src/start_demo.py +578 -0
- src/taskmodules/__init__.py +8 -0
- src/taskmodules/components/__init__.py +0 -0
- src/taskmodules/cross_text_binary_coref.py +116 -0
- src/taskmodules/cross_text_binary_coref_nli.py +166 -0
- src/taskmodules/re_text_classification_with_indices.py +176 -0
- src/train.py +294 -0
- src/utils/__init__.py +6 -0
- src/utils/config_utils.py +71 -0
- src/utils/data_utils.py +33 -0
requirements.txt
CHANGED
@@ -1,14 +1,3 @@
|
|
1 |
-
# this requires python>=3.10
|
2 |
-
gradio~=5.4.0
|
3 |
-
prettytable==3.10.0
|
4 |
-
beautifulsoup4==4.12.3
|
5 |
-
# numpy 2.0.0 breaks the code
|
6 |
-
numpy==1.25.2
|
7 |
-
scipy==1.13.0
|
8 |
-
arxiv==2.1.3
|
9 |
-
pyrootutils>=1.0.0,<1.1.0
|
10 |
-
|
11 |
-
########## from root requirements ##########
|
12 |
# --------- pytorch-ie --------- #
|
13 |
pytorch-ie>=0.29.6,<0.32.0
|
14 |
pie-datasets>=0.10.5,<0.11.0
|
@@ -16,20 +5,51 @@ pie-modules>=0.14.0,<0.15.0
|
|
16 |
|
17 |
# --------- models -------- #
|
18 |
adapters>=0.1.2,<0.2.0
|
19 |
-
|
|
|
20 |
langchain>=0.3.0,<0.4.0
|
21 |
langchain-core>=0.3.0,<0.4.0
|
22 |
langchain-community>=0.3.0,<0.4.0
|
23 |
# we use QDrant as vectorstore backend
|
24 |
langchain-qdrant>=0.1.0,<0.2.0
|
25 |
qdrant-client>=1.12.0,<2.0.0
|
26 |
-
# 0.26 seems to be broken when used with adapters, see https://github.com/adapter-hub/adapters/issues/748
|
27 |
-
huggingface_hub<0.26.0 # 0.26 seems to be broken
|
28 |
-
# to to handle segmented entities (if HANDLE_PARTS_OF_SAME=True)
|
29 |
-
networkx>=3.0.0,<4.0.0
|
30 |
|
31 |
-
# ---------
|
|
|
|
|
|
|
|
|
32 |
hydra-core>=1.3.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
# ---------
|
35 |
pre-commit # hooks for applying linters on commit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# --------- pytorch-ie --------- #
|
2 |
pytorch-ie>=0.29.6,<0.32.0
|
3 |
pie-datasets>=0.10.5,<0.11.0
|
|
|
5 |
|
6 |
# --------- models -------- #
|
7 |
adapters>=0.1.2,<0.2.0
|
8 |
+
pytorch-crf~=0.7.2
|
9 |
+
# --------- retriever -------- #
|
10 |
langchain>=0.3.0,<0.4.0
|
11 |
langchain-core>=0.3.0,<0.4.0
|
12 |
langchain-community>=0.3.0,<0.4.0
|
13 |
# we use QDrant as vectorstore backend
|
14 |
langchain-qdrant>=0.1.0,<0.2.0
|
15 |
qdrant-client>=1.12.0,<2.0.0
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
# --------- demo -------- #
|
18 |
+
gradio~=5.4.0
|
19 |
+
arxiv~=2.1.3
|
20 |
+
|
21 |
+
# --------- hydra --------- #
|
22 |
hydra-core>=1.3.0
|
23 |
+
hydra-colorlog>=1.2.0
|
24 |
+
hydra-optuna-sweeper>=1.2.0
|
25 |
+
|
26 |
+
# --------- loggers --------- #
|
27 |
+
wandb
|
28 |
+
# neptune-client
|
29 |
+
# mlflow
|
30 |
+
# comet-ml
|
31 |
+
# tensorboard
|
32 |
+
# aim
|
33 |
|
34 |
+
# --------- linters --------- #
|
35 |
pre-commit # hooks for applying linters on commit
|
36 |
+
black # code formatting
|
37 |
+
isort # import sorting
|
38 |
+
flake8 # code analysis
|
39 |
+
nbstripout # remove output from jupyter notebooks
|
40 |
+
|
41 |
+
# --------- others --------- #
|
42 |
+
pyrootutils # standardizing the project root setup
|
43 |
+
python-dotenv # loading env variables from .env file
|
44 |
+
rich # beautiful text formatting in terminal
|
45 |
+
pytest # tests
|
46 |
+
pytest-cov # test coverageataset
|
47 |
+
sh # for running bash commands in some tests
|
48 |
+
pudb # debugger
|
49 |
+
tabulate # show statistics as markdown
|
50 |
+
plotext # show statistics as plots
|
51 |
+
prettytable # rendering annotated docs as table (demo)
|
52 |
+
beautifulsoup4 # rendering annotated docs with displacy + highlighted relations (demo)
|
53 |
+
# 0.26 seems to be broken when used with adapters, see https://github.com/adapter-hub/adapters/issues/748
|
54 |
+
huggingface_hub<0.26.0 # interaction with HF hub
|
55 |
+
networkx~=3.2.1 # to handle segmented entities (e.g if HANDLE_PARTS_OF_SAME=True in demo)
|
src/datamodules/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .datamodule import PieDataModule
|
src/datamodules/components/__init__.py
ADDED
File without changes
|
src/datamodules/components/sampler.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This is a slightly modified version of https://github.com/ufoym/imbalanced-dataset-sampler."""
|
2 |
+
|
3 |
+
from typing import Callable, List, Optional
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
import torch.utils.data
|
8 |
+
|
9 |
+
|
10 |
+
class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
|
11 |
+
"""Samples elements randomly from a given list of indices for imbalanced dataset.
|
12 |
+
|
13 |
+
Arguments:
|
14 |
+
indices: a list of indices
|
15 |
+
num_samples: number of samples to draw
|
16 |
+
callback_get_label: a callback-like function which takes one argument - the dataset
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
dataset,
|
22 |
+
labels: Optional[List] = None,
|
23 |
+
indices: Optional[List] = None,
|
24 |
+
num_samples: Optional[int] = None,
|
25 |
+
callback_get_label: Optional[Callable] = None,
|
26 |
+
):
|
27 |
+
# if indices is not provided, all elements in the dataset will be considered
|
28 |
+
self.indices = list(range(len(dataset))) if indices is None else indices
|
29 |
+
|
30 |
+
# define custom callback
|
31 |
+
self.callback_get_label = callback_get_label
|
32 |
+
|
33 |
+
# if num_samples is not provided, draw `len(indices)` samples in each iteration
|
34 |
+
self.num_samples = len(self.indices) if num_samples is None else num_samples
|
35 |
+
|
36 |
+
# distribution of classes in the dataset
|
37 |
+
df = pd.DataFrame()
|
38 |
+
df["label"] = self._get_labels(dataset) if labels is None else labels
|
39 |
+
df.index = self.indices
|
40 |
+
df = df.sort_index()
|
41 |
+
|
42 |
+
label_to_count = df["label"].value_counts()
|
43 |
+
|
44 |
+
weights = 1.0 / label_to_count[df["label"]]
|
45 |
+
|
46 |
+
self.weights = torch.DoubleTensor(weights.to_list())
|
47 |
+
|
48 |
+
def _get_labels(self, dataset):
|
49 |
+
if self.callback_get_label:
|
50 |
+
return self.callback_get_label(dataset)
|
51 |
+
elif isinstance(dataset, torch.utils.data.TensorDataset):
|
52 |
+
return dataset.tensors[1]
|
53 |
+
elif isinstance(dataset, torch.utils.data.Subset):
|
54 |
+
return dataset.dataset.imgs[:][1]
|
55 |
+
elif isinstance(dataset, torch.utils.data.Dataset):
|
56 |
+
return dataset.get_labels()
|
57 |
+
else:
|
58 |
+
raise NotImplementedError
|
59 |
+
|
60 |
+
def __iter__(self):
|
61 |
+
return (
|
62 |
+
self.indices[i]
|
63 |
+
for i in torch.multinomial(self.weights, self.num_samples, replacement=True)
|
64 |
+
)
|
65 |
+
|
66 |
+
def __len__(self):
|
67 |
+
return self.num_samples
|
src/datamodules/datamodule.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union
|
2 |
+
|
3 |
+
from pytorch_ie.core import Document
|
4 |
+
from pytorch_ie.core.taskmodule import (
|
5 |
+
IterableTaskEncodingDataset,
|
6 |
+
TaskEncoding,
|
7 |
+
TaskEncodingDataset,
|
8 |
+
TaskModule,
|
9 |
+
)
|
10 |
+
from pytorch_lightning import LightningDataModule
|
11 |
+
from torch.utils.data import DataLoader, Sampler
|
12 |
+
from typing_extensions import TypeAlias
|
13 |
+
|
14 |
+
from .components.sampler import ImbalancedDatasetSampler
|
15 |
+
|
16 |
+
DocumentType = TypeVar("DocumentType", bound=Document)
|
17 |
+
InputEncoding = TypeVar("InputEncoding")
|
18 |
+
TargetEncoding = TypeVar("TargetEncoding")
|
19 |
+
DatasetType: TypeAlias = Union[
|
20 |
+
TaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
|
21 |
+
IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
|
22 |
+
]
|
23 |
+
|
24 |
+
|
25 |
+
class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]):
|
26 |
+
"""A simple LightningDataModule for PIE document datasets.
|
27 |
+
|
28 |
+
A DataModule implements 5 key methods:
|
29 |
+
- prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
|
30 |
+
- setup (things to do on every accelerator in distributed mode)
|
31 |
+
- train_dataloader (the training dataloader)
|
32 |
+
- val_dataloader (the validation dataloader(s))
|
33 |
+
- test_dataloader (the test dataloader(s))
|
34 |
+
|
35 |
+
This allows you to share a full dataset without explaining how to download,
|
36 |
+
split, transform and process the data.
|
37 |
+
|
38 |
+
Read the docs:
|
39 |
+
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
taskmodule: TaskModule[DocumentType, InputEncoding, TargetEncoding, Any, Any, Any],
|
45 |
+
dataset: Dict[str, Sequence[DocumentType]],
|
46 |
+
data_config_path: Optional[str] = None,
|
47 |
+
train_split: Optional[str] = "train",
|
48 |
+
val_split: Optional[str] = "validation",
|
49 |
+
test_split: Optional[str] = "test",
|
50 |
+
show_progress_for_encode: bool = False,
|
51 |
+
train_sampler: Optional[str] = None,
|
52 |
+
**dataloader_kwargs,
|
53 |
+
):
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
self.taskmodule = taskmodule
|
57 |
+
self.config_path = data_config_path
|
58 |
+
self.dataset = dataset
|
59 |
+
self.train_split = train_split
|
60 |
+
self.val_split = val_split
|
61 |
+
self.test_split = test_split
|
62 |
+
self.show_progress_for_encode = show_progress_for_encode
|
63 |
+
self.train_sampler_name = train_sampler
|
64 |
+
self.dataloader_kwargs = dataloader_kwargs
|
65 |
+
|
66 |
+
self._data: Dict[str, DatasetType] = {}
|
67 |
+
|
68 |
+
@property
|
69 |
+
def num_train(self) -> int:
|
70 |
+
if self.train_split is None:
|
71 |
+
raise ValueError("no train_split assigned")
|
72 |
+
data_train = self._data.get(self.train_split, None)
|
73 |
+
if data_train is None:
|
74 |
+
raise ValueError("can not get train size if setup() was not yet called")
|
75 |
+
if isinstance(data_train, IterableTaskEncodingDataset):
|
76 |
+
raise TypeError("IterableTaskEncodingDataset has no length")
|
77 |
+
return len(data_train)
|
78 |
+
|
79 |
+
def setup(self, stage: str):
|
80 |
+
if stage == "fit":
|
81 |
+
split_names = [self.train_split, self.val_split]
|
82 |
+
elif stage == "validate":
|
83 |
+
split_names = [self.val_split]
|
84 |
+
elif stage == "test":
|
85 |
+
split_names = [self.test_split]
|
86 |
+
else:
|
87 |
+
raise NotImplementedError(f"not implemented for stage={stage} ")
|
88 |
+
|
89 |
+
for split in split_names:
|
90 |
+
if split is None or split not in self.dataset:
|
91 |
+
continue
|
92 |
+
task_encoding_dataset = self.taskmodule.encode(
|
93 |
+
self.dataset[split],
|
94 |
+
encode_target=True,
|
95 |
+
as_dataset=True,
|
96 |
+
show_progress=self.show_progress_for_encode,
|
97 |
+
)
|
98 |
+
if not isinstance(
|
99 |
+
task_encoding_dataset,
|
100 |
+
(TaskEncodingDataset, IterableTaskEncodingDataset),
|
101 |
+
):
|
102 |
+
raise TypeError(
|
103 |
+
f"taskmodule.encode did not return a (Iterable)TaskEncodingDataset, but: {type(task_encoding_dataset)}"
|
104 |
+
)
|
105 |
+
self._data[split] = task_encoding_dataset
|
106 |
+
|
107 |
+
def data_split(self, split: Optional[str] = None) -> DatasetType:
|
108 |
+
if split is None or split not in self._data:
|
109 |
+
raise ValueError(f"data for split={split} not available")
|
110 |
+
return self._data[split]
|
111 |
+
|
112 |
+
def get_train_sampler(
|
113 |
+
self,
|
114 |
+
sampler_name: str,
|
115 |
+
dataset: DatasetType,
|
116 |
+
) -> Sampler:
|
117 |
+
if sampler_name == "imbalanced_dataset":
|
118 |
+
# for now, this work only with targets that have a single entry
|
119 |
+
return ImbalancedDatasetSampler(
|
120 |
+
dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds]
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
raise ValueError(f"unknown sampler name: {sampler_name}")
|
124 |
+
|
125 |
+
def train_dataloader(self):
|
126 |
+
ds = self.data_split(self.train_split)
|
127 |
+
if self.train_sampler_name is not None:
|
128 |
+
sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds)
|
129 |
+
else:
|
130 |
+
sampler = None
|
131 |
+
return DataLoader(
|
132 |
+
dataset=ds,
|
133 |
+
sampler=sampler,
|
134 |
+
collate_fn=self.taskmodule.collate,
|
135 |
+
# don't shuffle streamed datasets or if we use a sampler
|
136 |
+
shuffle=not (isinstance(ds, IterableTaskEncodingDataset) or sampler is not None),
|
137 |
+
**self.dataloader_kwargs,
|
138 |
+
)
|
139 |
+
|
140 |
+
def val_dataloader(self):
|
141 |
+
return DataLoader(
|
142 |
+
dataset=self.data_split(self.val_split),
|
143 |
+
collate_fn=self.taskmodule.collate,
|
144 |
+
shuffle=False,
|
145 |
+
**self.dataloader_kwargs,
|
146 |
+
)
|
147 |
+
|
148 |
+
def test_dataloader(self):
|
149 |
+
return DataLoader(
|
150 |
+
dataset=self.data_split(self.test_split),
|
151 |
+
collate_fn=self.taskmodule.collate,
|
152 |
+
shuffle=False,
|
153 |
+
**self.dataloader_kwargs,
|
154 |
+
)
|
src/dataset/__init__.py
ADDED
File without changes
|
src/dataset/processing.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Type, Union
|
2 |
+
|
3 |
+
from pie_datasets import Dataset, DatasetDict
|
4 |
+
from pytorch_ie import Document
|
5 |
+
from pytorch_ie.utils.hydra import resolve_optional_document_type, resolve_target
|
6 |
+
|
7 |
+
|
8 |
+
# TODO: simply use use DatasetDict.map() with set_batch_size_to_split_size=True and
|
9 |
+
# batched=True instead when https://github.com/ArneBinder/pie-datasets/pull/155 is merged
|
10 |
+
def apply_func_to_splits(
|
11 |
+
dataset: DatasetDict,
|
12 |
+
function: Union[str, Callable],
|
13 |
+
result_document_type: Type[Document],
|
14 |
+
**kwargs
|
15 |
+
):
|
16 |
+
resolved_func = resolve_target(function)
|
17 |
+
resolved_document_type = resolve_optional_document_type(document_type=result_document_type)
|
18 |
+
result_dict = dict()
|
19 |
+
split: Dataset
|
20 |
+
for split_name, split in dataset.items():
|
21 |
+
converted_dataset = split.map(
|
22 |
+
function=resolved_func,
|
23 |
+
batched=True,
|
24 |
+
batch_size=len(split),
|
25 |
+
result_document_type=resolved_document_type,
|
26 |
+
**kwargs
|
27 |
+
)
|
28 |
+
result_dict[split_name] = converted_dataset
|
29 |
+
return DatasetDict(result_dict)
|
src/demo/__init__.py
ADDED
File without changes
|
src/demo/annotation_utils.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Optional, Sequence, Union
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
|
6 |
+
|
7 |
+
# this is required to dynamically load the PIE models
|
8 |
+
from pie_modules.models import * # noqa: F403
|
9 |
+
from pie_modules.taskmodules import * # noqa: F403
|
10 |
+
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
11 |
+
from pytorch_ie import Pipeline
|
12 |
+
from pytorch_ie.annotations import LabeledSpan
|
13 |
+
from pytorch_ie.auto import AutoPipeline
|
14 |
+
from pytorch_ie.documents import (
|
15 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
16 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
17 |
+
)
|
18 |
+
|
19 |
+
# this is required to dynamically load the PIE models
|
20 |
+
from pytorch_ie.models import * # noqa: F403
|
21 |
+
from pytorch_ie.taskmodules import * # noqa: F403
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
def annotate_document(
|
27 |
+
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
28 |
+
argumentation_model: Pipeline,
|
29 |
+
handle_parts_of_same: bool = False,
|
30 |
+
) -> Union[
|
31 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
32 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
33 |
+
]:
|
34 |
+
"""Annotate a document with the provided pipeline.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
document: The document to annotate.
|
38 |
+
argumentation_model: The pipeline to use for annotation.
|
39 |
+
handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
|
40 |
+
"""
|
41 |
+
|
42 |
+
# execute prediction pipeline
|
43 |
+
argumentation_model(document)
|
44 |
+
|
45 |
+
if handle_parts_of_same:
|
46 |
+
merger = SpansViaRelationMerger(
|
47 |
+
relation_layer="binary_relations",
|
48 |
+
link_relation_label="parts_of_same",
|
49 |
+
create_multi_spans=True,
|
50 |
+
result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
51 |
+
result_field_mapping={
|
52 |
+
"labeled_spans": "labeled_multi_spans",
|
53 |
+
"binary_relations": "binary_relations",
|
54 |
+
"labeled_partitions": "labeled_partitions",
|
55 |
+
},
|
56 |
+
)
|
57 |
+
document = merger(document)
|
58 |
+
|
59 |
+
return document
|
60 |
+
|
61 |
+
|
62 |
+
def create_document(
|
63 |
+
text: str, doc_id: str, split_regex: Optional[str] = None
|
64 |
+
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
65 |
+
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
|
66 |
+
text.
|
67 |
+
|
68 |
+
Parameters:
|
69 |
+
text: The text to process.
|
70 |
+
doc_id: The ID of the document.
|
71 |
+
split_regex: A regular expression pattern to use for splitting the text into partitions.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
The processed document.
|
75 |
+
"""
|
76 |
+
|
77 |
+
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
|
78 |
+
id=doc_id, text=text, metadata={}
|
79 |
+
)
|
80 |
+
if split_regex is not None:
|
81 |
+
partitioner = RegexPartitioner(
|
82 |
+
pattern=split_regex, partition_layer_name="labeled_partitions"
|
83 |
+
)
|
84 |
+
document = partitioner(document)
|
85 |
+
else:
|
86 |
+
# add single partition from the whole text (the model only considers text in partitions)
|
87 |
+
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
|
88 |
+
return document
|
89 |
+
|
90 |
+
|
91 |
+
def load_argumentation_model(
|
92 |
+
model_name: str,
|
93 |
+
revision: Optional[str] = None,
|
94 |
+
device: str = "cpu",
|
95 |
+
) -> Pipeline:
|
96 |
+
try:
|
97 |
+
# the Pipeline class expects an integer for the device
|
98 |
+
if device == "cuda":
|
99 |
+
pipeline_device = 0
|
100 |
+
elif device.startswith("cuda:"):
|
101 |
+
pipeline_device = int(device.split(":")[1])
|
102 |
+
elif device == "cpu":
|
103 |
+
pipeline_device = -1
|
104 |
+
else:
|
105 |
+
raise gr.Error(f"Invalid device: {device}")
|
106 |
+
|
107 |
+
model = AutoPipeline.from_pretrained(
|
108 |
+
model_name,
|
109 |
+
device=pipeline_device,
|
110 |
+
num_workers=0,
|
111 |
+
taskmodule_kwargs=dict(revision=revision),
|
112 |
+
model_kwargs=dict(revision=revision),
|
113 |
+
)
|
114 |
+
gr.Info(
|
115 |
+
f"Loaded argumentation model: model_name={model_name}, revision={revision}, device={device}"
|
116 |
+
)
|
117 |
+
except Exception as e:
|
118 |
+
raise gr.Error(f"Failed to load argumentation model: {e}")
|
119 |
+
|
120 |
+
return model
|
121 |
+
|
122 |
+
|
123 |
+
def set_relation_types(
|
124 |
+
argumentation_model: Pipeline,
|
125 |
+
default: Optional[Sequence[str]] = None,
|
126 |
+
) -> gr.Dropdown:
|
127 |
+
if isinstance(argumentation_model.taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
|
128 |
+
relation_types = argumentation_model.taskmodule.labels_per_layer["binary_relations"]
|
129 |
+
else:
|
130 |
+
raise gr.Error("Unsupported taskmodule for relation types")
|
131 |
+
|
132 |
+
return gr.Dropdown(
|
133 |
+
choices=relation_types,
|
134 |
+
label="Argumentative Relation Types",
|
135 |
+
value=default,
|
136 |
+
multiselect=True,
|
137 |
+
)
|
src/demo/backend_utils.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import tempfile
|
5 |
+
from typing import Iterable, List, Optional, Sequence
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import pandas as pd
|
9 |
+
from pie_datasets import Dataset, IterableDataset, load_dataset
|
10 |
+
from pytorch_ie import Pipeline
|
11 |
+
from pytorch_ie.documents import (
|
12 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
13 |
+
)
|
14 |
+
|
15 |
+
from src.demo.annotation_utils import annotate_document, create_document
|
16 |
+
from src.demo.data_utils import load_text_from_arxiv
|
17 |
+
from src.demo.rendering_utils import (
|
18 |
+
RENDER_WITH_DISPLACY,
|
19 |
+
RENDER_WITH_PRETTY_TABLE,
|
20 |
+
render_displacy,
|
21 |
+
render_pretty_table,
|
22 |
+
)
|
23 |
+
from src.demo.retriever_utils import get_text_spans_and_relations_from_document
|
24 |
+
from src.langchain_modules import (
|
25 |
+
DocumentAwareSpanRetriever,
|
26 |
+
DocumentAwareSpanRetrieverWithRelations,
|
27 |
+
)
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
def add_annotated_pie_documents(
|
33 |
+
retriever: DocumentAwareSpanRetriever,
|
34 |
+
pie_documents: Sequence[TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions],
|
35 |
+
use_predicted_annotations: bool,
|
36 |
+
verbose: bool = False,
|
37 |
+
) -> None:
|
38 |
+
if verbose:
|
39 |
+
gr.Info(f"Create span embeddings for {len(pie_documents)} documents...")
|
40 |
+
num_docs_before = len(retriever.docstore)
|
41 |
+
retriever.add_pie_documents(pie_documents, use_predicted_annotations=use_predicted_annotations)
|
42 |
+
# number of documents that were overwritten
|
43 |
+
num_overwritten_docs = num_docs_before + len(pie_documents) - len(retriever.docstore)
|
44 |
+
# warn if documents were overwritten
|
45 |
+
if num_overwritten_docs > 0:
|
46 |
+
gr.Warning(f"{num_overwritten_docs} documents were overwritten.")
|
47 |
+
|
48 |
+
|
49 |
+
def process_texts(
|
50 |
+
texts: Iterable[str],
|
51 |
+
doc_ids: Iterable[str],
|
52 |
+
argumentation_model: Pipeline,
|
53 |
+
retriever: DocumentAwareSpanRetriever,
|
54 |
+
split_regex_escaped: Optional[str],
|
55 |
+
handle_parts_of_same: bool = False,
|
56 |
+
verbose: bool = False,
|
57 |
+
) -> None:
|
58 |
+
# check that doc_ids are unique
|
59 |
+
if len(set(doc_ids)) != len(list(doc_ids)):
|
60 |
+
raise gr.Error("Document IDs must be unique.")
|
61 |
+
pie_documents = [
|
62 |
+
create_document(text=text, doc_id=doc_id, split_regex=split_regex_escaped)
|
63 |
+
for text, doc_id in zip(texts, doc_ids)
|
64 |
+
]
|
65 |
+
if verbose:
|
66 |
+
gr.Info(f"Annotate {len(pie_documents)} documents...")
|
67 |
+
pie_documents = [
|
68 |
+
annotate_document(
|
69 |
+
document=pie_document,
|
70 |
+
argumentation_model=argumentation_model,
|
71 |
+
handle_parts_of_same=handle_parts_of_same,
|
72 |
+
)
|
73 |
+
for pie_document in pie_documents
|
74 |
+
]
|
75 |
+
add_annotated_pie_documents(
|
76 |
+
retriever=retriever,
|
77 |
+
pie_documents=pie_documents,
|
78 |
+
use_predicted_annotations=True,
|
79 |
+
verbose=verbose,
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
def add_annotated_pie_documents_from_dataset(
|
84 |
+
retriever: DocumentAwareSpanRetriever, verbose: bool = False, **load_dataset_kwargs
|
85 |
+
) -> None:
|
86 |
+
try:
|
87 |
+
gr.Info(
|
88 |
+
"Loading PIE dataset with parameters:\n" + json.dumps(load_dataset_kwargs, indent=2)
|
89 |
+
)
|
90 |
+
dataset = load_dataset(**load_dataset_kwargs)
|
91 |
+
if not isinstance(dataset, (Dataset, IterableDataset)):
|
92 |
+
raise gr.Error("Loaded dataset is not of type PIE (Iterable)Dataset.")
|
93 |
+
dataset_converted = dataset.to_document_type(
|
94 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
|
95 |
+
)
|
96 |
+
add_annotated_pie_documents(
|
97 |
+
retriever=retriever,
|
98 |
+
pie_documents=dataset_converted,
|
99 |
+
use_predicted_annotations=False,
|
100 |
+
verbose=verbose,
|
101 |
+
)
|
102 |
+
except Exception as e:
|
103 |
+
raise gr.Error(f"Failed to load dataset: {e}")
|
104 |
+
|
105 |
+
|
106 |
+
def wrapped_process_text(
|
107 |
+
doc_id: str, text: str, retriever: DocumentAwareSpanRetriever, **kwargs
|
108 |
+
) -> str:
|
109 |
+
try:
|
110 |
+
process_texts(doc_ids=[doc_id], texts=[text], retriever=retriever, **kwargs)
|
111 |
+
except Exception as e:
|
112 |
+
raise gr.Error(f"Failed to process text: {e}")
|
113 |
+
# Return as dict and document to avoid serialization issues
|
114 |
+
return doc_id
|
115 |
+
|
116 |
+
|
117 |
+
def process_uploaded_files(
|
118 |
+
file_names: List[str],
|
119 |
+
retriever: DocumentAwareSpanRetriever,
|
120 |
+
layer_captions: dict[str, str],
|
121 |
+
**kwargs,
|
122 |
+
) -> pd.DataFrame:
|
123 |
+
try:
|
124 |
+
doc_ids = []
|
125 |
+
texts = []
|
126 |
+
for file_name in file_names:
|
127 |
+
if file_name.lower().endswith(".txt"):
|
128 |
+
# read the file content
|
129 |
+
with open(file_name, "r", encoding="utf-8") as f:
|
130 |
+
text = f.read()
|
131 |
+
base_file_name = os.path.basename(file_name)
|
132 |
+
doc_ids.append(base_file_name)
|
133 |
+
texts.append(text)
|
134 |
+
else:
|
135 |
+
raise gr.Error(f"Unsupported file format: {file_name}")
|
136 |
+
process_texts(texts=texts, doc_ids=doc_ids, retriever=retriever, verbose=True, **kwargs)
|
137 |
+
except Exception as e:
|
138 |
+
raise gr.Error(f"Failed to process uploaded files: {e}")
|
139 |
+
|
140 |
+
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
|
141 |
+
|
142 |
+
|
143 |
+
def wrapped_add_annotated_pie_documents_from_dataset(
|
144 |
+
retriever: DocumentAwareSpanRetriever, verbose: bool, layer_captions: dict[str, str], **kwargs
|
145 |
+
) -> pd.DataFrame:
|
146 |
+
try:
|
147 |
+
add_annotated_pie_documents_from_dataset(retriever=retriever, verbose=verbose, **kwargs)
|
148 |
+
except Exception as e:
|
149 |
+
raise gr.Error(f"Failed to add annotated PIE documents from dataset: {e}")
|
150 |
+
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
|
151 |
+
|
152 |
+
|
153 |
+
def download_processed_documents(
|
154 |
+
retriever: DocumentAwareSpanRetriever,
|
155 |
+
file_name: str = "retriever_store",
|
156 |
+
) -> Optional[str]:
|
157 |
+
if len(retriever.docstore) == 0:
|
158 |
+
gr.Warning("No documents to download.")
|
159 |
+
return None
|
160 |
+
|
161 |
+
# zip the directory
|
162 |
+
file_path = os.path.join(tempfile.gettempdir(), file_name)
|
163 |
+
|
164 |
+
gr.Info(f"Zipping the retriever store to '{file_name}' ...")
|
165 |
+
result_file_path = retriever.save_to_archive(base_name=file_path, format="zip")
|
166 |
+
|
167 |
+
return result_file_path
|
168 |
+
|
169 |
+
|
170 |
+
def upload_processed_documents(
|
171 |
+
file_name: str,
|
172 |
+
retriever: DocumentAwareSpanRetriever,
|
173 |
+
layer_captions: dict[str, str],
|
174 |
+
) -> pd.DataFrame:
|
175 |
+
# load the documents from the zip file or directory
|
176 |
+
retriever.load_from_disc(file_name)
|
177 |
+
# return the overview of the document store
|
178 |
+
return retriever.docstore.overview(layer_captions=layer_captions, use_predictions=True)
|
179 |
+
|
180 |
+
|
181 |
+
def process_text_from_arxiv(
|
182 |
+
arxiv_id: str, retriever: DocumentAwareSpanRetriever, abstract_only: bool = False, **kwargs
|
183 |
+
) -> str:
|
184 |
+
try:
|
185 |
+
text, doc_id = load_text_from_arxiv(arxiv_id=arxiv_id, abstract_only=abstract_only)
|
186 |
+
except Exception as e:
|
187 |
+
raise gr.Error(f"Failed to load text from arXiv: {e}")
|
188 |
+
return wrapped_process_text(doc_id=doc_id, text=text, retriever=retriever, **kwargs)
|
189 |
+
|
190 |
+
|
191 |
+
def render_annotated_document(
|
192 |
+
retriever: DocumentAwareSpanRetrieverWithRelations,
|
193 |
+
document_id: str,
|
194 |
+
render_with: str,
|
195 |
+
render_kwargs_json: str,
|
196 |
+
) -> str:
|
197 |
+
text, spans, span_id2idx, relations = get_text_spans_and_relations_from_document(
|
198 |
+
retriever=retriever, document_id=document_id
|
199 |
+
)
|
200 |
+
|
201 |
+
render_kwargs = json.loads(render_kwargs_json)
|
202 |
+
if render_with == RENDER_WITH_PRETTY_TABLE:
|
203 |
+
html = render_pretty_table(
|
204 |
+
text=text,
|
205 |
+
spans=spans,
|
206 |
+
span_id2idx=span_id2idx,
|
207 |
+
binary_relations=relations,
|
208 |
+
**render_kwargs,
|
209 |
+
)
|
210 |
+
elif render_with == RENDER_WITH_DISPLACY:
|
211 |
+
html = render_displacy(
|
212 |
+
text=text,
|
213 |
+
spans=spans,
|
214 |
+
span_id2idx=span_id2idx,
|
215 |
+
binary_relations=relations,
|
216 |
+
**render_kwargs,
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
raise ValueError(f"Unknown render_with value: {render_with}")
|
220 |
+
|
221 |
+
return html
|
src/demo/data_utils.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import re
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
import arxiv
|
6 |
+
import gradio as gr
|
7 |
+
import requests
|
8 |
+
from bs4 import BeautifulSoup
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
def clean_spaces(text: str) -> str:
|
14 |
+
# replace all multiple spaces with a single space
|
15 |
+
text = re.sub(" +", " ", text)
|
16 |
+
# reduce more than two newlines to two newlines
|
17 |
+
text = re.sub("\n\n+", "\n\n", text)
|
18 |
+
# remove leading and trailing whitespaces
|
19 |
+
text = text.strip()
|
20 |
+
return text
|
21 |
+
|
22 |
+
|
23 |
+
def get_cleaned_arxiv_paper_text(html_content: str) -> str:
|
24 |
+
# parse the HTML content with BeautifulSoup
|
25 |
+
soup = BeautifulSoup(html_content, "html.parser")
|
26 |
+
# get alerts (this is one div with classes "package-alerts" and "ltx_document")
|
27 |
+
alerts = soup.find("div", class_="package-alerts ltx_document")
|
28 |
+
# get the "article" html element
|
29 |
+
article = soup.find("article")
|
30 |
+
article_text = article.get_text()
|
31 |
+
# cleanup the text
|
32 |
+
article_text_clean = clean_spaces(article_text)
|
33 |
+
return article_text_clean
|
34 |
+
|
35 |
+
|
36 |
+
def load_text_from_arxiv(arxiv_id: str, abstract_only: bool = False) -> Tuple[str, str]:
|
37 |
+
|
38 |
+
search_by_id = arxiv.Search(id_list=[arxiv_id])
|
39 |
+
try:
|
40 |
+
result = list(arxiv.Client().results(search_by_id))
|
41 |
+
except arxiv.HTTPError as e:
|
42 |
+
raise gr.Error(f"Failed to fetch arXiv data: {e}")
|
43 |
+
if len(result) == 0:
|
44 |
+
raise gr.Error(f"Could not find any paper with arXiv ID '{arxiv_id}'")
|
45 |
+
first_result = result[0]
|
46 |
+
if abstract_only:
|
47 |
+
abstract_clean = first_result.summary.replace("\n", " ")
|
48 |
+
return abstract_clean, first_result.entry_id
|
49 |
+
if "/abs/" not in first_result.entry_id:
|
50 |
+
raise gr.Error(
|
51 |
+
f"Could not create the HTML URL for arXiv ID '{arxiv_id}' because its entry ID has "
|
52 |
+
f"an unexpected format: {first_result.entry_id}"
|
53 |
+
)
|
54 |
+
html_url = first_result.entry_id.replace("/abs/", "/html/")
|
55 |
+
request_result = requests.get(html_url)
|
56 |
+
if request_result.status_code != 200:
|
57 |
+
raise gr.Error(
|
58 |
+
f"Could not fetch the HTML content for arXiv ID '{arxiv_id}', status code: "
|
59 |
+
f"{request_result.status_code}"
|
60 |
+
)
|
61 |
+
html_content = request_result.text
|
62 |
+
text_clean = get_cleaned_arxiv_paper_text(html_content)
|
63 |
+
return text_clean, html_url
|
src/demo/frontend_utils.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Union
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
|
7 |
+
# see https://github.com/gradio-app/gradio/issues/9288#issuecomment-2356163329
|
8 |
+
def get_fix_df_height_css(css_class: str, max_height: int) -> str:
|
9 |
+
# return ".qa-pairs .table-wrap {min-height: 170px; max-height: 170px;}"
|
10 |
+
return "." + css_class + " .table-wrap {max-height: " + str(max_height) + "px;}"
|
11 |
+
|
12 |
+
|
13 |
+
def escape_regex(regex: str) -> str:
|
14 |
+
# "double escape" the backslashes
|
15 |
+
result = regex.encode("unicode_escape").decode("utf-8")
|
16 |
+
return result
|
17 |
+
|
18 |
+
|
19 |
+
def unescape_regex(regex: str) -> str:
|
20 |
+
# reverse of escape_regex
|
21 |
+
result = regex.encode("utf-8").decode("unicode_escape")
|
22 |
+
return result
|
23 |
+
|
24 |
+
|
25 |
+
def open_accordion():
|
26 |
+
return gr.Accordion(open=True)
|
27 |
+
|
28 |
+
|
29 |
+
def close_accordion():
|
30 |
+
return gr.Accordion(open=False)
|
31 |
+
|
32 |
+
|
33 |
+
def change_tab(id: Union[int, str]):
|
34 |
+
return gr.Tabs(selected=id)
|
35 |
+
|
36 |
+
|
37 |
+
def get_cell_for_fixed_column_from_df(
|
38 |
+
evt: gr.SelectData,
|
39 |
+
df: pd.DataFrame,
|
40 |
+
column: str,
|
41 |
+
) -> Any:
|
42 |
+
"""Get the value of the fixed column for the selected row in the DataFrame.
|
43 |
+
This is required can *not* with a lambda function because that will not get
|
44 |
+
the evt parameter.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
evt: The event object.
|
48 |
+
df: The DataFrame.
|
49 |
+
column: The name of the column.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
The value of the fixed column for the selected row.
|
53 |
+
"""
|
54 |
+
row_idx, col_idx = evt.index
|
55 |
+
doc_id = df.iloc[row_idx][column]
|
56 |
+
return doc_id
|
src/demo/rendering_utils.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
from collections import defaultdict
|
4 |
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
5 |
+
|
6 |
+
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
|
7 |
+
|
8 |
+
from .rendering_utils_displacy import EntityRenderer
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
RENDER_WITH_DISPLACY = "displacy"
|
13 |
+
RENDER_WITH_PRETTY_TABLE = "pretty_table"
|
14 |
+
AVAILABLE_RENDER_MODES = [RENDER_WITH_DISPLACY, RENDER_WITH_PRETTY_TABLE]
|
15 |
+
|
16 |
+
# adjusted from rendering_utils_displacy.TPL_ENT
|
17 |
+
TPL_ENT_WITH_ID = """
|
18 |
+
<mark class="entity" data-entity-id="{entity_id}" data-slice-idx="{slice_idx}" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
|
19 |
+
{text}
|
20 |
+
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
|
21 |
+
</mark>
|
22 |
+
"""
|
23 |
+
|
24 |
+
HIGHLIGHT_SPANS_JS = """
|
25 |
+
() => {
|
26 |
+
function maybeSetColor(entity, colorAttributeKey, colorDictKey) {
|
27 |
+
var color = entity.getAttribute('data-color-' + colorAttributeKey);
|
28 |
+
// if color is a json string, parse it and use the value at colorDictKey
|
29 |
+
try {
|
30 |
+
const colors = JSON.parse(color);
|
31 |
+
color = colors[colorDictKey];
|
32 |
+
} catch (e) {}
|
33 |
+
if (color) {
|
34 |
+
entity.style.backgroundColor = color;
|
35 |
+
entity.style.color = '#000';
|
36 |
+
}
|
37 |
+
}
|
38 |
+
|
39 |
+
function highlightRelationArguments(entityId) {
|
40 |
+
const entities = document.querySelectorAll('.entity');
|
41 |
+
// reset all entities
|
42 |
+
entities.forEach(entity => {
|
43 |
+
const color = entity.getAttribute('data-color-original');
|
44 |
+
entity.style.backgroundColor = color;
|
45 |
+
entity.style.color = '';
|
46 |
+
});
|
47 |
+
|
48 |
+
if (entityId !== null) {
|
49 |
+
var visitedEntities = new Set();
|
50 |
+
// highlight selected entity
|
51 |
+
// get all elements with attribute data-entity-id==entityId
|
52 |
+
const selectedEntityParts = document.querySelectorAll(`[data-entity-id="${entityId}"]`);
|
53 |
+
selectedEntityParts.forEach(selectedEntityPart => {
|
54 |
+
const label = selectedEntityPart.getAttribute('data-label');
|
55 |
+
maybeSetColor(selectedEntityPart, 'selected', label);
|
56 |
+
visitedEntities.add(selectedEntityPart);
|
57 |
+
}); // <-- Corrected closing parenthesis here
|
58 |
+
// if there is at least one part, get the first one and ...
|
59 |
+
if (selectedEntityParts.length > 0) {
|
60 |
+
const selectedEntity = selectedEntityParts[0];
|
61 |
+
|
62 |
+
// ... highlight tails and ...
|
63 |
+
const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails'));
|
64 |
+
relationTailsAndLabels.forEach(relationTail => {
|
65 |
+
const tailEntityId = relationTail['entity-id'];
|
66 |
+
const tailEntityParts = document.querySelectorAll(`[data-entity-id="${tailEntityId}"]`);
|
67 |
+
tailEntityParts.forEach(tailEntity => {
|
68 |
+
const label = relationTail['label'];
|
69 |
+
maybeSetColor(tailEntity, 'tail', label);
|
70 |
+
visitedEntities.add(tailEntity);
|
71 |
+
}); // <-- Corrected closing parenthesis here
|
72 |
+
}); // <-- Corrected closing parenthesis here
|
73 |
+
// .. highlight heads
|
74 |
+
const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads'));
|
75 |
+
relationHeadsAndLabels.forEach(relationHead => {
|
76 |
+
const headEntityId = relationHead['entity-id'];
|
77 |
+
const headEntityParts = document.querySelectorAll(`[data-entity-id="${headEntityId}"]`);
|
78 |
+
headEntityParts.forEach(headEntity => {
|
79 |
+
const label = relationHead['label'];
|
80 |
+
maybeSetColor(headEntity, 'head', label);
|
81 |
+
visitedEntities.add(headEntity);
|
82 |
+
}); // <-- Corrected closing parenthesis here
|
83 |
+
}); // <-- Corrected closing parenthesis here
|
84 |
+
}
|
85 |
+
|
86 |
+
// highlight other entities
|
87 |
+
entities.forEach(entity => {
|
88 |
+
if (!visitedEntities.has(entity)) {
|
89 |
+
const label = entity.getAttribute('data-label');
|
90 |
+
maybeSetColor(entity, 'other', label);
|
91 |
+
}
|
92 |
+
});
|
93 |
+
}
|
94 |
+
}
|
95 |
+
function setHoverAduId(entityId) {
|
96 |
+
// get the textarea element that holds the reference adu id
|
97 |
+
let hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
|
98 |
+
// set the value of the input field
|
99 |
+
hoverAduIdDiv.value = entityId;
|
100 |
+
// trigger an input event to update the state
|
101 |
+
var event = new Event('input');
|
102 |
+
hoverAduIdDiv.dispatchEvent(event);
|
103 |
+
}
|
104 |
+
function setReferenceAduIdFromHover() {
|
105 |
+
// get the hover adu id
|
106 |
+
const hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
|
107 |
+
// get the value of the input field
|
108 |
+
const entityId = hoverAduIdDiv.value;
|
109 |
+
// get the textarea element that holds the reference adu id
|
110 |
+
let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea');
|
111 |
+
// set the value of the input field
|
112 |
+
referenceAduIdDiv.value = entityId;
|
113 |
+
// trigger an input event to update the state
|
114 |
+
var event = new Event('input');
|
115 |
+
referenceAduIdDiv.dispatchEvent(event);
|
116 |
+
}
|
117 |
+
|
118 |
+
const entities = document.querySelectorAll('.entity');
|
119 |
+
entities.forEach(entity => {
|
120 |
+
// make the cursor a pointer
|
121 |
+
entity.style.cursor = 'pointer';
|
122 |
+
const alreadyHasListener = entity.getAttribute('data-has-listener');
|
123 |
+
if (alreadyHasListener) {
|
124 |
+
return;
|
125 |
+
}
|
126 |
+
entity.addEventListener('mouseover', () => {
|
127 |
+
const entityId = entity.getAttribute('data-entity-id');
|
128 |
+
highlightRelationArguments(entityId);
|
129 |
+
setHoverAduId(entityId);
|
130 |
+
});
|
131 |
+
entity.addEventListener('mouseout', () => {
|
132 |
+
highlightRelationArguments(null);
|
133 |
+
});
|
134 |
+
entity.setAttribute('data-has-listener', 'true');
|
135 |
+
});
|
136 |
+
const entityContainer = document.querySelector('.entities');
|
137 |
+
if (entityContainer) {
|
138 |
+
entityContainer.addEventListener('click', () => {
|
139 |
+
setReferenceAduIdFromHover();
|
140 |
+
});
|
141 |
+
// make the cursor a pointer
|
142 |
+
// entityContainer.style.cursor = 'pointer';
|
143 |
+
}
|
144 |
+
}
|
145 |
+
"""
|
146 |
+
|
147 |
+
|
148 |
+
def render_pretty_table(
|
149 |
+
text: str,
|
150 |
+
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
|
151 |
+
span_id2idx: Dict[str, int],
|
152 |
+
binary_relations: Sequence[BinaryRelation],
|
153 |
+
**render_kwargs,
|
154 |
+
):
|
155 |
+
from prettytable import PrettyTable
|
156 |
+
|
157 |
+
t = PrettyTable()
|
158 |
+
t.field_names = ["head", "tail", "relation"]
|
159 |
+
t.align = "l"
|
160 |
+
for relation in list(binary_relations) + list(binary_relations):
|
161 |
+
t.add_row([str(relation.head), str(relation.tail), relation.label])
|
162 |
+
|
163 |
+
html = t.get_html_string(format=True)
|
164 |
+
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
|
165 |
+
|
166 |
+
return html
|
167 |
+
|
168 |
+
|
169 |
+
def render_displacy(
|
170 |
+
text: str,
|
171 |
+
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
|
172 |
+
span_id2idx: Dict[str, int],
|
173 |
+
binary_relations: Sequence[BinaryRelation],
|
174 |
+
inject_relations=True,
|
175 |
+
colors_hover=None,
|
176 |
+
entity_options={},
|
177 |
+
**render_kwargs,
|
178 |
+
):
|
179 |
+
|
180 |
+
ents: List[Dict[str, Any]] = []
|
181 |
+
for entity_id, idx in span_id2idx.items():
|
182 |
+
labeled_span = spans[idx]
|
183 |
+
# pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
|
184 |
+
# on hover and to inject the relation data.
|
185 |
+
if isinstance(labeled_span, LabeledSpan):
|
186 |
+
ents.append(
|
187 |
+
{
|
188 |
+
"start": labeled_span.start,
|
189 |
+
"end": labeled_span.end,
|
190 |
+
"label": labeled_span.label,
|
191 |
+
"params": {"entity_id": entity_id, "slice_idx": 0},
|
192 |
+
}
|
193 |
+
)
|
194 |
+
elif isinstance(labeled_span, LabeledMultiSpan):
|
195 |
+
for i, (start, end) in enumerate(labeled_span.slices):
|
196 |
+
ents.append(
|
197 |
+
{
|
198 |
+
"start": start,
|
199 |
+
"end": end,
|
200 |
+
"label": labeled_span.label,
|
201 |
+
"params": {"entity_id": entity_id, "slice_idx": i},
|
202 |
+
}
|
203 |
+
)
|
204 |
+
else:
|
205 |
+
raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}")
|
206 |
+
|
207 |
+
ents_sorted = sorted(ents, key=lambda x: (x["start"], x["end"]))
|
208 |
+
spacy_doc = {
|
209 |
+
"text": text,
|
210 |
+
# the ents MUST be sorted by start and end
|
211 |
+
"ents": ents_sorted,
|
212 |
+
"title": None,
|
213 |
+
}
|
214 |
+
|
215 |
+
# copy to avoid modifying the original options
|
216 |
+
entity_options = entity_options.copy()
|
217 |
+
# use the custom template with the entity ID
|
218 |
+
entity_options["template"] = TPL_ENT_WITH_ID
|
219 |
+
renderer = EntityRenderer(options=entity_options)
|
220 |
+
html = renderer.render([spacy_doc], page=True, minify=True).strip()
|
221 |
+
|
222 |
+
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
|
223 |
+
if inject_relations:
|
224 |
+
html = inject_relation_data(
|
225 |
+
html,
|
226 |
+
spans=spans,
|
227 |
+
span_id2idx=span_id2idx,
|
228 |
+
binary_relations=binary_relations,
|
229 |
+
additional_colors=colors_hover,
|
230 |
+
)
|
231 |
+
return html
|
232 |
+
|
233 |
+
|
234 |
+
def inject_relation_data(
|
235 |
+
html: str,
|
236 |
+
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
|
237 |
+
span_id2idx: Dict[str, int],
|
238 |
+
binary_relations: Sequence[BinaryRelation],
|
239 |
+
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
|
240 |
+
) -> str:
|
241 |
+
from bs4 import BeautifulSoup
|
242 |
+
|
243 |
+
# Parse the HTML using BeautifulSoup
|
244 |
+
soup = BeautifulSoup(html, "html.parser")
|
245 |
+
|
246 |
+
entity2tails = defaultdict(list)
|
247 |
+
entity2heads = defaultdict(list)
|
248 |
+
for relation in binary_relations:
|
249 |
+
entity2heads[relation.tail].append((relation.head, relation.label))
|
250 |
+
entity2tails[relation.head].append((relation.tail, relation.label))
|
251 |
+
|
252 |
+
annotation2id = {spans[span_idx]: span_id for span_id, span_idx in span_id2idx.items()}
|
253 |
+
# Add unique IDs to each entity
|
254 |
+
entities = soup.find_all(class_="entity")
|
255 |
+
for entity in entities:
|
256 |
+
original_color = entity["style"].split("background:")[1].split(";")[0].strip()
|
257 |
+
entity["data-color-original"] = original_color
|
258 |
+
if additional_colors is not None:
|
259 |
+
for key, color in additional_colors.items():
|
260 |
+
entity[f"data-color-{key}"] = (
|
261 |
+
json.dumps(color) if isinstance(color, dict) else color
|
262 |
+
)
|
263 |
+
|
264 |
+
entity_annotation = spans[span_id2idx[entity["data-entity-id"]]]
|
265 |
+
|
266 |
+
# sanity check.
|
267 |
+
if isinstance(entity_annotation, LabeledSpan):
|
268 |
+
annotation_text = entity_annotation.resolve()[1]
|
269 |
+
elif isinstance(entity_annotation, LabeledMultiSpan):
|
270 |
+
slice_idx = int(entity["data-slice-idx"])
|
271 |
+
annotation_text = entity_annotation.resolve()[1][slice_idx]
|
272 |
+
else:
|
273 |
+
raise ValueError(f"Unsupported entity type: {type(entity_annotation)}")
|
274 |
+
annotation_text_without_newline = annotation_text.replace("\n", "")
|
275 |
+
# Just check the start, because the text has the label attached to the end
|
276 |
+
if not entity.text.startswith(annotation_text_without_newline):
|
277 |
+
logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
|
278 |
+
|
279 |
+
entity["data-label"] = entity_annotation.label
|
280 |
+
entity["data-relation-tails"] = json.dumps(
|
281 |
+
[
|
282 |
+
{"entity-id": annotation2id[tail], "label": label}
|
283 |
+
for tail, label in entity2tails.get(entity_annotation, [])
|
284 |
+
if tail in annotation2id
|
285 |
+
]
|
286 |
+
)
|
287 |
+
entity["data-relation-heads"] = json.dumps(
|
288 |
+
[
|
289 |
+
{"entity-id": annotation2id[head], "label": label}
|
290 |
+
for head, label in entity2heads.get(entity_annotation, [])
|
291 |
+
if head in annotation2id
|
292 |
+
]
|
293 |
+
)
|
294 |
+
|
295 |
+
# Return the modified HTML as a string
|
296 |
+
return str(soup)
|
src/demo/rendering_utils_displacy.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This code is mainly taken from
|
2 |
+
# https://github.com/explosion/spaCy/blob/master/spacy/displacy/templates.py, and from
|
3 |
+
# https://github.com/explosion/spaCy/blob/master/spacy/displacy/render.py.
|
4 |
+
|
5 |
+
# Setting explicit height and max-width: none on the SVG is required for
|
6 |
+
# Jupyter to render it properly in a cell
|
7 |
+
|
8 |
+
TPL_DEP_SVG = """
|
9 |
+
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:lang="{lang}" id="{id}" class="displacy" width="{width}" height="{height}" direction="{dir}" style="max-width: none; height: {height}px; color: {color}; background: {bg}; font-family: {font}; direction: {dir}">{content}</svg>
|
10 |
+
"""
|
11 |
+
|
12 |
+
|
13 |
+
TPL_DEP_WORDS = """
|
14 |
+
<text class="displacy-token" fill="currentColor" text-anchor="middle" y="{y}">
|
15 |
+
<tspan class="displacy-word" fill="currentColor" x="{x}">{text}</tspan>
|
16 |
+
<tspan class="displacy-tag" dy="2em" fill="currentColor" x="{x}">{tag}</tspan>
|
17 |
+
</text>
|
18 |
+
"""
|
19 |
+
|
20 |
+
|
21 |
+
TPL_DEP_WORDS_LEMMA = """
|
22 |
+
<text class="displacy-token" fill="currentColor" text-anchor="middle" y="{y}">
|
23 |
+
<tspan class="displacy-word" fill="currentColor" x="{x}">{text}</tspan>
|
24 |
+
<tspan class="displacy-lemma" dy="2em" fill="currentColor" x="{x}">{lemma}</tspan>
|
25 |
+
<tspan class="displacy-tag" dy="2em" fill="currentColor" x="{x}">{tag}</tspan>
|
26 |
+
</text>
|
27 |
+
"""
|
28 |
+
|
29 |
+
|
30 |
+
TPL_DEP_ARCS = """
|
31 |
+
<g class="displacy-arrow">
|
32 |
+
<path class="displacy-arc" id="arrow-{id}-{i}" stroke-width="{stroke}px" d="{arc}" fill="none" stroke="currentColor"/>
|
33 |
+
<text dy="1.25em" style="font-size: 0.8em; letter-spacing: 1px">
|
34 |
+
<textPath xlink:href="#arrow-{id}-{i}" class="displacy-label" startOffset="50%" side="{label_side}" fill="currentColor" text-anchor="middle">{label}</textPath>
|
35 |
+
</text>
|
36 |
+
<path class="displacy-arrowhead" d="{head}" fill="currentColor"/>
|
37 |
+
</g>
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
TPL_FIGURE = """
|
42 |
+
<figure style="margin-bottom: 6rem">{content}</figure>
|
43 |
+
"""
|
44 |
+
|
45 |
+
TPL_TITLE = """
|
46 |
+
<h2 style="margin: 0">{title}</h2>
|
47 |
+
"""
|
48 |
+
|
49 |
+
|
50 |
+
TPL_ENTS = """
|
51 |
+
<div class="entities" style="line-height: 2.5; direction: {dir}">{content}</div>
|
52 |
+
"""
|
53 |
+
|
54 |
+
|
55 |
+
TPL_ENT = """
|
56 |
+
<mark class="entity" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
|
57 |
+
{text}
|
58 |
+
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
|
59 |
+
</mark>
|
60 |
+
"""
|
61 |
+
|
62 |
+
TPL_ENT_RTL = """
|
63 |
+
<mark class="entity" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em">
|
64 |
+
{text}
|
65 |
+
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-right: 0.5rem">{label}</span>
|
66 |
+
</mark>
|
67 |
+
"""
|
68 |
+
|
69 |
+
|
70 |
+
TPL_PAGE = """
|
71 |
+
<!DOCTYPE html>
|
72 |
+
<html lang="{lang}">
|
73 |
+
<head>
|
74 |
+
<title>displaCy</title>
|
75 |
+
</head>
|
76 |
+
|
77 |
+
<body style="font-size: 16px; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol'; padding: 4rem 2rem; direction: {dir}">{content}</body>
|
78 |
+
</html>
|
79 |
+
"""
|
80 |
+
|
81 |
+
|
82 |
+
DEFAULT_LANG = "en"
|
83 |
+
DEFAULT_DIR = "ltr"
|
84 |
+
|
85 |
+
|
86 |
+
def minify_html(html):
|
87 |
+
"""Perform a template-specific, rudimentary HTML minification for displaCy.
|
88 |
+
Disclaimer: NOT a general-purpose solution, only removes indentation and
|
89 |
+
newlines.
|
90 |
+
|
91 |
+
html (unicode): Markup to minify.
|
92 |
+
RETURNS (unicode): "Minified" HTML.
|
93 |
+
"""
|
94 |
+
return html.strip().replace(" ", "").replace("\n", "")
|
95 |
+
|
96 |
+
|
97 |
+
def escape_html(text):
|
98 |
+
"""Replace <, >, &, " with their HTML encoded representation. Intended to prevent HTML errors
|
99 |
+
in rendered displaCy markup.
|
100 |
+
|
101 |
+
text (unicode): The original text. RETURNS (unicode): Equivalent text to be safely used within
|
102 |
+
HTML.
|
103 |
+
"""
|
104 |
+
text = text.replace("&", "&")
|
105 |
+
text = text.replace("<", "<")
|
106 |
+
text = text.replace(">", ">")
|
107 |
+
text = text.replace('"', """)
|
108 |
+
return text
|
109 |
+
|
110 |
+
|
111 |
+
class EntityRenderer(object):
|
112 |
+
"""Render named entities as HTML."""
|
113 |
+
|
114 |
+
style = "ent"
|
115 |
+
|
116 |
+
def __init__(self, options={}):
|
117 |
+
"""Initialise dependency renderer.
|
118 |
+
|
119 |
+
options (dict): Visualiser-specific options (colors, ents)
|
120 |
+
"""
|
121 |
+
colors = {
|
122 |
+
"ORG": "#7aecec",
|
123 |
+
"PRODUCT": "#bfeeb7",
|
124 |
+
"GPE": "#feca74",
|
125 |
+
"LOC": "#ff9561",
|
126 |
+
"PERSON": "#aa9cfc",
|
127 |
+
"NORP": "#c887fb",
|
128 |
+
"FACILITY": "#9cc9cc",
|
129 |
+
"EVENT": "#ffeb80",
|
130 |
+
"LAW": "#ff8197",
|
131 |
+
"LANGUAGE": "#ff8197",
|
132 |
+
"WORK_OF_ART": "#f0d0ff",
|
133 |
+
"DATE": "#bfe1d9",
|
134 |
+
"TIME": "#bfe1d9",
|
135 |
+
"MONEY": "#e4e7d2",
|
136 |
+
"QUANTITY": "#e4e7d2",
|
137 |
+
"ORDINAL": "#e4e7d2",
|
138 |
+
"CARDINAL": "#e4e7d2",
|
139 |
+
"PERCENT": "#e4e7d2",
|
140 |
+
}
|
141 |
+
# user_colors = registry.displacy_colors.get_all()
|
142 |
+
# for user_color in user_colors.values():
|
143 |
+
# colors.update(user_color)
|
144 |
+
colors.update(options.get("colors", {}))
|
145 |
+
self.default_color = "#ddd"
|
146 |
+
self.colors = colors
|
147 |
+
self.ents = options.get("ents", None)
|
148 |
+
self.direction = DEFAULT_DIR
|
149 |
+
self.lang = DEFAULT_LANG
|
150 |
+
|
151 |
+
template = options.get("template")
|
152 |
+
if template:
|
153 |
+
self.ent_template = template
|
154 |
+
else:
|
155 |
+
if self.direction == "rtl":
|
156 |
+
self.ent_template = TPL_ENT_RTL
|
157 |
+
else:
|
158 |
+
self.ent_template = TPL_ENT
|
159 |
+
|
160 |
+
def render(self, parsed, page=False, minify=False):
|
161 |
+
"""Render complete markup.
|
162 |
+
|
163 |
+
parsed (list): Dependency parses to render. page (bool): Render parses wrapped as full HTML
|
164 |
+
page. minify (bool): Minify HTML markup. RETURNS (unicode): Rendered HTML markup.
|
165 |
+
"""
|
166 |
+
rendered = []
|
167 |
+
for i, p in enumerate(parsed):
|
168 |
+
if i == 0:
|
169 |
+
settings = p.get("settings", {})
|
170 |
+
self.direction = settings.get("direction", DEFAULT_DIR)
|
171 |
+
self.lang = settings.get("lang", DEFAULT_LANG)
|
172 |
+
rendered.append(self.render_ents(p["text"], p["ents"], p.get("title")))
|
173 |
+
if page:
|
174 |
+
docs = "".join([TPL_FIGURE.format(content=doc) for doc in rendered])
|
175 |
+
markup = TPL_PAGE.format(content=docs, lang=self.lang, dir=self.direction)
|
176 |
+
else:
|
177 |
+
markup = "".join(rendered)
|
178 |
+
if minify:
|
179 |
+
return minify_html(markup)
|
180 |
+
return markup
|
181 |
+
|
182 |
+
def render_ents(self, text, spans, title):
|
183 |
+
"""Render entities in text.
|
184 |
+
|
185 |
+
text (unicode): Original text. spans (list): Individual entity spans and their start, end
|
186 |
+
and label. title (unicode or None): Document title set in Doc.user_data['title'].
|
187 |
+
"""
|
188 |
+
markup = ""
|
189 |
+
offset = 0
|
190 |
+
for span in spans:
|
191 |
+
label = span["label"]
|
192 |
+
start = span["start"]
|
193 |
+
end = span["end"]
|
194 |
+
additional_params = span.get("params", {})
|
195 |
+
entity = escape_html(text[start:end])
|
196 |
+
fragments = text[offset:start].split("\n")
|
197 |
+
for i, fragment in enumerate(fragments):
|
198 |
+
markup += escape_html(fragment)
|
199 |
+
if len(fragments) > 1 and i != len(fragments) - 1:
|
200 |
+
markup += "<br/>"
|
201 |
+
if self.ents is None or label.upper() in self.ents:
|
202 |
+
color = self.colors.get(label.upper(), self.default_color)
|
203 |
+
ent_settings = {"label": label, "text": entity, "bg": color}
|
204 |
+
ent_settings.update(additional_params)
|
205 |
+
markup += self.ent_template.format(**ent_settings)
|
206 |
+
else:
|
207 |
+
markup += entity
|
208 |
+
offset = end
|
209 |
+
fragments = text[offset:].split("\n")
|
210 |
+
for i, fragment in enumerate(fragments):
|
211 |
+
markup += escape_html(fragment)
|
212 |
+
if len(fragments) > 1 and i != len(fragments) - 1:
|
213 |
+
markup += "<br/>"
|
214 |
+
markup = TPL_ENTS.format(content=markup, dir=self.direction)
|
215 |
+
if title:
|
216 |
+
markup = TPL_TITLE.format(title=title) + markup
|
217 |
+
return markup
|
src/demo/retrieve_and_dump_all_relevant.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrootutils
|
2 |
+
|
3 |
+
root = pyrootutils.setup_root(
|
4 |
+
search_from=__file__,
|
5 |
+
indicator=[".project-root"],
|
6 |
+
pythonpath=True,
|
7 |
+
dotenv=True,
|
8 |
+
)
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import logging
|
12 |
+
|
13 |
+
from src.demo.retriever_utils import (
|
14 |
+
retrieve_all_relevant_spans,
|
15 |
+
retrieve_all_relevant_spans_for_all_documents,
|
16 |
+
retrieve_relevant_spans,
|
17 |
+
)
|
18 |
+
from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
|
25 |
+
parser = argparse.ArgumentParser()
|
26 |
+
parser.add_argument(
|
27 |
+
"-c",
|
28 |
+
"--config_path",
|
29 |
+
type=str,
|
30 |
+
default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml",
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--data_path",
|
34 |
+
type=str,
|
35 |
+
required=True,
|
36 |
+
help="Path to a zip or directory containing a retriever dump.",
|
37 |
+
)
|
38 |
+
parser.add_argument("-k", "--top_k", type=int, default=10)
|
39 |
+
parser.add_argument("-t", "--threshold", type=float, default=0.95)
|
40 |
+
parser.add_argument(
|
41 |
+
"-o",
|
42 |
+
"--output_path",
|
43 |
+
type=str,
|
44 |
+
required=True,
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--query_doc_id",
|
48 |
+
type=str,
|
49 |
+
default=None,
|
50 |
+
help="If provided, retrieve all spans for only this query document.",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--query_span_id",
|
54 |
+
type=str,
|
55 |
+
default=None,
|
56 |
+
help="If provided, retrieve all spans for only this query span.",
|
57 |
+
)
|
58 |
+
args = parser.parse_args()
|
59 |
+
|
60 |
+
logging.basicConfig(
|
61 |
+
format="%(asctime)s %(levelname)-8s %(message)s",
|
62 |
+
level=logging.INFO,
|
63 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
64 |
+
)
|
65 |
+
|
66 |
+
if not args.output_path.endswith(".json"):
|
67 |
+
raise ValueError("only support json output")
|
68 |
+
|
69 |
+
logger.info(f"instantiating retriever from {args.config_path}...")
|
70 |
+
retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file(
|
71 |
+
args.config_path
|
72 |
+
)
|
73 |
+
logger.info(f"loading data from {args.data_path}...")
|
74 |
+
retriever.load_from_disc(args.data_path)
|
75 |
+
|
76 |
+
search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
|
77 |
+
logger.info(f"use search_kwargs: {search_kwargs}")
|
78 |
+
|
79 |
+
if args.query_span_id is not None:
|
80 |
+
logger.warning(f"retrieving results for single span: {args.query_span_id}")
|
81 |
+
all_spans_for_all_documents = retrieve_relevant_spans(
|
82 |
+
retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
|
83 |
+
)
|
84 |
+
elif args.query_doc_id is not None:
|
85 |
+
logger.warning(f"retrieving results for single document: {args.query_doc_id}")
|
86 |
+
all_spans_for_all_documents = retrieve_all_relevant_spans(
|
87 |
+
retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents(
|
91 |
+
retriever=retriever, **search_kwargs
|
92 |
+
)
|
93 |
+
|
94 |
+
if all_spans_for_all_documents is None:
|
95 |
+
logger.warning("no relevant spans found in any document")
|
96 |
+
exit(0)
|
97 |
+
|
98 |
+
logger.info(f"dumping results to {args.output_path}...")
|
99 |
+
all_spans_for_all_documents.to_json(args.output_path)
|
100 |
+
|
101 |
+
logger.info("done")
|
src/demo/retriever_utils.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Dict, Optional, Sequence, Tuple, Union
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import pandas as pd
|
6 |
+
from pytorch_ie import Annotation
|
7 |
+
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
|
8 |
+
from typing_extensions import Protocol
|
9 |
+
|
10 |
+
from src.langchain_modules import DocumentAwareSpanRetriever
|
11 |
+
from src.langchain_modules.span_retriever import (
|
12 |
+
DocumentAwareSpanRetrieverWithRelations,
|
13 |
+
_parse_config,
|
14 |
+
)
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
def get_document_as_dict(retriever: DocumentAwareSpanRetriever, doc_id: str) -> Dict:
|
20 |
+
document = retriever.get_document(doc_id=doc_id)
|
21 |
+
return retriever.docstore.as_dict(document)
|
22 |
+
|
23 |
+
|
24 |
+
def load_retriever(
|
25 |
+
retriever_config_str: str,
|
26 |
+
config_format: str,
|
27 |
+
device: str = "cpu",
|
28 |
+
previous_retriever: Optional[DocumentAwareSpanRetrieverWithRelations] = None,
|
29 |
+
) -> DocumentAwareSpanRetrieverWithRelations:
|
30 |
+
try:
|
31 |
+
retriever_config = _parse_config(retriever_config_str, format=config_format)
|
32 |
+
# set device for the embeddings pipeline
|
33 |
+
retriever_config["vectorstore"]["embedding"]["pipeline_kwargs"]["device"] = device
|
34 |
+
result = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config(retriever_config)
|
35 |
+
# if a previous retriever is provided, load all documents and vectors from the previous retriever
|
36 |
+
if previous_retriever is not None:
|
37 |
+
# documents
|
38 |
+
all_doc_ids = list(previous_retriever.docstore.yield_keys())
|
39 |
+
gr.Info(f"Storing {len(all_doc_ids)} documents from previous retriever...")
|
40 |
+
all_docs = previous_retriever.docstore.mget(all_doc_ids)
|
41 |
+
result.docstore.mset([(doc.id, doc) for doc in all_docs])
|
42 |
+
# spans (with vectors)
|
43 |
+
all_span_ids = list(previous_retriever.vectorstore.yield_keys())
|
44 |
+
all_spans = previous_retriever.vectorstore.mget(all_span_ids)
|
45 |
+
result.vectorstore.mset([(span.id, span) for span in all_spans])
|
46 |
+
|
47 |
+
gr.Info("Retriever loaded successfully.")
|
48 |
+
return result
|
49 |
+
except Exception as e:
|
50 |
+
raise gr.Error(f"Failed to load retriever: {e}")
|
51 |
+
|
52 |
+
|
53 |
+
def retrieve_similar_spans(
|
54 |
+
retriever: DocumentAwareSpanRetriever,
|
55 |
+
query_span_id: str,
|
56 |
+
**kwargs,
|
57 |
+
) -> pd.DataFrame:
|
58 |
+
if not query_span_id.strip():
|
59 |
+
raise gr.Error("No query span selected.")
|
60 |
+
try:
|
61 |
+
retrieval_result = retriever.invoke(input=query_span_id, **kwargs)
|
62 |
+
records = []
|
63 |
+
for similar_span_doc in retrieval_result:
|
64 |
+
pie_doc, metadata = retriever.docstore.unwrap_with_metadata(similar_span_doc)
|
65 |
+
span_ann = metadata["attached_span"]
|
66 |
+
records.append(
|
67 |
+
{
|
68 |
+
"doc_id": pie_doc.id,
|
69 |
+
"span_id": similar_span_doc.id,
|
70 |
+
"score": metadata["relevance_score"],
|
71 |
+
"label": span_ann.label,
|
72 |
+
"text": str(span_ann),
|
73 |
+
}
|
74 |
+
)
|
75 |
+
return (
|
76 |
+
pd.DataFrame(records, columns=["doc_id", "score", "label", "text", "span_id"])
|
77 |
+
.sort_values(by="score", ascending=False)
|
78 |
+
.round(3)
|
79 |
+
)
|
80 |
+
except Exception as e:
|
81 |
+
raise gr.Error(f"Failed to retrieve similar ADUs: {e}")
|
82 |
+
|
83 |
+
|
84 |
+
def retrieve_relevant_spans(
|
85 |
+
retriever: DocumentAwareSpanRetriever,
|
86 |
+
query_span_id: str,
|
87 |
+
relation_label_mapping: Optional[dict[str, str]] = None,
|
88 |
+
**kwargs,
|
89 |
+
) -> pd.DataFrame:
|
90 |
+
if not query_span_id.strip():
|
91 |
+
raise gr.Error("No query span selected.")
|
92 |
+
try:
|
93 |
+
relation_label_mapping = relation_label_mapping or {}
|
94 |
+
retrieval_result = retriever.invoke(input=query_span_id, return_related=True, **kwargs)
|
95 |
+
records = []
|
96 |
+
for relevant_span_doc in retrieval_result:
|
97 |
+
pie_doc, metadata = retriever.docstore.unwrap_with_metadata(relevant_span_doc)
|
98 |
+
span_ann = metadata["attached_span"]
|
99 |
+
tail_span_ann = metadata["attached_tail_span"]
|
100 |
+
mapped_relation_label = relation_label_mapping.get(
|
101 |
+
metadata["relation_label"], metadata["relation_label"]
|
102 |
+
)
|
103 |
+
records.append(
|
104 |
+
{
|
105 |
+
"doc_id": pie_doc.id,
|
106 |
+
"type": mapped_relation_label,
|
107 |
+
"rel_score": metadata["relation_score"],
|
108 |
+
"text": str(tail_span_ann),
|
109 |
+
"span_id": relevant_span_doc.id,
|
110 |
+
"label": tail_span_ann.label,
|
111 |
+
"ref_score": metadata["relevance_score"],
|
112 |
+
"ref_label": span_ann.label,
|
113 |
+
"ref_text": str(span_ann),
|
114 |
+
"ref_span_id": metadata["head_id"],
|
115 |
+
}
|
116 |
+
)
|
117 |
+
return (
|
118 |
+
pd.DataFrame(
|
119 |
+
records,
|
120 |
+
columns=[
|
121 |
+
"type",
|
122 |
+
# omitted for now, we get no valid relation scores for the generative model
|
123 |
+
# "rel_score",
|
124 |
+
"ref_score",
|
125 |
+
"label",
|
126 |
+
"text",
|
127 |
+
"ref_label",
|
128 |
+
"ref_text",
|
129 |
+
"doc_id",
|
130 |
+
"span_id",
|
131 |
+
"ref_span_id",
|
132 |
+
],
|
133 |
+
)
|
134 |
+
.sort_values(by=["ref_score"], ascending=False)
|
135 |
+
.round(3)
|
136 |
+
)
|
137 |
+
except Exception as e:
|
138 |
+
raise gr.Error(f"Failed to retrieve relevant ADUs: {e}")
|
139 |
+
|
140 |
+
|
141 |
+
class RetrieverCallable(Protocol):
|
142 |
+
def __call__(
|
143 |
+
self,
|
144 |
+
retriever: DocumentAwareSpanRetriever,
|
145 |
+
query_span_id: str,
|
146 |
+
**kwargs,
|
147 |
+
) -> Optional[pd.DataFrame]:
|
148 |
+
pass
|
149 |
+
|
150 |
+
|
151 |
+
def _retrieve_for_all_spans(
|
152 |
+
retriever: DocumentAwareSpanRetriever,
|
153 |
+
query_doc_id: str,
|
154 |
+
retrieve_func: RetrieverCallable,
|
155 |
+
query_span_id_column: str = "query_span_id",
|
156 |
+
**kwargs,
|
157 |
+
) -> Optional[pd.DataFrame]:
|
158 |
+
if not query_doc_id.strip():
|
159 |
+
raise gr.Error("No query document selected.")
|
160 |
+
try:
|
161 |
+
span_id2idx = retriever.get_span_id2idx_from_doc(query_doc_id)
|
162 |
+
gr.Info(f"Retrieving results for {len(span_id2idx)} ADUs in document {query_doc_id}...")
|
163 |
+
span_results = {
|
164 |
+
query_span_id: retrieve_func(
|
165 |
+
retriever=retriever,
|
166 |
+
query_span_id=query_span_id,
|
167 |
+
**kwargs,
|
168 |
+
)
|
169 |
+
for query_span_id in span_id2idx.keys()
|
170 |
+
}
|
171 |
+
span_results_not_empty = {
|
172 |
+
query_span_id: df
|
173 |
+
for query_span_id, df in span_results.items()
|
174 |
+
if df is not None and not df.empty
|
175 |
+
}
|
176 |
+
|
177 |
+
# add column with query_span_id
|
178 |
+
for query_span_id, query_span_result in span_results_not_empty.items():
|
179 |
+
query_span_result[query_span_id_column] = query_span_id
|
180 |
+
|
181 |
+
if len(span_results_not_empty) == 0:
|
182 |
+
gr.Info(f"No results found for any ADU in document {query_doc_id}.")
|
183 |
+
return None
|
184 |
+
else:
|
185 |
+
result = pd.concat(span_results_not_empty.values(), ignore_index=True)
|
186 |
+
gr.Info(f"Retrieved {len(result)} ADUs for document {query_doc_id}.")
|
187 |
+
return result
|
188 |
+
except Exception as e:
|
189 |
+
raise gr.Error(
|
190 |
+
f'Failed to retrieve results for all ADUs in document "{query_doc_id}": {e}'
|
191 |
+
)
|
192 |
+
|
193 |
+
|
194 |
+
def retrieve_all_similar_spans(
|
195 |
+
retriever: DocumentAwareSpanRetriever,
|
196 |
+
query_doc_id: str,
|
197 |
+
**kwargs,
|
198 |
+
) -> Optional[pd.DataFrame]:
|
199 |
+
return _retrieve_for_all_spans(
|
200 |
+
retriever=retriever,
|
201 |
+
query_doc_id=query_doc_id,
|
202 |
+
retrieve_func=retrieve_similar_spans,
|
203 |
+
**kwargs,
|
204 |
+
)
|
205 |
+
|
206 |
+
|
207 |
+
def retrieve_all_relevant_spans(
|
208 |
+
retriever: DocumentAwareSpanRetriever,
|
209 |
+
query_doc_id: str,
|
210 |
+
**kwargs,
|
211 |
+
) -> Optional[pd.DataFrame]:
|
212 |
+
return _retrieve_for_all_spans(
|
213 |
+
retriever=retriever,
|
214 |
+
query_doc_id=query_doc_id,
|
215 |
+
retrieve_func=retrieve_relevant_spans,
|
216 |
+
**kwargs,
|
217 |
+
)
|
218 |
+
|
219 |
+
|
220 |
+
class RetrieverForAllSpansCallable(Protocol):
|
221 |
+
def __call__(
|
222 |
+
self,
|
223 |
+
retriever: DocumentAwareSpanRetriever,
|
224 |
+
query_doc_id: str,
|
225 |
+
**kwargs,
|
226 |
+
) -> Optional[pd.DataFrame]:
|
227 |
+
pass
|
228 |
+
|
229 |
+
|
230 |
+
def _retrieve_for_all_documents(
|
231 |
+
retriever: DocumentAwareSpanRetriever,
|
232 |
+
retrieve_func: RetrieverForAllSpansCallable,
|
233 |
+
query_doc_id_column: str = "query_doc_id",
|
234 |
+
**kwargs,
|
235 |
+
) -> Optional[pd.DataFrame]:
|
236 |
+
try:
|
237 |
+
all_doc_ids = list(retriever.docstore.yield_keys())
|
238 |
+
gr.Info(f"Retrieving results for {len(all_doc_ids)} documents...")
|
239 |
+
doc_results = {
|
240 |
+
doc_id: retrieve_func(retriever=retriever, query_doc_id=doc_id, **kwargs)
|
241 |
+
for doc_id in all_doc_ids
|
242 |
+
}
|
243 |
+
doc_results_not_empty = {
|
244 |
+
doc_id: df for doc_id, df in doc_results.items() if df is not None and not df.empty
|
245 |
+
}
|
246 |
+
# add column with query_doc_id
|
247 |
+
for doc_id, doc_result in doc_results_not_empty.items():
|
248 |
+
doc_result[query_doc_id_column] = doc_id
|
249 |
+
|
250 |
+
if len(doc_results_not_empty) == 0:
|
251 |
+
gr.Info("No results found for any document.")
|
252 |
+
return None
|
253 |
+
else:
|
254 |
+
result = pd.concat(doc_results_not_empty, ignore_index=True)
|
255 |
+
gr.Info(f"Retrieved {len(result)} ADUs for all documents.")
|
256 |
+
return result
|
257 |
+
except Exception as e:
|
258 |
+
raise gr.Error(f"Failed to retrieve results for all documents: {e}")
|
259 |
+
|
260 |
+
|
261 |
+
def retrieve_all_similar_spans_for_all_documents(
|
262 |
+
retriever: DocumentAwareSpanRetriever,
|
263 |
+
**kwargs,
|
264 |
+
) -> Optional[pd.DataFrame]:
|
265 |
+
return _retrieve_for_all_documents(
|
266 |
+
retriever=retriever,
|
267 |
+
retrieve_func=retrieve_all_similar_spans,
|
268 |
+
**kwargs,
|
269 |
+
)
|
270 |
+
|
271 |
+
|
272 |
+
def retrieve_all_relevant_spans_for_all_documents(
|
273 |
+
retriever: DocumentAwareSpanRetriever,
|
274 |
+
**kwargs,
|
275 |
+
) -> Optional[pd.DataFrame]:
|
276 |
+
return _retrieve_for_all_documents(
|
277 |
+
retriever=retriever,
|
278 |
+
retrieve_func=retrieve_all_relevant_spans,
|
279 |
+
**kwargs,
|
280 |
+
)
|
281 |
+
|
282 |
+
|
283 |
+
def get_text_spans_and_relations_from_document(
|
284 |
+
retriever: DocumentAwareSpanRetrieverWithRelations, document_id: str
|
285 |
+
) -> Tuple[
|
286 |
+
str,
|
287 |
+
Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
|
288 |
+
Dict[str, int],
|
289 |
+
Sequence[BinaryRelation],
|
290 |
+
]:
|
291 |
+
document = retriever.get_document(doc_id=document_id)
|
292 |
+
pie_document = retriever.docstore.unwrap(document)
|
293 |
+
use_predicted_annotations = retriever.use_predicted_annotations(document)
|
294 |
+
spans = retriever.get_base_layer(
|
295 |
+
pie_document=pie_document, use_predicted_annotations=use_predicted_annotations
|
296 |
+
)
|
297 |
+
relations = retriever.get_relation_layer(
|
298 |
+
pie_document=pie_document, use_predicted_annotations=use_predicted_annotations
|
299 |
+
)
|
300 |
+
span_id2idx = retriever.get_span_id2idx_from_doc(document)
|
301 |
+
return pie_document.text, spans, span_id2idx, relations
|
302 |
+
|
303 |
+
|
304 |
+
def get_span_annotation(
|
305 |
+
retriever: DocumentAwareSpanRetriever,
|
306 |
+
span_id: str,
|
307 |
+
) -> Annotation:
|
308 |
+
if span_id.strip() == "":
|
309 |
+
raise gr.Error("No span selected.")
|
310 |
+
try:
|
311 |
+
return retriever.get_span_by_id(span_id=span_id)
|
312 |
+
except Exception as e:
|
313 |
+
raise gr.Error(f"Failed to retrieve span annotation: {e}")
|
src/document/__init__.py
ADDED
File without changes
|
src/document/processing.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from typing import Any, Dict, Iterable, List, Sequence, Set, Tuple, TypeVar, Union
|
5 |
+
|
6 |
+
import networkx as nx
|
7 |
+
from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
|
8 |
+
from pie_modules.documents import TextDocumentWithLabeledMultiSpansAndBinaryRelations
|
9 |
+
from pytorch_ie import AnnotationLayer
|
10 |
+
from pytorch_ie.core import Document
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
D = TypeVar("D", bound=Document)
|
16 |
+
|
17 |
+
|
18 |
+
def _remove_overlapping_entities(
|
19 |
+
entities: Iterable[Dict[str, Any]], relations: Iterable[Dict[str, Any]]
|
20 |
+
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
21 |
+
sorted_entities = sorted(entities, key=lambda span: span["start"])
|
22 |
+
entities_wo_overlap = []
|
23 |
+
skipped_entities = []
|
24 |
+
last_end = 0
|
25 |
+
for entity_dict in sorted_entities:
|
26 |
+
if entity_dict["start"] < last_end:
|
27 |
+
skipped_entities.append(entity_dict)
|
28 |
+
else:
|
29 |
+
entities_wo_overlap.append(entity_dict)
|
30 |
+
last_end = entity_dict["end"]
|
31 |
+
if len(skipped_entities) > 0:
|
32 |
+
logger.warning(f"skipped overlapping entities: {skipped_entities}")
|
33 |
+
valid_entity_ids = set(entity_dict["_id"] for entity_dict in entities_wo_overlap)
|
34 |
+
valid_relations = [
|
35 |
+
relation_dict
|
36 |
+
for relation_dict in relations
|
37 |
+
if relation_dict["head"] in valid_entity_ids and relation_dict["tail"] in valid_entity_ids
|
38 |
+
]
|
39 |
+
return entities_wo_overlap, valid_relations
|
40 |
+
|
41 |
+
|
42 |
+
def remove_overlapping_entities(
|
43 |
+
doc: D,
|
44 |
+
entity_layer_name: str = "entities",
|
45 |
+
relation_layer_name: str = "relations",
|
46 |
+
) -> D:
|
47 |
+
# TODO: use document.add_all_annotations_from_other()
|
48 |
+
document_dict = doc.asdict()
|
49 |
+
entities_wo_overlap, valid_relations = _remove_overlapping_entities(
|
50 |
+
entities=document_dict[entity_layer_name]["annotations"],
|
51 |
+
relations=document_dict[relation_layer_name]["annotations"],
|
52 |
+
)
|
53 |
+
|
54 |
+
document_dict[entity_layer_name] = {
|
55 |
+
"annotations": entities_wo_overlap,
|
56 |
+
"predictions": [],
|
57 |
+
}
|
58 |
+
document_dict[relation_layer_name] = {
|
59 |
+
"annotations": valid_relations,
|
60 |
+
"predictions": [],
|
61 |
+
}
|
62 |
+
new_doc = type(doc).fromdict(document_dict)
|
63 |
+
|
64 |
+
return new_doc
|
65 |
+
|
66 |
+
|
67 |
+
def _merge_spans_via_relation(
|
68 |
+
spans: Sequence[LabeledSpan],
|
69 |
+
relations: Sequence[BinaryRelation],
|
70 |
+
link_relation_label: str,
|
71 |
+
create_multi_spans: bool = True,
|
72 |
+
) -> Tuple[Union[Set[LabeledSpan], Set[LabeledMultiSpan]], Set[BinaryRelation]]:
|
73 |
+
# convert list of relations to a graph to easily calculate connected components to merge
|
74 |
+
g = nx.Graph()
|
75 |
+
link_relations = []
|
76 |
+
other_relations = []
|
77 |
+
for rel in relations:
|
78 |
+
if rel.label == link_relation_label:
|
79 |
+
link_relations.append(rel)
|
80 |
+
# never merge spans that have not the same label
|
81 |
+
if (
|
82 |
+
not (isinstance(rel.head, LabeledSpan) or isinstance(rel.tail, LabeledSpan))
|
83 |
+
or rel.head.label == rel.tail.label
|
84 |
+
):
|
85 |
+
g.add_edge(rel.head, rel.tail)
|
86 |
+
else:
|
87 |
+
logger.debug(
|
88 |
+
f"spans to merge do not have the same label, do not merge them: {rel.head}, {rel.tail}"
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
other_relations.append(rel)
|
92 |
+
|
93 |
+
span_mapping = {}
|
94 |
+
connected_components: Set[LabeledSpan]
|
95 |
+
for connected_components in nx.connected_components(g):
|
96 |
+
# all spans in a connected component have the same label
|
97 |
+
label = list(span.label for span in connected_components)[0]
|
98 |
+
connected_components_sorted = sorted(connected_components, key=lambda span: span.start)
|
99 |
+
if create_multi_spans:
|
100 |
+
new_span = LabeledMultiSpan(
|
101 |
+
slices=tuple((span.start, span.end) for span in connected_components_sorted),
|
102 |
+
label=label,
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
new_span = LabeledSpan(
|
106 |
+
start=min(span.start for span in connected_components_sorted),
|
107 |
+
end=max(span.end for span in connected_components_sorted),
|
108 |
+
label=label,
|
109 |
+
)
|
110 |
+
for span in connected_components_sorted:
|
111 |
+
span_mapping[span] = new_span
|
112 |
+
for span in spans:
|
113 |
+
if span not in span_mapping:
|
114 |
+
if create_multi_spans:
|
115 |
+
span_mapping[span] = LabeledMultiSpan(
|
116 |
+
slices=((span.start, span.end),), label=span.label, score=span.score
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
span_mapping[span] = LabeledSpan(
|
120 |
+
start=span.start, end=span.end, label=span.label, score=span.score
|
121 |
+
)
|
122 |
+
|
123 |
+
new_spans = set(span_mapping.values())
|
124 |
+
new_relations = set(
|
125 |
+
BinaryRelation(
|
126 |
+
head=span_mapping[rel.head],
|
127 |
+
tail=span_mapping[rel.tail],
|
128 |
+
label=rel.label,
|
129 |
+
score=rel.score,
|
130 |
+
)
|
131 |
+
for rel in other_relations
|
132 |
+
)
|
133 |
+
|
134 |
+
return new_spans, new_relations
|
135 |
+
|
136 |
+
|
137 |
+
def merge_spans_via_relation(
|
138 |
+
document: D,
|
139 |
+
relation_layer: str,
|
140 |
+
link_relation_label: str,
|
141 |
+
use_predicted_spans: bool = False,
|
142 |
+
process_predictions: bool = True,
|
143 |
+
create_multi_spans: bool = False,
|
144 |
+
) -> D:
|
145 |
+
|
146 |
+
rel_layer = document[relation_layer]
|
147 |
+
span_layer = rel_layer.target_layer
|
148 |
+
new_gold_spans, new_gold_relations = _merge_spans_via_relation(
|
149 |
+
spans=span_layer,
|
150 |
+
relations=rel_layer,
|
151 |
+
link_relation_label=link_relation_label,
|
152 |
+
create_multi_spans=create_multi_spans,
|
153 |
+
)
|
154 |
+
if process_predictions:
|
155 |
+
new_pred_spans, new_pred_relations = _merge_spans_via_relation(
|
156 |
+
spans=span_layer.predictions if use_predicted_spans else span_layer,
|
157 |
+
relations=rel_layer.predictions,
|
158 |
+
link_relation_label=link_relation_label,
|
159 |
+
create_multi_spans=create_multi_spans,
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
assert not use_predicted_spans
|
163 |
+
new_pred_spans = set(span_layer.predictions.clear())
|
164 |
+
new_pred_relations = set(rel_layer.predictions.clear())
|
165 |
+
|
166 |
+
relation_layer_name = relation_layer
|
167 |
+
span_layer_name = document[relation_layer].target_name
|
168 |
+
if create_multi_spans:
|
169 |
+
doc_dict = document.asdict()
|
170 |
+
for f in document.annotation_fields():
|
171 |
+
doc_dict.pop(f.name)
|
172 |
+
|
173 |
+
result = TextDocumentWithLabeledMultiSpansAndBinaryRelations.fromdict(doc_dict)
|
174 |
+
result.labeled_multi_spans.extend(new_gold_spans)
|
175 |
+
result.labeled_multi_spans.predictions.extend(new_pred_spans)
|
176 |
+
result.binary_relations.extend(new_gold_relations)
|
177 |
+
result.binary_relations.predictions.extend(new_pred_relations)
|
178 |
+
else:
|
179 |
+
result = document.copy(with_annotations=False)
|
180 |
+
result[span_layer_name].extend(new_gold_spans)
|
181 |
+
result[span_layer_name].predictions.extend(new_pred_spans)
|
182 |
+
result[relation_layer_name].extend(new_gold_relations)
|
183 |
+
result[relation_layer_name].predictions.extend(new_pred_relations)
|
184 |
+
|
185 |
+
return result
|
186 |
+
|
187 |
+
|
188 |
+
def remove_partitions_by_labels(
|
189 |
+
document: D, partition_layer: str, label_blacklist: List[str]
|
190 |
+
) -> D:
|
191 |
+
document = document.copy()
|
192 |
+
layer: AnnotationLayer = document[partition_layer]
|
193 |
+
new_partitions = []
|
194 |
+
for partition in layer.clear():
|
195 |
+
if partition.label not in label_blacklist:
|
196 |
+
new_partitions.append(partition)
|
197 |
+
layer.extend(new_partitions)
|
198 |
+
return document
|
199 |
+
|
200 |
+
|
201 |
+
D_text = TypeVar("D_text", bound=Document)
|
202 |
+
|
203 |
+
|
204 |
+
def replace_substrings_in_text(
|
205 |
+
document: D_text, replacements: Dict[str, str], enforce_same_length: bool = True
|
206 |
+
) -> D_text:
|
207 |
+
new_text = document.text
|
208 |
+
for old_str, new_str in replacements.items():
|
209 |
+
if enforce_same_length and len(old_str) != len(new_str):
|
210 |
+
raise ValueError(
|
211 |
+
f'Replacement strings must have the same length, but got "{old_str}" -> "{new_str}"'
|
212 |
+
)
|
213 |
+
new_text = new_text.replace(old_str, new_str)
|
214 |
+
result_dict = document.asdict()
|
215 |
+
result_dict["text"] = new_text
|
216 |
+
result = type(document).fromdict(result_dict)
|
217 |
+
result.text = new_text
|
218 |
+
return result
|
219 |
+
|
220 |
+
|
221 |
+
def replace_substrings_in_text_with_spaces(document: D_text, substrings: Iterable[str]) -> D_text:
|
222 |
+
replacements = {substring: " " * len(substring) for substring in substrings}
|
223 |
+
return replace_substrings_in_text(document, replacements=replacements)
|
src/evaluate.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrootutils
|
2 |
+
|
3 |
+
root = pyrootutils.setup_root(
|
4 |
+
search_from=__file__,
|
5 |
+
indicator=[".project-root"],
|
6 |
+
pythonpath=True,
|
7 |
+
dotenv=True,
|
8 |
+
)
|
9 |
+
|
10 |
+
# ------------------------------------------------------------------------------------ #
|
11 |
+
# `pyrootutils.setup_root(...)` is an optional line at the top of each entry file
|
12 |
+
# that helps to make the environment more robust and convenient
|
13 |
+
#
|
14 |
+
# the main advantages are:
|
15 |
+
# - allows you to keep all entry files in "src/" without installing project as a package
|
16 |
+
# - makes paths and scripts always work no matter where is your current work dir
|
17 |
+
# - automatically loads environment variables from ".env" file if exists
|
18 |
+
#
|
19 |
+
# how it works:
|
20 |
+
# - the line above recursively searches for either ".git" or "pyproject.toml" in present
|
21 |
+
# and parent dirs, to determine the project root dir
|
22 |
+
# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
|
23 |
+
# any place without installing project as a package
|
24 |
+
# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
|
25 |
+
# to make all paths always relative to the project root
|
26 |
+
# - loads environment variables from ".env" file in root dir (if `dotenv=True`)
|
27 |
+
#
|
28 |
+
# you can remove `pyrootutils.setup_root(...)` if you:
|
29 |
+
# 1. either install project as a package or move each entry file to the project root dir
|
30 |
+
# 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
|
31 |
+
# 3. always run entry files from the project root dir
|
32 |
+
#
|
33 |
+
# https://github.com/ashleve/pyrootutils
|
34 |
+
# ------------------------------------------------------------------------------------ #
|
35 |
+
|
36 |
+
from typing import Tuple
|
37 |
+
|
38 |
+
import hydra
|
39 |
+
import pytorch_lightning as pl
|
40 |
+
from omegaconf import DictConfig
|
41 |
+
from pie_datasets import DatasetDict
|
42 |
+
from pie_modules.models import * # noqa: F403
|
43 |
+
from pie_modules.taskmodules import * # noqa: F403
|
44 |
+
from pytorch_ie.core import PyTorchIEModel, TaskModule
|
45 |
+
from pytorch_ie.models import * # noqa: F403
|
46 |
+
from pytorch_ie.taskmodules import * # noqa: F403
|
47 |
+
from pytorch_lightning import Trainer
|
48 |
+
|
49 |
+
from src import utils
|
50 |
+
from src.datamodules import PieDataModule
|
51 |
+
from src.models import * # noqa: F403
|
52 |
+
from src.taskmodules import * # noqa: F403
|
53 |
+
|
54 |
+
log = utils.get_pylogger(__name__)
|
55 |
+
|
56 |
+
|
57 |
+
@utils.task_wrapper
|
58 |
+
def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
|
59 |
+
"""Evaluates given checkpoint on a datamodule testset.
|
60 |
+
|
61 |
+
This method is wrapped in optional @task_wrapper decorator which applies extra utilities
|
62 |
+
before and after the call.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
cfg (DictConfig): Configuration composed by Hydra.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
|
69 |
+
"""
|
70 |
+
|
71 |
+
# Set seed for random number generators in pytorch, numpy and python.random
|
72 |
+
if cfg.get("seed"):
|
73 |
+
pl.seed_everything(cfg.seed, workers=True)
|
74 |
+
|
75 |
+
# Init pytorch-ie dataset
|
76 |
+
log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
|
77 |
+
dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial")
|
78 |
+
|
79 |
+
# Init pytorch-ie taskmodule
|
80 |
+
log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
|
81 |
+
taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")
|
82 |
+
|
83 |
+
# auto-convert the dataset if the metric specifies a document type
|
84 |
+
dataset = taskmodule.convert_dataset(dataset)
|
85 |
+
|
86 |
+
# Init pytorch-ie datamodule
|
87 |
+
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
|
88 |
+
datamodule: PieDataModule = hydra.utils.instantiate(
|
89 |
+
cfg.datamodule, dataset=dataset, taskmodule=taskmodule, _convert_="partial"
|
90 |
+
)
|
91 |
+
|
92 |
+
# Init pytorch-ie model
|
93 |
+
log.info(f"Instantiating model <{cfg.model._target_}>")
|
94 |
+
model: PyTorchIEModel = hydra.utils.instantiate(cfg.model, _convert_="partial")
|
95 |
+
|
96 |
+
# Init lightning loggers
|
97 |
+
logger = utils.instantiate_dict_entries(cfg, "logger")
|
98 |
+
|
99 |
+
# Init lightning trainer
|
100 |
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
101 |
+
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger, _convert_="partial")
|
102 |
+
|
103 |
+
object_dict = {
|
104 |
+
"cfg": cfg,
|
105 |
+
"taskmodule": taskmodule,
|
106 |
+
"dataset": dataset,
|
107 |
+
"model": model,
|
108 |
+
"logger": logger,
|
109 |
+
"trainer": trainer,
|
110 |
+
}
|
111 |
+
|
112 |
+
if logger:
|
113 |
+
log.info("Logging hyperparameters!")
|
114 |
+
utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)
|
115 |
+
|
116 |
+
log.info("Starting testing!")
|
117 |
+
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
|
118 |
+
|
119 |
+
# for predictions use trainer.predict(...)
|
120 |
+
# predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
|
121 |
+
|
122 |
+
metric_dict = trainer.callback_metrics
|
123 |
+
|
124 |
+
return metric_dict, object_dict
|
125 |
+
|
126 |
+
|
127 |
+
@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="evaluate.yaml")
|
128 |
+
def main(cfg: DictConfig) -> None:
|
129 |
+
metric_dict, _ = evaluate(cfg)
|
130 |
+
|
131 |
+
return metric_dict
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
utils.replace_sys_args_with_values_from_files()
|
136 |
+
utils.prepare_omegaconf()
|
137 |
+
main()
|
src/evaluate_documents.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrootutils
|
2 |
+
|
3 |
+
root = pyrootutils.setup_root(
|
4 |
+
search_from=__file__,
|
5 |
+
indicator=[".project-root"],
|
6 |
+
pythonpath=True,
|
7 |
+
dotenv=True,
|
8 |
+
)
|
9 |
+
|
10 |
+
# ------------------------------------------------------------------------------------ #
|
11 |
+
# `pyrootutils.setup_root(...)` is an optional line at the top of each entry file
|
12 |
+
# that helps to make the environment more robust and convenient
|
13 |
+
#
|
14 |
+
# the main advantages are:
|
15 |
+
# - allows you to keep all entry files in "src/" without installing project as a package
|
16 |
+
# - makes paths and scripts always work no matter where is your current work dir
|
17 |
+
# - automatically loads environment variables from ".env" file if exists
|
18 |
+
#
|
19 |
+
# how it works:
|
20 |
+
# - the line above recursively searches for either ".git" or "pyproject.toml" in present
|
21 |
+
# and parent dirs, to determine the project root dir
|
22 |
+
# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
|
23 |
+
# any place without installing project as a package
|
24 |
+
# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
|
25 |
+
# to make all paths always relative to the project root
|
26 |
+
# - loads environment variables from ".env" file in root dir (if `dotenv=True`)
|
27 |
+
#
|
28 |
+
# you can remove `pyrootutils.setup_root(...)` if you:
|
29 |
+
# 1. either install project as a package or move each entry file to the project root dir
|
30 |
+
# 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
|
31 |
+
# 3. always run entry files from the project root dir
|
32 |
+
#
|
33 |
+
# https://github.com/ashleve/pyrootutils
|
34 |
+
# ------------------------------------------------------------------------------------ #
|
35 |
+
|
36 |
+
from typing import Any, Tuple
|
37 |
+
|
38 |
+
import hydra
|
39 |
+
import pytorch_lightning as pl
|
40 |
+
from omegaconf import DictConfig
|
41 |
+
from pie_datasets import DatasetDict
|
42 |
+
from pytorch_ie.core import DocumentMetric
|
43 |
+
from pytorch_ie.metrics import * # noqa: F403
|
44 |
+
|
45 |
+
from src import utils
|
46 |
+
from src.metrics import * # noqa: F403
|
47 |
+
|
48 |
+
log = utils.get_pylogger(__name__)
|
49 |
+
|
50 |
+
|
51 |
+
@utils.task_wrapper
|
52 |
+
def evaluate_documents(cfg: DictConfig) -> Tuple[dict, dict]:
|
53 |
+
"""Evaluates serialized PIE documents.
|
54 |
+
|
55 |
+
This method is wrapped in optional @task_wrapper decorator which applies extra utilities
|
56 |
+
before and after the call.
|
57 |
+
Args:
|
58 |
+
cfg (DictConfig): Configuration composed by Hydra.
|
59 |
+
Returns:
|
60 |
+
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
|
61 |
+
"""
|
62 |
+
|
63 |
+
# Set seed for random number generators in pytorch, numpy and python.random
|
64 |
+
if cfg.get("seed"):
|
65 |
+
pl.seed_everything(cfg.seed, workers=True)
|
66 |
+
|
67 |
+
# Init pytorch-ie dataset
|
68 |
+
log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
|
69 |
+
dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial")
|
70 |
+
|
71 |
+
# Init pytorch-ie taskmodule
|
72 |
+
log.info(f"Instantiating metric <{cfg.metric._target_}>")
|
73 |
+
metric: DocumentMetric = hydra.utils.instantiate(cfg.metric, _convert_="partial")
|
74 |
+
|
75 |
+
# auto-convert the dataset if the metric specifies a document type
|
76 |
+
dataset = metric.convert_dataset(dataset)
|
77 |
+
|
78 |
+
# Init lightning loggers
|
79 |
+
loggers = utils.instantiate_dict_entries(cfg, "logger")
|
80 |
+
|
81 |
+
object_dict = {
|
82 |
+
"cfg": cfg,
|
83 |
+
"dataset": dataset,
|
84 |
+
"metric": metric,
|
85 |
+
"logger": loggers,
|
86 |
+
}
|
87 |
+
|
88 |
+
if loggers:
|
89 |
+
log.info("Logging hyperparameters!")
|
90 |
+
# send hparams to all loggers
|
91 |
+
for logger in loggers:
|
92 |
+
logger.log_hyperparams(cfg)
|
93 |
+
|
94 |
+
splits = cfg.get("splits", None)
|
95 |
+
if splits is None:
|
96 |
+
documents = dataset
|
97 |
+
else:
|
98 |
+
documents = type(dataset)({k: v for k, v in dataset.items() if k in splits})
|
99 |
+
|
100 |
+
metric_dict = metric(documents)
|
101 |
+
|
102 |
+
return metric_dict, object_dict
|
103 |
+
|
104 |
+
|
105 |
+
@hydra.main(
|
106 |
+
version_base="1.2", config_path=str(root / "configs"), config_name="evaluate_documents.yaml"
|
107 |
+
)
|
108 |
+
def main(cfg: DictConfig) -> Any:
|
109 |
+
metric_dict, _ = evaluate_documents(cfg)
|
110 |
+
return metric_dict
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
utils.replace_sys_args_with_values_from_files()
|
115 |
+
utils.prepare_omegaconf()
|
116 |
+
main()
|
src/hydra_callbacks/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .save_job_return_value import SaveJobReturnValueCallback
|
src/hydra_callbacks/save_job_return_value.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any, Dict, Generator, List, Tuple, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
+
from hydra.core.utils import JobReturn
|
12 |
+
from hydra.experimental.callback import Callback
|
13 |
+
from omegaconf import DictConfig
|
14 |
+
|
15 |
+
|
16 |
+
def to_py_obj(obj):
|
17 |
+
"""Convert a PyTorch tensor, Numpy array or python list to a python list.
|
18 |
+
|
19 |
+
Modified version of transformers.utils.generic.to_py_obj.
|
20 |
+
"""
|
21 |
+
if isinstance(obj, dict):
|
22 |
+
return {k: to_py_obj(v) for k, v in obj.items()}
|
23 |
+
elif isinstance(obj, (list, tuple)):
|
24 |
+
return type(obj)(to_py_obj(o) for o in obj)
|
25 |
+
elif isinstance(obj, torch.Tensor):
|
26 |
+
return obj.detach().cpu().tolist()
|
27 |
+
elif isinstance(obj, (np.ndarray, np.number)): # tolist also works on 0d np arrays
|
28 |
+
return obj.tolist()
|
29 |
+
else:
|
30 |
+
return obj
|
31 |
+
|
32 |
+
|
33 |
+
def list_of_dicts_to_dict_of_lists_recursive(list_of_dicts):
|
34 |
+
"""Convert a list of dicts to a dict of lists recursively.
|
35 |
+
|
36 |
+
Example:
|
37 |
+
# works with nested dicts
|
38 |
+
>>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": {"c": 2}}, {"a": 3, "b": {"c": 4}}])
|
39 |
+
{'b': {'c': [2, 4]}, 'a': [1, 3]}
|
40 |
+
# works with incomplete dicts
|
41 |
+
>>> list_of_dicts_to_dict_of_lists_recursive([{"a": 1, "b": 2}, {"a": 3}])
|
42 |
+
{'b': [2, None], 'a': [1, 3]}
|
43 |
+
|
44 |
+
Args:
|
45 |
+
list_of_dicts (List[dict]): A list of dicts.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
dict: A dict of lists.
|
49 |
+
"""
|
50 |
+
if isinstance(list_of_dicts, list):
|
51 |
+
if len(list_of_dicts) == 0:
|
52 |
+
return {}
|
53 |
+
elif isinstance(list_of_dicts[0], dict):
|
54 |
+
keys = set()
|
55 |
+
for d in list_of_dicts:
|
56 |
+
if not isinstance(d, dict):
|
57 |
+
raise ValueError("Not all elements of the list are dicts.")
|
58 |
+
keys.update(d.keys())
|
59 |
+
return {
|
60 |
+
k: list_of_dicts_to_dict_of_lists_recursive(
|
61 |
+
[d.get(k, None) for d in list_of_dicts]
|
62 |
+
)
|
63 |
+
for k in keys
|
64 |
+
}
|
65 |
+
else:
|
66 |
+
return list_of_dicts
|
67 |
+
else:
|
68 |
+
return list_of_dicts
|
69 |
+
|
70 |
+
|
71 |
+
def _flatten_dict_gen(d, parent_key: Tuple[str, ...] = ()) -> Generator:
|
72 |
+
for k, v in d.items():
|
73 |
+
new_key = parent_key + (k,)
|
74 |
+
if isinstance(v, dict):
|
75 |
+
yield from dict(_flatten_dict_gen(v, new_key)).items()
|
76 |
+
else:
|
77 |
+
yield new_key, v
|
78 |
+
|
79 |
+
|
80 |
+
def flatten_dict(d: Dict[str, Any]) -> Dict[Tuple[str, ...], Any]:
|
81 |
+
return dict(_flatten_dict_gen(d))
|
82 |
+
|
83 |
+
|
84 |
+
def unflatten_dict(d: Dict[Tuple[str, ...], Any]) -> Union[Dict[str, Any], Any]:
|
85 |
+
"""Unflattens a dictionary with nested keys.
|
86 |
+
|
87 |
+
Example:
|
88 |
+
>>> d = {("a", "b", "c"): 1, ("a", "b", "d"): 2, ("a", "e"): 3}
|
89 |
+
>>> unflatten_dict(d)
|
90 |
+
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
|
91 |
+
"""
|
92 |
+
result: Dict[str, Any] = {}
|
93 |
+
for k, v in d.items():
|
94 |
+
if len(k) == 0:
|
95 |
+
if len(result) > 1:
|
96 |
+
raise ValueError("Cannot unflatten dictionary with multiple root keys.")
|
97 |
+
return v
|
98 |
+
current = result
|
99 |
+
for key in k[:-1]:
|
100 |
+
current = current.setdefault(key, {})
|
101 |
+
current[k[-1]] = v
|
102 |
+
return result
|
103 |
+
|
104 |
+
|
105 |
+
def overrides_to_identifiers(overrides_per_result: List[List[str]], sep: str = "-") -> List[str]:
|
106 |
+
"""Converts a list of lists of overrides to a list of identifiers. But takes only the overrides
|
107 |
+
into account, that are not identical for all results.
|
108 |
+
|
109 |
+
Example:
|
110 |
+
>>> overrides_per_result = [
|
111 |
+
... ["a=1", "b=2", "c=3"],
|
112 |
+
... ["a=1", "b=2", "c=4"],
|
113 |
+
... ["a=1", "b=3", "c=3"],
|
114 |
+
]
|
115 |
+
>>> overrides_to_identifiers(overrides_per_result)
|
116 |
+
['b=2-c=3', 'b=2-c=4', 'b=3-c=3']
|
117 |
+
|
118 |
+
Args:
|
119 |
+
overrides_per_result (List[List[str]]): A list of lists of overrides.
|
120 |
+
sep (str, optional): The separator to use between the overrides. Defaults to "-".
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
List[str]: A list of identifiers.
|
124 |
+
"""
|
125 |
+
# get the overrides that are not identical for all results
|
126 |
+
overrides_per_result_transposed = np.array(overrides_per_result).T.tolist()
|
127 |
+
indices = [
|
128 |
+
i for i, entries in enumerate(overrides_per_result_transposed) if len(set(entries)) > 1
|
129 |
+
]
|
130 |
+
# convert the overrides to identifiers
|
131 |
+
identifiers = [
|
132 |
+
sep.join([overrides[idx] for idx in indices]) for overrides in overrides_per_result
|
133 |
+
]
|
134 |
+
return identifiers
|
135 |
+
|
136 |
+
|
137 |
+
class SaveJobReturnValueCallback(Callback):
|
138 |
+
"""Save the job return-value in ${output_dir}/{job_return_value_filename}.
|
139 |
+
|
140 |
+
This also works for multi-runs (e.g. sweeps for hyperparameter search). In this case, the result will be saved
|
141 |
+
additionally in a common file in the multi-run log directory. If integrate_multirun_result=True, the
|
142 |
+
job return-values are also aggregated (e.g. mean, min, max) and saved in another file.
|
143 |
+
|
144 |
+
params:
|
145 |
+
-------
|
146 |
+
filenames: str or List[str] (default: "job_return_value.json")
|
147 |
+
The filename(s) of the file(s) to save the job return-value to. If it ends with ".json",
|
148 |
+
the return-value will be saved as a json file. If it ends with ".pkl", the return-value will be
|
149 |
+
saved as a pickle file, if it ends with ".md", the return-value will be saved as a markdown file.
|
150 |
+
integrate_multirun_result: bool (default: True)
|
151 |
+
If True, the job return-values of all jobs from a multi-run will be rearranged into a dict of lists (maybe
|
152 |
+
nested), where the keys are the keys of the job return-values and the values are lists of the corresponding
|
153 |
+
values of all jobs. This is useful if you want to access specific values of all jobs in a multi-run all at once.
|
154 |
+
Also, aggregated values (e.g. mean, min, max) are created for all numeric values and saved in another file.
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
filenames: Union[str, List[str]] = "job_return_value.json",
|
160 |
+
integrate_multirun_result: bool = False,
|
161 |
+
) -> None:
|
162 |
+
self.log = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
163 |
+
self.filenames = [filenames] if isinstance(filenames, str) else filenames
|
164 |
+
self.integrate_multirun_result = integrate_multirun_result
|
165 |
+
self.job_returns: List[JobReturn] = []
|
166 |
+
|
167 |
+
def on_job_end(self, config: DictConfig, job_return: JobReturn, **kwargs: Any) -> None:
|
168 |
+
self.job_returns.append(job_return)
|
169 |
+
output_dir = Path(config.hydra.runtime.output_dir) # / Path(config.hydra.output_subdir)
|
170 |
+
for filename in self.filenames:
|
171 |
+
self._save(obj=job_return.return_value, filename=filename, output_dir=output_dir)
|
172 |
+
|
173 |
+
def on_multirun_end(self, config: DictConfig, **kwargs: Any) -> None:
|
174 |
+
if self.integrate_multirun_result:
|
175 |
+
# rearrange the job return-values of all jobs from a multi-run into a dict of lists (maybe nested),
|
176 |
+
obj = list_of_dicts_to_dict_of_lists_recursive(
|
177 |
+
[jr.return_value for jr in self.job_returns]
|
178 |
+
)
|
179 |
+
# also create an aggregated result
|
180 |
+
# convert to python object to allow selecting numeric columns
|
181 |
+
obj_py = to_py_obj(obj)
|
182 |
+
obj_flat = flatten_dict(obj_py)
|
183 |
+
# create dataframe from flattened dict
|
184 |
+
df_flat = pd.DataFrame(obj_flat)
|
185 |
+
# select only the numeric values
|
186 |
+
df_numbers_only = df_flat.select_dtypes(["number"])
|
187 |
+
cols_removed = set(df_flat.columns) - set(df_numbers_only.columns)
|
188 |
+
if len(cols_removed) > 0:
|
189 |
+
self.log.warning(
|
190 |
+
f"Removed the following columns from the aggregated result because they are not numeric: "
|
191 |
+
f"{cols_removed}"
|
192 |
+
)
|
193 |
+
if len(df_numbers_only.columns) == 0:
|
194 |
+
obj_aggregated = None
|
195 |
+
else:
|
196 |
+
# aggregate the numeric values
|
197 |
+
df_described = df_numbers_only.describe()
|
198 |
+
# add the aggregation keys (e.g. mean, min, ...) as most inner keys and convert back to dict
|
199 |
+
obj_flat_aggregated = df_described.T.stack().to_dict()
|
200 |
+
# unflatten because _save() works better with nested dicts
|
201 |
+
obj_aggregated = unflatten_dict(obj_flat_aggregated)
|
202 |
+
else:
|
203 |
+
# create a dict of the job return-values of all jobs from a multi-run
|
204 |
+
# (_save() works better with nested dicts)
|
205 |
+
ids = overrides_to_identifiers([jr.overrides for jr in self.job_returns])
|
206 |
+
obj = {identifier: jr.return_value for identifier, jr in zip(ids, self.job_returns)}
|
207 |
+
obj_aggregated = None
|
208 |
+
output_dir = Path(config.hydra.sweep.dir)
|
209 |
+
for filename in self.filenames:
|
210 |
+
self._save(
|
211 |
+
obj=obj,
|
212 |
+
filename=filename,
|
213 |
+
output_dir=output_dir,
|
214 |
+
multi_run_result=self.integrate_multirun_result,
|
215 |
+
)
|
216 |
+
# if available, also save the aggregated result
|
217 |
+
if obj_aggregated is not None:
|
218 |
+
file_base_name, ext = os.path.splitext(filename)
|
219 |
+
filename_aggregated = f"{file_base_name}.aggregated{ext}"
|
220 |
+
self._save(obj=obj_aggregated, filename=filename_aggregated, output_dir=output_dir)
|
221 |
+
|
222 |
+
def _save(
|
223 |
+
self, obj: Any, filename: str, output_dir: Path, multi_run_result: bool = False
|
224 |
+
) -> None:
|
225 |
+
self.log.info(f"Saving job_return in {output_dir / filename}")
|
226 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
227 |
+
assert output_dir is not None
|
228 |
+
if filename.endswith(".pkl"):
|
229 |
+
with open(str(output_dir / filename), "wb") as file:
|
230 |
+
pickle.dump(obj, file, protocol=4)
|
231 |
+
elif filename.endswith(".json"):
|
232 |
+
# Convert PyTorch tensors and numpy arrays to native python types
|
233 |
+
obj_py = to_py_obj(obj)
|
234 |
+
with open(str(output_dir / filename), "w") as file:
|
235 |
+
json.dump(obj_py, file, indent=2)
|
236 |
+
elif filename.endswith(".md"):
|
237 |
+
# Convert PyTorch tensors and numpy arrays to native python types
|
238 |
+
obj_py = to_py_obj(obj)
|
239 |
+
obj_py_flat = flatten_dict(obj_py)
|
240 |
+
|
241 |
+
if multi_run_result:
|
242 |
+
# In the case of multi-run, we expect to have multiple values for each key.
|
243 |
+
# We therefore just convert the dict to a pandas DataFrame.
|
244 |
+
result = pd.DataFrame(obj_py_flat)
|
245 |
+
else:
|
246 |
+
# In the case of a single job, we expect to have only one value for each key.
|
247 |
+
# We therefore convert the dict to a pandas Series and ...
|
248 |
+
series = pd.Series(obj_py_flat)
|
249 |
+
if len(series.index.levels) > 1:
|
250 |
+
# ... if the Series has multiple index levels, we create a DataFrame by unstacking the last level.
|
251 |
+
result = series.unstack(-1)
|
252 |
+
else:
|
253 |
+
# ... otherwise we just unpack the one-entry index values and save the resulting Series.
|
254 |
+
series.index = series.index.get_level_values(0)
|
255 |
+
result = series
|
256 |
+
|
257 |
+
with open(str(output_dir / filename), "w") as file:
|
258 |
+
file.write(result.to_markdown())
|
259 |
+
|
260 |
+
else:
|
261 |
+
raise ValueError("Unknown file extension")
|
src/langchain_modules/span_retriever.py
CHANGED
@@ -4,7 +4,7 @@ import uuid
|
|
4 |
from collections import defaultdict
|
5 |
from copy import copy
|
6 |
from enum import Enum
|
7 |
-
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union
|
8 |
|
9 |
from langchain_core.callbacks import (
|
10 |
AsyncCallbackManagerForRetrieverRun,
|
@@ -674,14 +674,14 @@ class DocumentAwareSpanRetriever(BaseRetriever, SerializableStore):
|
|
674 |
|
675 |
def add_pie_documents(
|
676 |
self,
|
677 |
-
documents:
|
678 |
use_predicted_annotations: bool,
|
679 |
metadata: Optional[Dict[str, Any]] = None,
|
680 |
) -> None:
|
681 |
"""Add pie documents to the retriever.
|
682 |
|
683 |
Args:
|
684 |
-
documents:
|
685 |
use_predicted_annotations: Whether to use the predicted annotations or the gold annotations
|
686 |
metadata: Optional metadata to add to each document
|
687 |
"""
|
|
|
4 |
from collections import defaultdict
|
5 |
from copy import copy
|
6 |
from enum import Enum
|
7 |
+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Type, Union
|
8 |
|
9 |
from langchain_core.callbacks import (
|
10 |
AsyncCallbackManagerForRetrieverRun,
|
|
|
674 |
|
675 |
def add_pie_documents(
|
676 |
self,
|
677 |
+
documents: Iterable[TextBasedDocument],
|
678 |
use_predicted_annotations: bool,
|
679 |
metadata: Optional[Dict[str, Any]] = None,
|
680 |
) -> None:
|
681 |
"""Add pie documents to the retriever.
|
682 |
|
683 |
Args:
|
684 |
+
documents: Iterable of pie documents to add
|
685 |
use_predicted_annotations: Whether to use the predicted annotations or the gold annotations
|
686 |
metadata: Optional metadata to add to each document
|
687 |
"""
|
src/metrics/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .coref_sklearn import CorefMetricsSKLearn
|
2 |
+
from .coref_torchmetrics import CorefMetricsTorchmetrics
|
src/metrics/annotation_processor.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, List
|
2 |
+
|
3 |
+
from pytorch_ie.annotations import BinaryRelation, Span, LabeledMultiSpan
|
4 |
+
from pytorch_ie.core import Annotation
|
5 |
+
|
6 |
+
|
7 |
+
def decode_span_without_label(ann: Annotation) -> Tuple[Tuple[int, int], ...]:
|
8 |
+
if isinstance(ann, Span):
|
9 |
+
return (ann.start, ann.end),
|
10 |
+
elif isinstance(ann, LabeledMultiSpan):
|
11 |
+
return ann.slices
|
12 |
+
else:
|
13 |
+
raise ValueError("Annotation must be a Span or LabeledMultiSpan")
|
14 |
+
|
15 |
+
|
16 |
+
def to_binary_relation_without_argument_labels(ann: Annotation) -> Tuple[Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int], ...], str]:
|
17 |
+
if not isinstance(ann, BinaryRelation):
|
18 |
+
raise ValueError("Annotation must be a BinaryRelation")
|
19 |
+
return (
|
20 |
+
decode_span_without_label(ann.head),
|
21 |
+
decode_span_without_label(ann.tail),
|
22 |
+
ann.label,
|
23 |
+
)
|
src/metrics/coref_sklearn.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from pandas import MultiIndex
|
8 |
+
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
|
9 |
+
from pytorch_ie import DocumentMetric
|
10 |
+
from pytorch_ie.core.metric import T
|
11 |
+
from pytorch_ie.utils.hydra import resolve_target
|
12 |
+
from torchmetrics import Metric, MetricCollection
|
13 |
+
|
14 |
+
from src.hydra_callbacks.save_job_return_value import to_py_obj
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
def get_num_total(targets: List[int], preds: List[float]):
|
20 |
+
return len(targets)
|
21 |
+
|
22 |
+
|
23 |
+
def get_num_positives(targets: List[int], preds: List[float], positive_idx: int = 1):
|
24 |
+
return len([v for v in targets if v == positive_idx])
|
25 |
+
|
26 |
+
|
27 |
+
def discretize(
|
28 |
+
values: List[float], threshold: Union[float, List[float], dict]
|
29 |
+
) -> Union[List[float], Dict[Any, List[float]]]:
|
30 |
+
if isinstance(threshold, float):
|
31 |
+
result = (np.array(values) >= threshold).astype(int).tolist()
|
32 |
+
return result
|
33 |
+
if isinstance(threshold, list):
|
34 |
+
return {t: discretize(values=values, threshold=t) for t in threshold} # type: ignore
|
35 |
+
if isinstance(threshold, dict):
|
36 |
+
thresholds = (
|
37 |
+
np.arange(threshold["start"], threshold["end"], threshold["step"]).round(4).tolist()
|
38 |
+
)
|
39 |
+
return discretize(values, threshold=thresholds)
|
40 |
+
raise TypeError(f"threshold has unknown type: {threshold}")
|
41 |
+
|
42 |
+
|
43 |
+
class CorefMetricsSKLearn(DocumentMetric):
|
44 |
+
DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
metrics: Dict[str, str],
|
49 |
+
thresholds: Optional[Dict[str, float]] = None,
|
50 |
+
default_target_idx: int = 0,
|
51 |
+
default_prediction_score: float = 0.0,
|
52 |
+
show_as_markdown: bool = False,
|
53 |
+
markdown_precision: int = 4,
|
54 |
+
plot: bool = False,
|
55 |
+
):
|
56 |
+
self.metrics = {name: resolve_target(metric) for name, metric in metrics.items()}
|
57 |
+
self.thresholds = thresholds or {}
|
58 |
+
thresholds_not_in_metrics = {
|
59 |
+
name: t for name, t in self.thresholds.items() if name not in self.metrics
|
60 |
+
}
|
61 |
+
if len(thresholds_not_in_metrics) > 0:
|
62 |
+
logger.warning(
|
63 |
+
f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}"
|
64 |
+
)
|
65 |
+
self.default_target_idx = default_target_idx
|
66 |
+
self.default_prediction_score = default_prediction_score
|
67 |
+
self.show_as_markdown = show_as_markdown
|
68 |
+
self.markdown_precision = markdown_precision
|
69 |
+
self.plot = plot
|
70 |
+
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
def reset(self) -> None:
|
74 |
+
self._preds: List[float] = []
|
75 |
+
self._targets: List[int] = []
|
76 |
+
|
77 |
+
def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None:
|
78 |
+
target_args2idx = {
|
79 |
+
(rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations
|
80 |
+
}
|
81 |
+
prediction_args2score = {
|
82 |
+
(rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions
|
83 |
+
}
|
84 |
+
all_args = set(target_args2idx) | set(prediction_args2score)
|
85 |
+
all_targets: List[int] = []
|
86 |
+
all_predictions: List[float] = []
|
87 |
+
for args in all_args:
|
88 |
+
target_idx = target_args2idx.get(args, self.default_target_idx)
|
89 |
+
prediction_score = prediction_args2score.get(args, self.default_prediction_score)
|
90 |
+
all_targets.append(target_idx)
|
91 |
+
all_predictions.append(prediction_score)
|
92 |
+
# prediction_scores = torch.tensor(all_predictions)
|
93 |
+
# target_indices = torch.tensor(all_targets)
|
94 |
+
# self.metrics.update(preds=prediction_scores, target=target_indices)
|
95 |
+
self._preds.extend(all_predictions)
|
96 |
+
self._targets.extend(all_targets)
|
97 |
+
|
98 |
+
def do_plot(self):
|
99 |
+
raise NotImplementedError()
|
100 |
+
|
101 |
+
from matplotlib import pyplot as plt
|
102 |
+
|
103 |
+
# Get the number of metrics
|
104 |
+
num_metrics = len(self.metrics)
|
105 |
+
|
106 |
+
# Calculate rows and columns for subplots (aim for a square-like layout)
|
107 |
+
ncols = math.ceil(math.sqrt(num_metrics))
|
108 |
+
nrows = math.ceil(num_metrics / ncols)
|
109 |
+
|
110 |
+
# Create the subplots
|
111 |
+
fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
|
112 |
+
|
113 |
+
# Flatten the ax_list if necessary (in case of multiple rows/columns)
|
114 |
+
ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary
|
115 |
+
|
116 |
+
# Ensure that we pass exactly the number of axes required by metrics
|
117 |
+
ax_list = ax_list[:num_metrics]
|
118 |
+
|
119 |
+
# Plot the metrics using the list of axes
|
120 |
+
self.metrics.plot(ax=ax_list, together=False)
|
121 |
+
|
122 |
+
# Adjust layout to avoid overlapping plots
|
123 |
+
plt.tight_layout()
|
124 |
+
plt.show()
|
125 |
+
|
126 |
+
def _compute(self) -> T:
|
127 |
+
|
128 |
+
if self.plot:
|
129 |
+
self.do_plot()
|
130 |
+
|
131 |
+
result = {}
|
132 |
+
for name, metric in self.metrics.items():
|
133 |
+
|
134 |
+
if name in self.thresholds:
|
135 |
+
preds = discretize(values=self._preds, threshold=self.thresholds[name])
|
136 |
+
else:
|
137 |
+
preds = self._preds
|
138 |
+
if isinstance(preds, dict):
|
139 |
+
metric_results = {
|
140 |
+
t: metric(self._targets, t_preds) for t, t_preds in preds.items()
|
141 |
+
}
|
142 |
+
# just get the max
|
143 |
+
max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1])
|
144 |
+
result[f"{name}-{max_t}"] = max_v
|
145 |
+
else:
|
146 |
+
result[name] = metric(self._targets, preds)
|
147 |
+
|
148 |
+
result = to_py_obj(result)
|
149 |
+
if self.show_as_markdown:
|
150 |
+
import pandas as pd
|
151 |
+
|
152 |
+
series = pd.Series(result)
|
153 |
+
if isinstance(series.index, MultiIndex):
|
154 |
+
if len(series.index.levels) > 1:
|
155 |
+
# in fact, this is not a series anymore
|
156 |
+
series = series.unstack(-1)
|
157 |
+
else:
|
158 |
+
series.index = series.index.get_level_values(0)
|
159 |
+
logger.info(
|
160 |
+
f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
|
161 |
+
)
|
162 |
+
return result
|
src/metrics/coref_torchmetrics.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from pandas import MultiIndex
|
7 |
+
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
|
8 |
+
from pytorch_ie import DocumentMetric
|
9 |
+
from pytorch_ie.core.metric import T
|
10 |
+
from torchmetrics import Metric, MetricCollection
|
11 |
+
|
12 |
+
from src.hydra_callbacks.save_job_return_value import to_py_obj
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class CorefMetricsTorchmetrics(DocumentMetric):
|
18 |
+
DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
metrics: Dict[str, Metric],
|
23 |
+
default_target_idx: int = 0,
|
24 |
+
default_prediction_score: float = 0.0,
|
25 |
+
show_as_markdown: bool = False,
|
26 |
+
markdown_precision: int = 4,
|
27 |
+
plot: bool = False,
|
28 |
+
):
|
29 |
+
self.metrics = MetricCollection(metrics)
|
30 |
+
self.default_target_idx = default_target_idx
|
31 |
+
self.default_prediction_score = default_prediction_score
|
32 |
+
self.show_as_markdown = show_as_markdown
|
33 |
+
self.markdown_precision = markdown_precision
|
34 |
+
self.plot = plot
|
35 |
+
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
def reset(self) -> None:
|
39 |
+
self.metrics.reset()
|
40 |
+
|
41 |
+
def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None:
|
42 |
+
target_args2idx = {
|
43 |
+
(rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations
|
44 |
+
}
|
45 |
+
prediction_args2score = {
|
46 |
+
(rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions
|
47 |
+
}
|
48 |
+
all_args = set(target_args2idx) | set(prediction_args2score)
|
49 |
+
all_targets = []
|
50 |
+
all_predictions = []
|
51 |
+
for args in all_args:
|
52 |
+
target_idx = target_args2idx.get(args, self.default_target_idx)
|
53 |
+
prediction_score = prediction_args2score.get(args, self.default_prediction_score)
|
54 |
+
all_targets.append(target_idx)
|
55 |
+
all_predictions.append(prediction_score)
|
56 |
+
prediction_scores = torch.tensor(all_predictions)
|
57 |
+
target_indices = torch.tensor(all_targets)
|
58 |
+
self.metrics.update(preds=prediction_scores, target=target_indices)
|
59 |
+
|
60 |
+
def do_plot(self):
|
61 |
+
from matplotlib import pyplot as plt
|
62 |
+
|
63 |
+
# Get the number of metrics
|
64 |
+
num_metrics = len(self.metrics)
|
65 |
+
|
66 |
+
# Calculate rows and columns for subplots (aim for a square-like layout)
|
67 |
+
ncols = math.ceil(math.sqrt(num_metrics))
|
68 |
+
nrows = math.ceil(num_metrics / ncols)
|
69 |
+
|
70 |
+
# Create the subplots
|
71 |
+
fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
|
72 |
+
|
73 |
+
# Flatten the ax_list if necessary (in case of multiple rows/columns)
|
74 |
+
ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary
|
75 |
+
|
76 |
+
# Ensure that we pass exactly the number of axes required by metrics
|
77 |
+
ax_list = ax_list[:num_metrics]
|
78 |
+
|
79 |
+
# Plot the metrics using the list of axes
|
80 |
+
self.metrics.plot(ax=ax_list, together=False)
|
81 |
+
|
82 |
+
# Adjust layout to avoid overlapping plots
|
83 |
+
plt.tight_layout()
|
84 |
+
plt.show()
|
85 |
+
|
86 |
+
def _compute(self) -> T:
|
87 |
+
|
88 |
+
if self.plot:
|
89 |
+
self.do_plot()
|
90 |
+
|
91 |
+
result = self.metrics.compute()
|
92 |
+
|
93 |
+
result = to_py_obj(result)
|
94 |
+
if self.show_as_markdown:
|
95 |
+
import pandas as pd
|
96 |
+
|
97 |
+
series = pd.Series(result)
|
98 |
+
if isinstance(series.index, MultiIndex):
|
99 |
+
if len(series.index.levels) > 1:
|
100 |
+
# in fact, this is not a series anymore
|
101 |
+
series = series.unstack(-1)
|
102 |
+
else:
|
103 |
+
series.index = series.index.get_level_values(0)
|
104 |
+
logger.info(
|
105 |
+
f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
|
106 |
+
)
|
107 |
+
return result
|
src/models/__init__.py
CHANGED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sequence_classification_with_pooler import (
|
2 |
+
SequencePairSimilarityModelWithMaxCosineSim,
|
3 |
+
SequencePairSimilarityModelWithPooler2,
|
4 |
+
SequencePairSimilarityModelWithPoolerAndAdapter,
|
5 |
+
)
|
src/models/components/__init__.py
ADDED
File without changes
|
src/models/components/pooler.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor, cat, nn
|
3 |
+
|
4 |
+
|
5 |
+
class SpanMeanPooler(nn.Module):
|
6 |
+
"""Pooler that takes the mean hidden state over spans. If the start or end index is negative, a
|
7 |
+
learned embedding is used. The indices are expected to have the shape [batch_size,
|
8 |
+
num_indices].
|
9 |
+
|
10 |
+
The resulting embeddings are concatenated, so the output shape is [batch_size, num_indices * input_dim].
|
11 |
+
Note this a slightly modified version of the pie_modules.models.components.pooler.SpanMaxPooler,
|
12 |
+
i.e. we changed the aggregation method from torch.amax to torch.mean.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
input_dim: The input dimension of the hidden state.
|
16 |
+
num_indices: The number of indices to pool.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
The pooled hidden states with shape [batch_size, num_indices * input_dim].
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, input_dim: int, num_indices: int = 2, **kwargs):
|
23 |
+
super().__init__(**kwargs)
|
24 |
+
self.input_dim = input_dim
|
25 |
+
self.num_indices = num_indices
|
26 |
+
self.missing_embeddings = nn.Parameter(torch.empty(num_indices, self.input_dim))
|
27 |
+
nn.init.normal_(self.missing_embeddings)
|
28 |
+
|
29 |
+
def forward(
|
30 |
+
self, hidden_state: Tensor, start_indices: Tensor, end_indices: Tensor, **kwargs
|
31 |
+
) -> Tensor:
|
32 |
+
batch_size, seq_len, hidden_size = hidden_state.shape
|
33 |
+
if start_indices.shape[1] != self.num_indices:
|
34 |
+
raise ValueError(
|
35 |
+
f"number of start indices [{start_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
|
36 |
+
)
|
37 |
+
|
38 |
+
if end_indices.shape[1] != self.num_indices:
|
39 |
+
raise ValueError(
|
40 |
+
f"number of end indices [{end_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
|
41 |
+
)
|
42 |
+
|
43 |
+
# check that start_indices are before end_indices
|
44 |
+
mask_both_positive = (start_indices >= 0) & (end_indices >= 0)
|
45 |
+
mask_start_before_end = start_indices < end_indices
|
46 |
+
mask_valid = mask_start_before_end | ~mask_both_positive
|
47 |
+
if not torch.all(mask_valid):
|
48 |
+
raise ValueError(
|
49 |
+
f"values in start_indices have to be smaller than respective values in "
|
50 |
+
f"end_indices, but start_indices=\n{start_indices}\n and end_indices=\n{end_indices}"
|
51 |
+
)
|
52 |
+
|
53 |
+
# times num_indices due to concat
|
54 |
+
result = torch.zeros(
|
55 |
+
batch_size, hidden_size * self.num_indices, device=hidden_state.device
|
56 |
+
)
|
57 |
+
for batch_idx in range(batch_size):
|
58 |
+
current_start_indices = start_indices[batch_idx]
|
59 |
+
current_end_indices = end_indices[batch_idx]
|
60 |
+
current_embeddings = [
|
61 |
+
(
|
62 |
+
torch.mean(
|
63 |
+
hidden_state[
|
64 |
+
batch_idx, current_start_indices[i] : current_end_indices[i], :
|
65 |
+
],
|
66 |
+
dim=0,
|
67 |
+
)
|
68 |
+
if current_start_indices[i] >= 0 and current_end_indices[i] >= 0
|
69 |
+
else self.missing_embeddings[i]
|
70 |
+
)
|
71 |
+
for i in range(self.num_indices)
|
72 |
+
]
|
73 |
+
result[batch_idx] = cat(current_embeddings, 0)
|
74 |
+
|
75 |
+
return result
|
76 |
+
|
77 |
+
@property
|
78 |
+
def output_dim(self) -> int:
|
79 |
+
return self.input_dim * self.num_indices
|
src/models/sequence_classification_with_pooler.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import logging
|
3 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from adapters import AutoAdapterModel
|
8 |
+
from pie_modules.models import SequencePairSimilarityModelWithPooler
|
9 |
+
from pie_modules.models.components.pooler import MENTION_POOLING
|
10 |
+
from pie_modules.models.sequence_classification_with_pooler import (
|
11 |
+
InputType,
|
12 |
+
OutputType,
|
13 |
+
SequenceClassificationModelWithPooler,
|
14 |
+
SequenceClassificationModelWithPoolerBase,
|
15 |
+
TargetType,
|
16 |
+
separate_arguments_by_prefix,
|
17 |
+
)
|
18 |
+
from pytorch_ie import PyTorchIEModel
|
19 |
+
from torch import FloatTensor, Tensor
|
20 |
+
from transformers import AutoConfig, PreTrainedModel
|
21 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
22 |
+
|
23 |
+
from src.models.components.pooler import SpanMeanPooler
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
class SequenceClassificationModelWithPoolerBase2(
|
29 |
+
SequenceClassificationModelWithPoolerBase, abc.ABC
|
30 |
+
):
|
31 |
+
def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]:
|
32 |
+
aggregate = self.pooler_config.get("aggregate", "max")
|
33 |
+
if self.pooler_config["type"] == MENTION_POOLING and aggregate != "max":
|
34 |
+
if aggregate == "mean":
|
35 |
+
pooler_config = dict(self.pooler_config)
|
36 |
+
pooler_config.pop("type")
|
37 |
+
pooler_config.pop("aggregate")
|
38 |
+
pooler = SpanMeanPooler(input_dim=input_dim, **pooler_config)
|
39 |
+
return pooler, pooler.output_dim
|
40 |
+
else:
|
41 |
+
raise ValueError(f"Unknown aggregation method: {aggregate}")
|
42 |
+
else:
|
43 |
+
return super().setup_pooler(input_dim)
|
44 |
+
|
45 |
+
|
46 |
+
class SequenceClassificationModelWithPoolerAndAdapterBase(
|
47 |
+
SequenceClassificationModelWithPoolerBase2, abc.ABC
|
48 |
+
):
|
49 |
+
def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs):
|
50 |
+
self.adapter_name_or_path = adapter_name_or_path
|
51 |
+
super().__init__(**kwargs)
|
52 |
+
|
53 |
+
def setup_base_model(self) -> PreTrainedModel:
|
54 |
+
if self.adapter_name_or_path is None:
|
55 |
+
return super().setup_base_model()
|
56 |
+
else:
|
57 |
+
config = AutoConfig.from_pretrained(self.model_name_or_path)
|
58 |
+
if self.is_from_pretrained:
|
59 |
+
model = AutoAdapterModel.from_config(config=config)
|
60 |
+
else:
|
61 |
+
model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config)
|
62 |
+
# load the adapter in any case (it looks like it is not saved in the state or loaded
|
63 |
+
# from a serialized state)
|
64 |
+
logger.info(f"load adapter: {self.adapter_name_or_path}")
|
65 |
+
model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True)
|
66 |
+
return model
|
67 |
+
|
68 |
+
|
69 |
+
@PyTorchIEModel.register()
|
70 |
+
class SequencePairSimilarityModelWithPooler2(
|
71 |
+
SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerBase2
|
72 |
+
):
|
73 |
+
pass
|
74 |
+
|
75 |
+
|
76 |
+
@PyTorchIEModel.register()
|
77 |
+
class SequencePairSimilarityModelWithPoolerAndAdapter(
|
78 |
+
SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
|
79 |
+
):
|
80 |
+
pass
|
81 |
+
|
82 |
+
|
83 |
+
@PyTorchIEModel.register()
|
84 |
+
class SequenceClassificationModelWithPoolerAndAdapter(
|
85 |
+
SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase
|
86 |
+
):
|
87 |
+
pass
|
88 |
+
|
89 |
+
|
90 |
+
def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor:
|
91 |
+
# Normalize the embeddings
|
92 |
+
embeddings_normalized = F.normalize(embeddings, p=2, dim=1) # Shape: (n, k)
|
93 |
+
embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1) # Shape: (m, k)
|
94 |
+
|
95 |
+
# Compute the cosine similarity matrix
|
96 |
+
cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T) # Shape: (n, m)
|
97 |
+
|
98 |
+
# Get the overall maximum cosine similarity value
|
99 |
+
max_cosine_sim = torch.max(cosine_sim) # This will return a scalar
|
100 |
+
return max_cosine_sim
|
101 |
+
|
102 |
+
|
103 |
+
def get_span_embeddings(
|
104 |
+
embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor
|
105 |
+
) -> List[FloatTensor]:
|
106 |
+
result = []
|
107 |
+
for embeds, starts, ends in zip(embeddings, start_indices, end_indices):
|
108 |
+
span_embeds = embeds[starts[0] : ends[0]]
|
109 |
+
result.append(span_embeds)
|
110 |
+
return result
|
111 |
+
|
112 |
+
|
113 |
+
@PyTorchIEModel.register()
|
114 |
+
class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler):
|
115 |
+
def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]:
|
116 |
+
output = self.model(**model_inputs)
|
117 |
+
hidden_state = output.last_hidden_state
|
118 |
+
# pooled_output = self.pooler(hidden_state, **pooler_inputs)
|
119 |
+
# pooled_output = self.dropout(pooled_output)
|
120 |
+
span_embeds = get_span_embeddings(hidden_state, **pooler_inputs)
|
121 |
+
return span_embeds
|
122 |
+
|
123 |
+
def forward(
|
124 |
+
self,
|
125 |
+
inputs: InputType,
|
126 |
+
targets: Optional[TargetType] = None,
|
127 |
+
return_hidden_states: bool = False,
|
128 |
+
) -> OutputType:
|
129 |
+
sanitized_inputs = separate_arguments_by_prefix(
|
130 |
+
# Note that the order of the prefixes is important because one is a prefix of the other,
|
131 |
+
# so we need to start with the longer!
|
132 |
+
arguments=inputs,
|
133 |
+
prefixes=["pooler_pair_", "pooler_"],
|
134 |
+
)
|
135 |
+
|
136 |
+
span_embeddings = self.get_pooled_output(
|
137 |
+
model_inputs=sanitized_inputs["remaining"]["encoding"],
|
138 |
+
pooler_inputs=sanitized_inputs["pooler_"],
|
139 |
+
)
|
140 |
+
span_embeddings_pair = self.get_pooled_output(
|
141 |
+
model_inputs=sanitized_inputs["remaining"]["encoding_pair"],
|
142 |
+
pooler_inputs=sanitized_inputs["pooler_pair_"],
|
143 |
+
)
|
144 |
+
|
145 |
+
logits_list = [
|
146 |
+
get_max_cosine_sim(span_embeds, span_embeds_pair)
|
147 |
+
for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair)
|
148 |
+
]
|
149 |
+
logits = torch.stack(logits_list)
|
150 |
+
|
151 |
+
result = {"logits": logits}
|
152 |
+
if targets is not None:
|
153 |
+
labels = targets["scores"]
|
154 |
+
loss = self.loss_fct(logits, labels)
|
155 |
+
result["loss"] = loss
|
156 |
+
if return_hidden_states:
|
157 |
+
raise NotImplementedError("return_hidden_states is not yet implemented")
|
158 |
+
|
159 |
+
return SequenceClassifierOutput(**result)
|
160 |
+
|
161 |
+
|
162 |
+
@PyTorchIEModel.register()
|
163 |
+
class SequencePairSimilarityModelWithMaxCosineSimAndAdapter(
|
164 |
+
SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter
|
165 |
+
):
|
166 |
+
pass
|
src/models/utils/__init__.py
CHANGED
@@ -1 +1,5 @@
|
|
1 |
-
from .loading import
|
|
|
|
|
|
|
|
|
|
1 |
+
from .loading import (
|
2 |
+
load_model_from_pie_model,
|
3 |
+
load_model_with_adapter,
|
4 |
+
load_tokenizer_from_pie_taskmodule,
|
5 |
+
)
|
src/models/utils/loading.py
CHANGED
@@ -23,10 +23,10 @@ def load_tokenizer_from_pie_taskmodule(taskmodule_kwargs: Dict[str, Any]) -> Pre
|
|
23 |
|
24 |
|
25 |
def load_model_with_adapter(
|
26 |
-
|
27 |
-
) ->
|
28 |
-
from adapters import AutoAdapterModel
|
29 |
|
30 |
model = AutoAdapterModel.from_pretrained(**model_kwargs)
|
31 |
model.load_adapter(set_active=True, **adapter_kwargs)
|
32 |
-
return model
|
|
|
23 |
|
24 |
|
25 |
def load_model_with_adapter(
|
26 |
+
model_kwargs: Dict[str, Any], adapter_kwargs: Dict[str, Any]
|
27 |
+
) -> PreTrainedModel:
|
28 |
+
from adapters import AutoAdapterModel
|
29 |
|
30 |
model = AutoAdapterModel.from_pretrained(**model_kwargs)
|
31 |
model.load_adapter(set_active=True, **adapter_kwargs)
|
32 |
+
return model
|
src/pipeline/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .ner_re_pipeline import NerRePipeline
|
2 |
+
from .span_retrieval_based_re_pipeline import SpanRetrievalBasedRelationExtractionPipeline
|
src/pipeline/ner_re_pipeline.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from functools import partial
|
5 |
+
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, TypeVar, Union
|
6 |
+
|
7 |
+
from pie_modules.utils import resolve_type
|
8 |
+
from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
|
9 |
+
from pytorch_ie.core import Document
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
D = TypeVar("D", bound=Document)
|
15 |
+
|
16 |
+
|
17 |
+
def clear_annotation_layers(doc: D, layer_names: List[str], predictions: bool = False) -> None:
|
18 |
+
for layer_name in layer_names:
|
19 |
+
if predictions:
|
20 |
+
doc[layer_name].predictions.clear()
|
21 |
+
else:
|
22 |
+
doc[layer_name].clear()
|
23 |
+
|
24 |
+
|
25 |
+
def move_annotations_from_predictions(doc: D, layer_names: List[str]) -> None:
|
26 |
+
for layer_name in layer_names:
|
27 |
+
annotations = list(doc[layer_name].predictions)
|
28 |
+
# remove any previous annotations
|
29 |
+
doc[layer_name].clear()
|
30 |
+
# each annotation can be attached to just one annotation container, so we need to clear the predictions
|
31 |
+
doc[layer_name].predictions.clear()
|
32 |
+
doc[layer_name].extend(annotations)
|
33 |
+
|
34 |
+
|
35 |
+
def move_annotations_to_predictions(doc: D, layer_names: List[str]) -> None:
|
36 |
+
for layer_name in layer_names:
|
37 |
+
annotations = list(doc[layer_name])
|
38 |
+
# each annotation can be attached to just one annotation container, so we need to clear the layer
|
39 |
+
doc[layer_name].clear()
|
40 |
+
# remove any previous annotations
|
41 |
+
doc[layer_name].predictions.clear()
|
42 |
+
doc[layer_name].predictions.extend(annotations)
|
43 |
+
|
44 |
+
|
45 |
+
def add_annotations_from_other_documents(
|
46 |
+
docs: Iterable[D],
|
47 |
+
other_docs: Sequence[Document],
|
48 |
+
layer_names: List[str],
|
49 |
+
from_predictions: bool = False,
|
50 |
+
to_predictions: bool = False,
|
51 |
+
clear_before: bool = True,
|
52 |
+
) -> None:
|
53 |
+
for i, doc in enumerate(docs):
|
54 |
+
other_doc = other_docs[i]
|
55 |
+
# copy to not modify the input
|
56 |
+
other_doc = type(other_doc).fromdict(other_doc.asdict())
|
57 |
+
|
58 |
+
for layer_name in layer_names:
|
59 |
+
if clear_before:
|
60 |
+
doc[layer_name].clear()
|
61 |
+
other_layer = other_doc[layer_name]
|
62 |
+
if from_predictions:
|
63 |
+
other_layer = other_layer.predictions
|
64 |
+
other_annotations = list(other_layer)
|
65 |
+
other_layer.clear()
|
66 |
+
if to_predictions:
|
67 |
+
doc[layer_name].predictions.extend(other_annotations)
|
68 |
+
else:
|
69 |
+
doc[layer_name].extend(other_annotations)
|
70 |
+
|
71 |
+
|
72 |
+
def process_pipeline_steps(
|
73 |
+
documents: Sequence[Document],
|
74 |
+
processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]],
|
75 |
+
) -> Sequence[Document]:
|
76 |
+
|
77 |
+
# call the processors in the order they are provided
|
78 |
+
for step_name, processor in processors.items():
|
79 |
+
logger.info(f"process {step_name} ...")
|
80 |
+
processed_documents = processor(documents)
|
81 |
+
if processed_documents is not None:
|
82 |
+
documents = processed_documents
|
83 |
+
|
84 |
+
return documents
|
85 |
+
|
86 |
+
|
87 |
+
def process_documents(
|
88 |
+
documents: List[Document], processor: Callable[..., Optional[Document]], **kwargs
|
89 |
+
) -> List[Document]:
|
90 |
+
result = []
|
91 |
+
for doc in documents:
|
92 |
+
processed_doc = processor(doc, **kwargs)
|
93 |
+
if processed_doc is not None:
|
94 |
+
result.append(processed_doc)
|
95 |
+
else:
|
96 |
+
result.append(doc)
|
97 |
+
return result
|
98 |
+
|
99 |
+
|
100 |
+
class DummyTaskmodule(WithDocumentTypeMixin):
|
101 |
+
def __init__(self, document_type: Optional[Union[Type[Document], str]]):
|
102 |
+
if isinstance(document_type, str):
|
103 |
+
self._document_type = resolve_type(document_type, expected_super_type=Document)
|
104 |
+
else:
|
105 |
+
self._document_type = document_type
|
106 |
+
|
107 |
+
@property
|
108 |
+
def document_type(self) -> Optional[Type[Document]]:
|
109 |
+
return self._document_type
|
110 |
+
|
111 |
+
|
112 |
+
class NerRePipeline:
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
ner_model_path: str,
|
116 |
+
re_model_path: str,
|
117 |
+
entity_layer: str,
|
118 |
+
relation_layer: str,
|
119 |
+
device: Optional[int] = None,
|
120 |
+
batch_size: Optional[int] = None,
|
121 |
+
show_progress_bar: Optional[bool] = None,
|
122 |
+
document_type: Optional[Union[Type[Document], str]] = None,
|
123 |
+
**processor_kwargs,
|
124 |
+
):
|
125 |
+
self.taskmodule = DummyTaskmodule(document_type)
|
126 |
+
self.ner_model_path = ner_model_path
|
127 |
+
self.re_model_path = re_model_path
|
128 |
+
self.processor_kwargs = processor_kwargs or {}
|
129 |
+
self.entity_layer = entity_layer
|
130 |
+
self.relation_layer = relation_layer
|
131 |
+
# set some values for the inference processors, if provided
|
132 |
+
for inference_pipeline in ["ner_pipeline", "re_pipeline"]:
|
133 |
+
if inference_pipeline not in self.processor_kwargs:
|
134 |
+
self.processor_kwargs[inference_pipeline] = {}
|
135 |
+
if "device" not in self.processor_kwargs[inference_pipeline] and device is not None:
|
136 |
+
self.processor_kwargs[inference_pipeline]["device"] = device
|
137 |
+
if (
|
138 |
+
"batch_size" not in self.processor_kwargs[inference_pipeline]
|
139 |
+
and batch_size is not None
|
140 |
+
):
|
141 |
+
self.processor_kwargs[inference_pipeline]["batch_size"] = batch_size
|
142 |
+
if (
|
143 |
+
"show_progress_bar" not in self.processor_kwargs[inference_pipeline]
|
144 |
+
and show_progress_bar is not None
|
145 |
+
):
|
146 |
+
self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar
|
147 |
+
|
148 |
+
def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]:
|
149 |
+
|
150 |
+
input_docs: Sequence[Document]
|
151 |
+
# we need to keep the original documents to add the gold data back
|
152 |
+
original_docs: Sequence[Document]
|
153 |
+
if inplace:
|
154 |
+
input_docs = documents
|
155 |
+
original_docs = [doc.copy() for doc in documents]
|
156 |
+
else:
|
157 |
+
input_docs = [doc.copy() for doc in documents]
|
158 |
+
original_docs = documents
|
159 |
+
|
160 |
+
docs_with_predictions = process_pipeline_steps(
|
161 |
+
documents=input_docs,
|
162 |
+
processors={
|
163 |
+
"clear_annotations": partial(
|
164 |
+
process_documents,
|
165 |
+
processor=clear_annotation_layers,
|
166 |
+
layer_names=[self.entity_layer, self.relation_layer],
|
167 |
+
**self.processor_kwargs.get("clear_annotations", {}),
|
168 |
+
),
|
169 |
+
"ner_pipeline": AutoPipeline.from_pretrained(
|
170 |
+
self.ner_model_path, **self.processor_kwargs.get("ner_pipeline", {})
|
171 |
+
),
|
172 |
+
"use_predicted_entities": partial(
|
173 |
+
process_documents,
|
174 |
+
processor=move_annotations_from_predictions,
|
175 |
+
layer_names=[self.entity_layer],
|
176 |
+
**self.processor_kwargs.get("use_predicted_entities", {}),
|
177 |
+
),
|
178 |
+
# "create_candidate_relations": partial(
|
179 |
+
# process_documents,
|
180 |
+
# processor=CandidateRelationAdder(
|
181 |
+
# **self.processor_kwargs.get("create_candidate_relations", {})
|
182 |
+
# ),
|
183 |
+
# ),
|
184 |
+
"re_pipeline": AutoPipeline.from_pretrained(
|
185 |
+
self.re_model_path, **self.processor_kwargs.get("re_pipeline", {})
|
186 |
+
),
|
187 |
+
# otherwise we can not move the entities back to predictions
|
188 |
+
"clear_candidate_relations": partial(
|
189 |
+
process_documents,
|
190 |
+
processor=clear_annotation_layers,
|
191 |
+
layer_names=[self.relation_layer],
|
192 |
+
**self.processor_kwargs.get("clear_candidate_relations", {}),
|
193 |
+
),
|
194 |
+
"move_entities_to_predictions": partial(
|
195 |
+
process_documents,
|
196 |
+
processor=move_annotations_to_predictions,
|
197 |
+
layer_names=[self.entity_layer],
|
198 |
+
**self.processor_kwargs.get("move_entities_to_predictions", {}),
|
199 |
+
),
|
200 |
+
"re_add_gold_data": partial(
|
201 |
+
add_annotations_from_other_documents,
|
202 |
+
other_docs=original_docs,
|
203 |
+
layer_names=[self.entity_layer, self.relation_layer],
|
204 |
+
**self.processor_kwargs.get("re_add_gold_data", {}),
|
205 |
+
),
|
206 |
+
},
|
207 |
+
)
|
208 |
+
return docs_with_predictions
|
src/pipeline/span_retrieval_based_re_pipeline.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Optional, Sequence, Type
|
3 |
+
|
4 |
+
from langchain_core.documents import Document as LCDocument
|
5 |
+
from pie_datasets import Dataset, IterableDataset
|
6 |
+
from pytorch_ie import Document, WithDocumentTypeMixin
|
7 |
+
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
|
8 |
+
from pytorch_ie.documents import TextBasedDocument
|
9 |
+
|
10 |
+
from src.langchain_modules import DocumentAwareSpanRetriever
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class DummyTaskmodule(WithDocumentTypeMixin):
|
16 |
+
def __init__(self, document_type: Type[Document]):
|
17 |
+
self._document_type = document_type
|
18 |
+
|
19 |
+
@property
|
20 |
+
def document_type(self) -> Optional[Type[Document]]:
|
21 |
+
return self._document_type
|
22 |
+
|
23 |
+
|
24 |
+
class SpanRetrievalBasedRelationExtractionPipeline:
|
25 |
+
"""Pipeline for adding binary relations between spans based on span retrieval within the same document.
|
26 |
+
|
27 |
+
This pipeline retrieves spans for all existing spans as query and adds binary relations between the
|
28 |
+
query spans and the retrieved spans.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
retriever: The span retriever to use for retrieving spans.
|
32 |
+
relation_label: The label to use for the binary relations.
|
33 |
+
relation_layer_name: The name of the annotation layer to add the binary relations to.
|
34 |
+
load_store_path: If provided, the retriever store(s) will be loaded from this path before processing.
|
35 |
+
save_store_path: If provided, the retriever store(s) will be saved to this path after processing.
|
36 |
+
fast_dev_run: Whether to run the pipeline in fast dev mode, i.e. only processing the first 2 documents.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
retriever: DocumentAwareSpanRetriever,
|
42 |
+
relation_label: str,
|
43 |
+
relation_layer_name: str = "binary_relations",
|
44 |
+
use_predicted_annotations: bool = False,
|
45 |
+
load_store_path: Optional[str] = None,
|
46 |
+
save_store_path: Optional[str] = None,
|
47 |
+
fast_dev_run: bool = False,
|
48 |
+
):
|
49 |
+
self.retriever = retriever
|
50 |
+
if not self.retriever.retrieve_from_same_document:
|
51 |
+
raise NotImplementedError("Retriever must retrieve from the same document")
|
52 |
+
self.relation_label = relation_label
|
53 |
+
self.relation_layer_name = relation_layer_name
|
54 |
+
self.use_predicted_annotations = use_predicted_annotations
|
55 |
+
self.load_store_path = load_store_path
|
56 |
+
self.save_store_path = save_store_path
|
57 |
+
if self.load_store_path is not None:
|
58 |
+
self.retriever.load_from_directory(path=self.load_store_path)
|
59 |
+
|
60 |
+
self.fast_dev_run = fast_dev_run
|
61 |
+
|
62 |
+
# to make auto-conversion work: we request documents of type pipeline.taskmodule.document_type
|
63 |
+
# from the dataset
|
64 |
+
@property
|
65 |
+
def taskmodule(self) -> DummyTaskmodule:
|
66 |
+
return DummyTaskmodule(self.retriever.pie_document_type)
|
67 |
+
|
68 |
+
def _construct_similarity_relations(
|
69 |
+
self,
|
70 |
+
query_results: list[LCDocument],
|
71 |
+
query_span: LabeledSpan,
|
72 |
+
) -> list[BinaryRelation]:
|
73 |
+
return [
|
74 |
+
BinaryRelation(
|
75 |
+
head=query_span,
|
76 |
+
tail=lc_doc.metadata["attached_span"],
|
77 |
+
label=self.relation_label,
|
78 |
+
score=float(lc_doc.metadata["relevance_score"]),
|
79 |
+
)
|
80 |
+
for lc_doc in query_results
|
81 |
+
]
|
82 |
+
|
83 |
+
def _process_single_document(
|
84 |
+
self,
|
85 |
+
document: Document,
|
86 |
+
) -> TextBasedDocument:
|
87 |
+
if not isinstance(document, TextBasedDocument):
|
88 |
+
raise ValueError("Document must be a TextBasedDocument")
|
89 |
+
|
90 |
+
self.retriever.add_pie_documents(
|
91 |
+
[document], use_predicted_annotations=self.use_predicted_annotations
|
92 |
+
)
|
93 |
+
|
94 |
+
all_new_rels = []
|
95 |
+
spans = self.retriever.get_base_layer(
|
96 |
+
document, use_predicted_annotations=self.use_predicted_annotations
|
97 |
+
)
|
98 |
+
span_id2idx = self.retriever.get_span_id2idx_from_doc(document.id)
|
99 |
+
for span_id, span_idx in span_id2idx.items():
|
100 |
+
query_span = spans[span_idx]
|
101 |
+
query_result = self.retriever.invoke(input=span_id)
|
102 |
+
query_rels = self._construct_similarity_relations(query_result, query_span=query_span)
|
103 |
+
all_new_rels.extend(query_rels)
|
104 |
+
|
105 |
+
if self.relation_layer_name not in document:
|
106 |
+
raise ValueError(f"Document does not have a layer named {self.relation_layer_name}")
|
107 |
+
document[self.relation_layer_name].predictions.extend(all_new_rels)
|
108 |
+
|
109 |
+
if self.retriever.retrieve_from_same_document and self.save_store_path is None:
|
110 |
+
self.retriever.delete_documents([document.id])
|
111 |
+
|
112 |
+
return document
|
113 |
+
|
114 |
+
def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]:
|
115 |
+
if inplace:
|
116 |
+
raise NotImplementedError("Inplace processing is not supported yet")
|
117 |
+
|
118 |
+
if self.fast_dev_run:
|
119 |
+
logger.warning("Fast dev run enabled, only processing the first 2 documents")
|
120 |
+
documents = documents[:2]
|
121 |
+
|
122 |
+
if not isinstance(documents, (Dataset, IterableDataset)):
|
123 |
+
documents = Dataset.from_documents(documents)
|
124 |
+
|
125 |
+
mapped_documents = documents.map(self._process_single_document)
|
126 |
+
|
127 |
+
if self.save_store_path is not None:
|
128 |
+
self.retriever.save_to_directory(path=self.save_store_path)
|
129 |
+
|
130 |
+
return mapped_documents
|
src/predict.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrootutils
|
2 |
+
|
3 |
+
root = pyrootutils.setup_root(
|
4 |
+
search_from=__file__,
|
5 |
+
indicator=[".project-root"],
|
6 |
+
pythonpath=True,
|
7 |
+
dotenv=True,
|
8 |
+
)
|
9 |
+
|
10 |
+
# ------------------------------------------------------------------------------------ #
|
11 |
+
# `pyrootutils.setup_root(...)` is an optional line at the top of each entry file
|
12 |
+
# that helps to make the environment more robust and convenient
|
13 |
+
#
|
14 |
+
# the main advantages are:
|
15 |
+
# - allows you to keep all entry files in "src/" without installing project as a package
|
16 |
+
# - makes paths and scripts always work no matter where is your current work dir
|
17 |
+
# - automatically loads environment variables from ".env" file if exists
|
18 |
+
#
|
19 |
+
# how it works:
|
20 |
+
# - the line above recursively searches for either ".git" or "pyproject.toml" in present
|
21 |
+
# and parent dirs, to determine the project root dir
|
22 |
+
# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
|
23 |
+
# any place without installing project as a package
|
24 |
+
# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
|
25 |
+
# to make all paths always relative to the project root
|
26 |
+
# - loads environment variables from ".env" file in root dir (if `dotenv=True`)
|
27 |
+
#
|
28 |
+
# you can remove `pyrootutils.setup_root(...)` if you:
|
29 |
+
# 1. either install project as a package or move each entry file to the project root dir
|
30 |
+
# 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
|
31 |
+
# 3. always run entry files from the project root dir
|
32 |
+
#
|
33 |
+
# https://github.com/ashleve/pyrootutils
|
34 |
+
# ------------------------------------------------------------------------------------ #
|
35 |
+
|
36 |
+
import os
|
37 |
+
import timeit
|
38 |
+
from collections.abc import Iterable, Sequence
|
39 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
40 |
+
|
41 |
+
import hydra
|
42 |
+
import pytorch_lightning as pl
|
43 |
+
from omegaconf import DictConfig, OmegaConf
|
44 |
+
from pie_datasets import Dataset, DatasetDict
|
45 |
+
from pie_modules.models import * # noqa: F403
|
46 |
+
from pie_modules.taskmodules import * # noqa: F403
|
47 |
+
from pytorch_ie import Document, Pipeline
|
48 |
+
from pytorch_ie.models import * # noqa: F403
|
49 |
+
from pytorch_ie.taskmodules import * # noqa: F403
|
50 |
+
|
51 |
+
from src import utils
|
52 |
+
from src.models import * # noqa: F403
|
53 |
+
from src.serializer.interface import DocumentSerializer
|
54 |
+
from src.taskmodules import * # noqa: F403
|
55 |
+
|
56 |
+
log = utils.get_pylogger(__name__)
|
57 |
+
|
58 |
+
|
59 |
+
def document_batch_iter(
|
60 |
+
dataset: Union[Sequence[Document], Iterable[Document]], batch_size: int
|
61 |
+
) -> Iterable[Sequence[Document]]:
|
62 |
+
if isinstance(dataset, Sequence):
|
63 |
+
for i in range(0, len(dataset), batch_size):
|
64 |
+
yield dataset[i : i + batch_size]
|
65 |
+
elif isinstance(dataset, Iterable):
|
66 |
+
docs = []
|
67 |
+
for doc in dataset:
|
68 |
+
docs.append(doc)
|
69 |
+
if len(docs) == batch_size:
|
70 |
+
yield docs
|
71 |
+
docs = []
|
72 |
+
if docs:
|
73 |
+
yield docs
|
74 |
+
else:
|
75 |
+
raise ValueError(f"Unsupported dataset type: {type(dataset)}")
|
76 |
+
|
77 |
+
|
78 |
+
@utils.task_wrapper
|
79 |
+
def predict(cfg: DictConfig) -> Tuple[dict, dict]:
|
80 |
+
"""Contains minimal example of the prediction pipeline. Uses a pretrained model to annotate
|
81 |
+
documents from a dataset and serializes them.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
cfg (DictConfig): Configuration composed by Hydra.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
None
|
88 |
+
"""
|
89 |
+
|
90 |
+
# Set seed for random number generators in pytorch, numpy and python.random
|
91 |
+
if cfg.get("seed"):
|
92 |
+
pl.seed_everything(cfg.seed, workers=True)
|
93 |
+
|
94 |
+
# Init pytorch-ie dataset
|
95 |
+
log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
|
96 |
+
dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial")
|
97 |
+
|
98 |
+
# Init pytorch-ie pipeline
|
99 |
+
# The pipeline, and therefore the inference step, is optional to allow for easy testing
|
100 |
+
# of the dataset creation and processing.
|
101 |
+
pipeline: Optional[Pipeline] = None
|
102 |
+
if cfg.get("pipeline") and cfg.pipeline.get("_target_"):
|
103 |
+
log.info(f"Instantiating pipeline <{cfg.pipeline._target_}> from {cfg.model_name_or_path}")
|
104 |
+
pipeline = hydra.utils.instantiate(cfg.pipeline, _convert_="partial")
|
105 |
+
|
106 |
+
# Per default, the model is loaded with .from_pretrained() which already loads the weights.
|
107 |
+
# However, ckpt_path can be used to load different weights from any checkpoint.
|
108 |
+
if cfg.ckpt_path is not None:
|
109 |
+
pipeline.model = pipeline.model.load_from_checkpoint(checkpoint_path=cfg.ckpt_path).to(
|
110 |
+
pipeline.device
|
111 |
+
)
|
112 |
+
|
113 |
+
# auto-convert the dataset if the metric specifies a document type
|
114 |
+
dataset = pipeline.taskmodule.convert_dataset(dataset)
|
115 |
+
|
116 |
+
# Init the serializer
|
117 |
+
serializer: Optional[DocumentSerializer] = None
|
118 |
+
if cfg.get("serializer") and cfg.serializer.get("_target_"):
|
119 |
+
log.info(f"Instantiating serializer <{cfg.serializer._target_}>")
|
120 |
+
serializer = hydra.utils.instantiate(cfg.serializer, _convert_="partial")
|
121 |
+
|
122 |
+
# select the dataset split for prediction
|
123 |
+
dataset_predict = dataset[cfg.dataset_split]
|
124 |
+
|
125 |
+
object_dict = {
|
126 |
+
"cfg": cfg,
|
127 |
+
"dataset": dataset,
|
128 |
+
"pipeline": pipeline,
|
129 |
+
"serializer": serializer,
|
130 |
+
}
|
131 |
+
result: Dict[str, Any] = {}
|
132 |
+
if pipeline is not None:
|
133 |
+
log.info("Starting inference!")
|
134 |
+
prediction_time = 0.0
|
135 |
+
else:
|
136 |
+
log.warning("No prediction pipeline is defined, skip inference!")
|
137 |
+
prediction_time = None
|
138 |
+
document_batch_size = cfg.get("document_batch_size", None)
|
139 |
+
for docs_batch in (
|
140 |
+
document_batch_iter(dataset_predict, document_batch_size)
|
141 |
+
if document_batch_size
|
142 |
+
else [dataset_predict]
|
143 |
+
):
|
144 |
+
if pipeline is not None:
|
145 |
+
t_start = timeit.default_timer()
|
146 |
+
docs_batch = pipeline(docs_batch, inplace=False)
|
147 |
+
prediction_time += timeit.default_timer() - t_start # type: ignore
|
148 |
+
|
149 |
+
# serialize the documents
|
150 |
+
if serializer is not None:
|
151 |
+
# the serializer should not return the serialized documents, but write them to disk
|
152 |
+
# and instead return some metadata such as the path to the serialized documents
|
153 |
+
serializer_result = serializer(docs_batch)
|
154 |
+
if "serializer" in result and result["serializer"] != serializer_result:
|
155 |
+
log.warning(
|
156 |
+
f"serializer result changed from {result['serializer']} to {serializer_result}"
|
157 |
+
" during prediction. Only the last result is returned."
|
158 |
+
)
|
159 |
+
result["serializer"] = serializer_result
|
160 |
+
|
161 |
+
if prediction_time is not None:
|
162 |
+
result["prediction_time"] = prediction_time
|
163 |
+
|
164 |
+
# serialize config with resolved paths
|
165 |
+
if cfg.get("config_out_path"):
|
166 |
+
config_out_dir = os.path.dirname(cfg.config_out_path)
|
167 |
+
os.makedirs(config_out_dir, exist_ok=True)
|
168 |
+
OmegaConf.save(config=cfg, f=cfg.config_out_path)
|
169 |
+
result["config"] = cfg.config_out_path
|
170 |
+
|
171 |
+
return result, object_dict
|
172 |
+
|
173 |
+
|
174 |
+
@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="predict.yaml")
|
175 |
+
def main(cfg: DictConfig) -> None:
|
176 |
+
result_dict, _ = predict(cfg)
|
177 |
+
return result_dict
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == "__main__":
|
181 |
+
utils.replace_sys_args_with_values_from_files()
|
182 |
+
utils.prepare_omegaconf()
|
183 |
+
main()
|
src/serializer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .json import JsonSerializer, JsonSerializer2
|
src/serializer/interface.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Any, Sequence
|
3 |
+
|
4 |
+
from pytorch_ie.core import Document
|
5 |
+
|
6 |
+
|
7 |
+
class DocumentSerializer(ABC):
|
8 |
+
"""This defines the interface for a document serializer.
|
9 |
+
|
10 |
+
The serializer should not return the serialized documents, but write them to disk and instead
|
11 |
+
return some metadata such as the path to the serialized documents.
|
12 |
+
"""
|
13 |
+
|
14 |
+
@abstractmethod
|
15 |
+
def __call__(self, documents: Sequence[Document]) -> Any:
|
16 |
+
pass
|
src/serializer/json.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from typing import Dict, List, Optional, Sequence, Type, TypeVar
|
4 |
+
|
5 |
+
from pie_datasets import Dataset, DatasetDict, IterableDataset
|
6 |
+
from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
|
7 |
+
from pytorch_ie.core import Document
|
8 |
+
from pytorch_ie.utils.hydra import resolve_optional_document_type, serialize_document_type
|
9 |
+
|
10 |
+
from src.serializer.interface import DocumentSerializer
|
11 |
+
from src.utils import get_pylogger
|
12 |
+
|
13 |
+
log = get_pylogger(__name__)
|
14 |
+
|
15 |
+
D = TypeVar("D", bound=Document)
|
16 |
+
|
17 |
+
|
18 |
+
def as_json_lines(file_name: str) -> bool:
|
19 |
+
if file_name.lower().endswith(".jsonl"):
|
20 |
+
return True
|
21 |
+
elif file_name.lower().endswith(".json"):
|
22 |
+
return False
|
23 |
+
else:
|
24 |
+
raise Exception(f"unknown file extension: {file_name}")
|
25 |
+
|
26 |
+
|
27 |
+
class JsonSerializer(DocumentSerializer):
|
28 |
+
def __init__(self, **kwargs):
|
29 |
+
self.default_kwargs = kwargs
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def write(
|
33 |
+
cls,
|
34 |
+
documents: Sequence[Document],
|
35 |
+
path: str,
|
36 |
+
file_name: str = "documents.jsonl",
|
37 |
+
metadata_file_name: str = METADATA_FILE_NAME,
|
38 |
+
split: Optional[str] = None,
|
39 |
+
**kwargs,
|
40 |
+
) -> Dict[str, str]:
|
41 |
+
realpath = os.path.realpath(path)
|
42 |
+
log.info(f'serialize documents to "{realpath}" ...')
|
43 |
+
os.makedirs(realpath, exist_ok=True)
|
44 |
+
|
45 |
+
# dump metadata including the document_type
|
46 |
+
if len(documents) == 0:
|
47 |
+
raise Exception("cannot serialize empty list of documents")
|
48 |
+
document_type = type(documents[0])
|
49 |
+
metadata = {"document_type": serialize_document_type(document_type)}
|
50 |
+
full_metadata_file_name = os.path.join(realpath, metadata_file_name)
|
51 |
+
if os.path.exists(full_metadata_file_name):
|
52 |
+
# load previous metadata
|
53 |
+
with open(full_metadata_file_name) as f:
|
54 |
+
previous_metadata = json.load(f)
|
55 |
+
if previous_metadata != metadata:
|
56 |
+
raise ValueError(
|
57 |
+
f"metadata file {full_metadata_file_name} already exists, "
|
58 |
+
"but the content does not match the current metadata"
|
59 |
+
"\nprevious metadata: {previous_metadata}"
|
60 |
+
"\ncurrent metadata: {metadata}"
|
61 |
+
)
|
62 |
+
else:
|
63 |
+
with open(full_metadata_file_name, "w") as f:
|
64 |
+
json.dump(metadata, f, indent=2)
|
65 |
+
|
66 |
+
if split is not None:
|
67 |
+
realpath = os.path.join(realpath, split)
|
68 |
+
os.makedirs(realpath, exist_ok=True)
|
69 |
+
full_file_name = os.path.join(realpath, file_name)
|
70 |
+
if as_json_lines(file_name):
|
71 |
+
# if the file already exists, append to it
|
72 |
+
mode = "a" if os.path.exists(full_file_name) else "w"
|
73 |
+
with open(full_file_name, mode) as f:
|
74 |
+
for doc in documents:
|
75 |
+
f.write(json.dumps(doc.asdict(), **kwargs) + "\n")
|
76 |
+
else:
|
77 |
+
docs_list = [doc.asdict() for doc in documents]
|
78 |
+
if os.path.exists(full_file_name):
|
79 |
+
# load previous documents
|
80 |
+
with open(full_file_name) as f:
|
81 |
+
previous_doc_list = json.load(f)
|
82 |
+
docs_list = previous_doc_list + docs_list
|
83 |
+
with open(full_file_name, "w") as f:
|
84 |
+
json.dump(docs_list, fp=f, **kwargs)
|
85 |
+
return {"path": realpath, "file_name": file_name, "metadata_file_name": metadata_file_name}
|
86 |
+
|
87 |
+
@classmethod
|
88 |
+
def read(
|
89 |
+
cls,
|
90 |
+
path: str,
|
91 |
+
document_type: Optional[Type[D]] = None,
|
92 |
+
file_name: str = "documents.jsonl",
|
93 |
+
metadata_file_name: str = METADATA_FILE_NAME,
|
94 |
+
split: Optional[str] = None,
|
95 |
+
) -> List[D]:
|
96 |
+
realpath = os.path.realpath(path)
|
97 |
+
log.info(f'load documents from "{realpath}" ...')
|
98 |
+
|
99 |
+
# try to load metadata including the document_type
|
100 |
+
full_metadata_file_name = os.path.join(realpath, metadata_file_name)
|
101 |
+
if os.path.exists(full_metadata_file_name):
|
102 |
+
with open(full_metadata_file_name) as f:
|
103 |
+
metadata = json.load(f)
|
104 |
+
document_type = resolve_optional_document_type(metadata.get("document_type"))
|
105 |
+
|
106 |
+
if document_type is None:
|
107 |
+
raise Exception("document_type is required to load serialized documents")
|
108 |
+
|
109 |
+
if split is not None:
|
110 |
+
realpath = os.path.join(realpath, split)
|
111 |
+
full_file_name = os.path.join(realpath, file_name)
|
112 |
+
documents = []
|
113 |
+
if as_json_lines(str(file_name)):
|
114 |
+
with open(full_file_name) as f:
|
115 |
+
for line in f:
|
116 |
+
json_dict = json.loads(line)
|
117 |
+
documents.append(document_type.fromdict(json_dict))
|
118 |
+
else:
|
119 |
+
with open(full_file_name) as f:
|
120 |
+
json_list = json.load(f)
|
121 |
+
for json_dict in json_list:
|
122 |
+
documents.append(document_type.fromdict(json_dict))
|
123 |
+
return documents
|
124 |
+
|
125 |
+
def read_with_defaults(self, **kwargs) -> List[D]:
|
126 |
+
all_kwargs = {**self.default_kwargs, **kwargs}
|
127 |
+
return self.read(**all_kwargs)
|
128 |
+
|
129 |
+
def write_with_defaults(self, **kwargs) -> Dict[str, str]:
|
130 |
+
all_kwargs = {**self.default_kwargs, **kwargs}
|
131 |
+
return self.write(**all_kwargs)
|
132 |
+
|
133 |
+
def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]:
|
134 |
+
return self.write_with_defaults(documents=documents, **kwargs)
|
135 |
+
|
136 |
+
|
137 |
+
class JsonSerializer2(DocumentSerializer):
|
138 |
+
def __init__(self, **kwargs):
|
139 |
+
self.default_kwargs = kwargs
|
140 |
+
|
141 |
+
@classmethod
|
142 |
+
def write(
|
143 |
+
cls,
|
144 |
+
documents: Sequence[Document],
|
145 |
+
path: str,
|
146 |
+
split: str = "train",
|
147 |
+
) -> Dict[str, str]:
|
148 |
+
if not isinstance(documents, (Dataset, IterableDataset)):
|
149 |
+
documents = Dataset.from_documents(documents)
|
150 |
+
dataset_dict = DatasetDict({split: documents})
|
151 |
+
dataset_dict.to_json(path=path)
|
152 |
+
return {"path": path, "split": split}
|
153 |
+
|
154 |
+
@classmethod
|
155 |
+
def read(
|
156 |
+
cls,
|
157 |
+
path: str,
|
158 |
+
document_type: Optional[Type[D]] = None,
|
159 |
+
split: Optional[str] = None,
|
160 |
+
) -> Dataset[Document]:
|
161 |
+
dataset_dict = DatasetDict.from_json(
|
162 |
+
data_dir=path, document_type=document_type, split=split
|
163 |
+
)
|
164 |
+
if split is not None:
|
165 |
+
return dataset_dict[split]
|
166 |
+
if len(dataset_dict) == 1:
|
167 |
+
return dataset_dict[list(dataset_dict.keys())[0]]
|
168 |
+
raise ValueError(f"multiple splits found in dataset_dict: {list(dataset_dict.keys())}")
|
169 |
+
|
170 |
+
def read_with_defaults(self, **kwargs) -> Sequence[D]:
|
171 |
+
all_kwargs = {**self.default_kwargs, **kwargs}
|
172 |
+
return self.read(**all_kwargs)
|
173 |
+
|
174 |
+
def write_with_defaults(self, **kwargs) -> Dict[str, str]:
|
175 |
+
all_kwargs = {**self.default_kwargs, **kwargs}
|
176 |
+
return self.write(**all_kwargs)
|
177 |
+
|
178 |
+
def __call__(self, documents: Sequence[Document], **kwargs) -> Dict[str, str]:
|
179 |
+
return self.write_with_defaults(documents=documents, **kwargs)
|
src/start_demo.py
ADDED
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
import pyrootutils
|
3 |
+
from omegaconf import DictConfig, OmegaConf, SCMode
|
4 |
+
|
5 |
+
root = pyrootutils.setup_root(
|
6 |
+
search_from=__file__,
|
7 |
+
indicator=[".project-root"],
|
8 |
+
pythonpath=True,
|
9 |
+
dotenv=True,
|
10 |
+
)
|
11 |
+
|
12 |
+
import json
|
13 |
+
import logging
|
14 |
+
|
15 |
+
import gradio as gr
|
16 |
+
import torch
|
17 |
+
import yaml
|
18 |
+
|
19 |
+
from src.demo.annotation_utils import load_argumentation_model
|
20 |
+
from src.demo.backend_utils import (
|
21 |
+
download_processed_documents,
|
22 |
+
process_text_from_arxiv,
|
23 |
+
process_uploaded_files,
|
24 |
+
render_annotated_document,
|
25 |
+
upload_processed_documents,
|
26 |
+
wrapped_add_annotated_pie_documents_from_dataset,
|
27 |
+
wrapped_process_text,
|
28 |
+
)
|
29 |
+
from src.demo.frontend_utils import (
|
30 |
+
change_tab,
|
31 |
+
escape_regex,
|
32 |
+
get_cell_for_fixed_column_from_df,
|
33 |
+
get_fix_df_height_css,
|
34 |
+
open_accordion,
|
35 |
+
unescape_regex,
|
36 |
+
)
|
37 |
+
from src.demo.rendering_utils import AVAILABLE_RENDER_MODES, HIGHLIGHT_SPANS_JS
|
38 |
+
from src.demo.retriever_utils import (
|
39 |
+
get_document_as_dict,
|
40 |
+
get_span_annotation,
|
41 |
+
load_retriever,
|
42 |
+
retrieve_all_relevant_spans,
|
43 |
+
retrieve_all_similar_spans,
|
44 |
+
retrieve_relevant_spans,
|
45 |
+
retrieve_similar_spans,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def load_yaml_config(path: str) -> str:
|
50 |
+
with open(path, "r") as file:
|
51 |
+
yaml_string = file.read()
|
52 |
+
config = yaml.safe_load(yaml_string)
|
53 |
+
return yaml.dump(config)
|
54 |
+
|
55 |
+
|
56 |
+
def resolve_config(cfg) -> dict:
|
57 |
+
return OmegaConf.to_container(cfg, resolve=True, structured_config_mode=SCMode.DICT)
|
58 |
+
|
59 |
+
|
60 |
+
@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="demo.yaml")
|
61 |
+
def main(cfg: DictConfig) -> None:
|
62 |
+
|
63 |
+
# configure logging
|
64 |
+
logging.basicConfig()
|
65 |
+
|
66 |
+
# resolve everything in the config to prevent any issues with to json serialization etc.
|
67 |
+
cfg = resolve_config(cfg)
|
68 |
+
|
69 |
+
example_text = cfg["example_text"]
|
70 |
+
|
71 |
+
default_device = "cuda" if torch.cuda.is_available() else "cpu"
|
72 |
+
|
73 |
+
default_retriever_config_str = load_yaml_config(cfg["default_retriever_config_path"])
|
74 |
+
|
75 |
+
default_model_name = cfg["default_model_name"]
|
76 |
+
default_model_revision = cfg["default_model_revision"]
|
77 |
+
handle_parts_of_same = cfg["handle_parts_of_same"]
|
78 |
+
|
79 |
+
default_arxiv_id = cfg["default_arxiv_id"]
|
80 |
+
default_load_pie_dataset_kwargs_str = json.dumps(
|
81 |
+
cfg["default_load_pie_dataset_kwargs"], indent=2
|
82 |
+
)
|
83 |
+
|
84 |
+
default_render_mode = cfg["default_render_mode"]
|
85 |
+
if default_render_mode not in AVAILABLE_RENDER_MODES:
|
86 |
+
raise ValueError(
|
87 |
+
f"Invalid default render mode '{default_render_mode}'. "
|
88 |
+
f"Choose one of {AVAILABLE_RENDER_MODES}."
|
89 |
+
)
|
90 |
+
default_render_kwargs = cfg["default_render_kwargs"]
|
91 |
+
|
92 |
+
# captions for better readability
|
93 |
+
default_split_regex = cfg["default_split_regex"]
|
94 |
+
# map from render mode to the corresponding caption
|
95 |
+
render_mode2caption = {
|
96 |
+
render_mode: cfg["render_mode_captions"].get(render_mode, render_mode)
|
97 |
+
for render_mode in AVAILABLE_RENDER_MODES
|
98 |
+
}
|
99 |
+
render_caption2mode = {v: k for k, v in render_mode2caption.items()}
|
100 |
+
default_min_similarity = cfg["default_min_similarity"]
|
101 |
+
layer_caption_mapping = cfg["layer_caption_mapping"]
|
102 |
+
relation_name_mapping = cfg["relation_name_mapping"]
|
103 |
+
|
104 |
+
gr.Info("Loading models ...")
|
105 |
+
argumentation_model = load_argumentation_model(
|
106 |
+
model_name=default_model_name,
|
107 |
+
revision=default_model_revision,
|
108 |
+
device=default_device,
|
109 |
+
)
|
110 |
+
retriever = load_retriever(
|
111 |
+
default_retriever_config_str, device=default_device, config_format="yaml"
|
112 |
+
)
|
113 |
+
|
114 |
+
with gr.Blocks(css=get_fix_df_height_css(css_class="df-docstore", max_height=300)) as demo:
|
115 |
+
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
116 |
+
# models_state = gr.State((argumentation_model, embedding_model))
|
117 |
+
argumentation_model_state = gr.State((argumentation_model,))
|
118 |
+
retriever_state = gr.State((retriever,))
|
119 |
+
|
120 |
+
with gr.Row():
|
121 |
+
with gr.Tabs() as left_tabs:
|
122 |
+
with gr.Tab("User Input", id="user_input") as user_input_tab:
|
123 |
+
doc_id = gr.Textbox(
|
124 |
+
label="Document ID",
|
125 |
+
value="user_input",
|
126 |
+
)
|
127 |
+
doc_text = gr.Textbox(
|
128 |
+
label="Text",
|
129 |
+
lines=20,
|
130 |
+
value=example_text,
|
131 |
+
)
|
132 |
+
|
133 |
+
with gr.Accordion("Model Configuration", open=False):
|
134 |
+
with gr.Accordion("argumentation structure", open=True):
|
135 |
+
model_name = gr.Textbox(
|
136 |
+
label="Model Name",
|
137 |
+
value=default_model_name,
|
138 |
+
)
|
139 |
+
model_revision = gr.Textbox(
|
140 |
+
label="Model Revision",
|
141 |
+
value=default_model_revision,
|
142 |
+
)
|
143 |
+
load_arg_model_btn = gr.Button("Load Argumentation Model")
|
144 |
+
|
145 |
+
with gr.Accordion("retriever", open=True):
|
146 |
+
retriever_config = gr.Code(
|
147 |
+
language="yaml",
|
148 |
+
label="Retriever Configuration",
|
149 |
+
value=default_retriever_config_str,
|
150 |
+
lines=len(default_retriever_config_str.split("\n")),
|
151 |
+
)
|
152 |
+
load_retriever_btn = gr.Button("Load Retriever")
|
153 |
+
|
154 |
+
device = gr.Textbox(
|
155 |
+
label="Device (e.g. 'cuda' or 'cpu')",
|
156 |
+
value=default_device,
|
157 |
+
)
|
158 |
+
load_arg_model_btn.click(
|
159 |
+
fn=lambda _model_name, _model_revision, _device: (
|
160 |
+
load_argumentation_model(
|
161 |
+
model_name=_model_name,
|
162 |
+
revision=_model_revision,
|
163 |
+
device=_device,
|
164 |
+
),
|
165 |
+
),
|
166 |
+
inputs=[model_name, model_revision, device],
|
167 |
+
outputs=argumentation_model_state,
|
168 |
+
)
|
169 |
+
load_retriever_btn.click(
|
170 |
+
fn=lambda _retriever_config, _device, _previous_retriever: (
|
171 |
+
load_retriever(
|
172 |
+
retriever_config_str=_retriever_config,
|
173 |
+
device=_device,
|
174 |
+
previous_retriever=_previous_retriever[0],
|
175 |
+
config_format="yaml",
|
176 |
+
),
|
177 |
+
),
|
178 |
+
inputs=[retriever_config, device, retriever_state],
|
179 |
+
outputs=retriever_state,
|
180 |
+
)
|
181 |
+
|
182 |
+
split_regex_escaped = gr.Textbox(
|
183 |
+
label="Regex to partition the text",
|
184 |
+
placeholder="Regular expression pattern to split the text into partitions",
|
185 |
+
value=escape_regex(default_split_regex),
|
186 |
+
)
|
187 |
+
|
188 |
+
predict_btn = gr.Button("Analyse")
|
189 |
+
|
190 |
+
with gr.Tab("Analysed Document", id="analysed_document") as analysed_document_tab:
|
191 |
+
selected_document_id = gr.Textbox(
|
192 |
+
label="Document ID", max_lines=1, interactive=False
|
193 |
+
)
|
194 |
+
rendered_output = gr.HTML(label="Rendered Output")
|
195 |
+
|
196 |
+
with gr.Accordion("Render Options", open=False):
|
197 |
+
render_as = gr.Dropdown(
|
198 |
+
label="Render with",
|
199 |
+
choices=list(render_mode2caption.values()),
|
200 |
+
value=render_mode2caption[default_render_mode],
|
201 |
+
)
|
202 |
+
render_kwargs = gr.Code(
|
203 |
+
language="json",
|
204 |
+
label="Render Arguments",
|
205 |
+
lines=len(json.dumps(default_render_kwargs, indent=2).split("\n")),
|
206 |
+
value=json.dumps(default_render_kwargs, indent=2),
|
207 |
+
)
|
208 |
+
render_btn = gr.Button("Re-render")
|
209 |
+
|
210 |
+
with gr.Accordion("See plain result ...", open=False):
|
211 |
+
get_document_json_btn = gr.Button("Fetch annotated document as JSON")
|
212 |
+
document_json = gr.JSON(label="Model Output")
|
213 |
+
|
214 |
+
with gr.Tabs() as right_tabs:
|
215 |
+
with gr.Tab("Retrieval", id="retrieval") as retrieval_tab:
|
216 |
+
with gr.Accordion(
|
217 |
+
"Indexed Documents", open=False
|
218 |
+
) as processed_documents_accordion:
|
219 |
+
processed_documents_df = gr.DataFrame(
|
220 |
+
headers=["id", "num_adus", "num_relations"],
|
221 |
+
interactive=False,
|
222 |
+
elem_classes="df-docstore",
|
223 |
+
)
|
224 |
+
gr.Markdown("Data Snapshot:")
|
225 |
+
with gr.Row():
|
226 |
+
download_processed_documents_btn = gr.DownloadButton("Download")
|
227 |
+
upload_processed_documents_btn = gr.UploadButton(
|
228 |
+
"Upload", file_types=["file"]
|
229 |
+
)
|
230 |
+
|
231 |
+
# currently not used
|
232 |
+
# relation_types = set_relation_types(
|
233 |
+
# argumentation_model_state.value[0], default=["supports_reversed", "contradicts_reversed"]
|
234 |
+
# )
|
235 |
+
|
236 |
+
# Dummy textbox to hold the hover adu id. On click on the rendered output,
|
237 |
+
# its content will be copied to selected_adu_id which will trigger the retrieval.
|
238 |
+
hover_adu_id = gr.Textbox(
|
239 |
+
label="ID (hover)",
|
240 |
+
elem_id="hover_adu_id",
|
241 |
+
interactive=False,
|
242 |
+
visible=False,
|
243 |
+
)
|
244 |
+
selected_adu_id = gr.Textbox(
|
245 |
+
label="ID (selected)",
|
246 |
+
elem_id="selected_adu_id",
|
247 |
+
interactive=False,
|
248 |
+
visible=False,
|
249 |
+
)
|
250 |
+
selected_adu_text = gr.Textbox(label="Selected ADU", interactive=False)
|
251 |
+
|
252 |
+
with gr.Accordion("Relevant ADUs from other documents", open=True):
|
253 |
+
relevant_adus_df = gr.DataFrame(
|
254 |
+
headers=[
|
255 |
+
"relation",
|
256 |
+
"adu",
|
257 |
+
"reference_adu",
|
258 |
+
"doc_id",
|
259 |
+
"sim_score",
|
260 |
+
"rel_score",
|
261 |
+
],
|
262 |
+
interactive=False,
|
263 |
+
)
|
264 |
+
|
265 |
+
with gr.Accordion("Retrieval Configuration", open=False):
|
266 |
+
min_similarity = gr.Slider(
|
267 |
+
label="Minimum Similarity",
|
268 |
+
minimum=0.0,
|
269 |
+
maximum=1.0,
|
270 |
+
step=0.01,
|
271 |
+
value=default_min_similarity,
|
272 |
+
)
|
273 |
+
top_k = gr.Slider(
|
274 |
+
label="Top K",
|
275 |
+
minimum=2,
|
276 |
+
maximum=50,
|
277 |
+
step=1,
|
278 |
+
value=10,
|
279 |
+
)
|
280 |
+
retrieve_similar_adus_btn = gr.Button(
|
281 |
+
"Retrieve *similar* ADUs for *selected* ADU"
|
282 |
+
)
|
283 |
+
similar_adus_df = gr.DataFrame(
|
284 |
+
headers=["doc_id", "adu_id", "score", "text"], interactive=False
|
285 |
+
)
|
286 |
+
retrieve_all_similar_adus_btn = gr.Button(
|
287 |
+
"Retrieve *similar* ADUs for *all* ADUs in the document"
|
288 |
+
)
|
289 |
+
all_similar_adus_df = gr.DataFrame(
|
290 |
+
headers=["doc_id", "query_adu_id", "adu_id", "score", "text"],
|
291 |
+
interactive=False,
|
292 |
+
)
|
293 |
+
retrieve_all_relevant_adus_btn = gr.Button(
|
294 |
+
"Retrieve *relevant* ADUs for *all* ADUs in the document"
|
295 |
+
)
|
296 |
+
all_relevant_adus_df = gr.DataFrame(
|
297 |
+
headers=["doc_id", "adu_id", "score", "text"], interactive=False
|
298 |
+
)
|
299 |
+
|
300 |
+
with gr.Tab("Import Documents", id="import_documents") as import_documents_tab:
|
301 |
+
upload_btn = gr.UploadButton(
|
302 |
+
"Batch Analyse Texts",
|
303 |
+
file_types=["text"],
|
304 |
+
file_count="multiple",
|
305 |
+
)
|
306 |
+
|
307 |
+
with gr.Accordion("Import text from arXiv", open=False):
|
308 |
+
arxiv_id = gr.Textbox(
|
309 |
+
label="arXiv paper ID",
|
310 |
+
placeholder=f"e.g. {default_arxiv_id}",
|
311 |
+
max_lines=1,
|
312 |
+
)
|
313 |
+
load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False)
|
314 |
+
load_arxiv_btn = gr.Button(
|
315 |
+
"Load & Analyse from arXiv", variant="secondary"
|
316 |
+
)
|
317 |
+
|
318 |
+
with gr.Accordion(
|
319 |
+
"Import argument structure annotated PIE dataset", open=False
|
320 |
+
):
|
321 |
+
load_pie_dataset_kwargs_str = gr.Code(
|
322 |
+
language="json",
|
323 |
+
label="Parameters for Loading the PIE Dataset",
|
324 |
+
value=default_load_pie_dataset_kwargs_str,
|
325 |
+
lines=len(default_load_pie_dataset_kwargs_str.split("\n")),
|
326 |
+
)
|
327 |
+
load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset")
|
328 |
+
|
329 |
+
render_event_kwargs = dict(
|
330 |
+
fn=lambda _retriever, _document_id, _render_as, _render_kwargs: render_annotated_document(
|
331 |
+
retriever=_retriever[0],
|
332 |
+
document_id=_document_id,
|
333 |
+
render_with=render_caption2mode[_render_as],
|
334 |
+
render_kwargs_json=_render_kwargs,
|
335 |
+
),
|
336 |
+
inputs=[retriever_state, selected_document_id, render_as, render_kwargs],
|
337 |
+
outputs=rendered_output,
|
338 |
+
)
|
339 |
+
|
340 |
+
show_overview_kwargs = dict(
|
341 |
+
fn=lambda _retriever: _retriever[0].docstore.overview(
|
342 |
+
layer_captions=layer_caption_mapping, use_predictions=True
|
343 |
+
),
|
344 |
+
inputs=[retriever_state],
|
345 |
+
outputs=[processed_documents_df],
|
346 |
+
)
|
347 |
+
predict_btn.click(
|
348 |
+
fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
|
349 |
+
).then(
|
350 |
+
fn=lambda _doc_text, _doc_id, _argumentation_model, _retriever, _split_regex_escaped: wrapped_process_text(
|
351 |
+
text=_doc_text,
|
352 |
+
doc_id=_doc_id,
|
353 |
+
argumentation_model=_argumentation_model[0],
|
354 |
+
retriever=_retriever[0],
|
355 |
+
split_regex_escaped=(
|
356 |
+
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
|
357 |
+
),
|
358 |
+
handle_parts_of_same=handle_parts_of_same,
|
359 |
+
),
|
360 |
+
inputs=[
|
361 |
+
doc_text,
|
362 |
+
doc_id,
|
363 |
+
argumentation_model_state,
|
364 |
+
retriever_state,
|
365 |
+
split_regex_escaped,
|
366 |
+
],
|
367 |
+
outputs=[selected_document_id],
|
368 |
+
api_name="predict",
|
369 |
+
).success(
|
370 |
+
**show_overview_kwargs
|
371 |
+
).success(
|
372 |
+
**render_event_kwargs
|
373 |
+
)
|
374 |
+
render_btn.click(**render_event_kwargs, api_name="render")
|
375 |
+
|
376 |
+
load_arxiv_btn.click(
|
377 |
+
fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
|
378 |
+
).then(
|
379 |
+
fn=lambda _arxiv_id, _load_arxiv_only_abstract, _argumentation_model, _retriever, _split_regex_escaped: process_text_from_arxiv(
|
380 |
+
arxiv_id=_arxiv_id.strip() or default_arxiv_id,
|
381 |
+
abstract_only=_load_arxiv_only_abstract,
|
382 |
+
argumentation_model=_argumentation_model[0],
|
383 |
+
retriever=_retriever[0],
|
384 |
+
split_regex_escaped=(
|
385 |
+
unescape_regex(_split_regex_escaped) if _split_regex_escaped else None
|
386 |
+
),
|
387 |
+
handle_parts_of_same=handle_parts_of_same,
|
388 |
+
),
|
389 |
+
inputs=[
|
390 |
+
arxiv_id,
|
391 |
+
load_arxiv_only_abstract,
|
392 |
+
argumentation_model_state,
|
393 |
+
retriever_state,
|
394 |
+
split_regex_escaped,
|
395 |
+
],
|
396 |
+
outputs=[selected_document_id],
|
397 |
+
api_name="predict",
|
398 |
+
).success(
|
399 |
+
**show_overview_kwargs
|
400 |
+
)
|
401 |
+
|
402 |
+
load_pie_dataset_btn.click(
|
403 |
+
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs]
|
404 |
+
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then(
|
405 |
+
fn=lambda _retriever, _load_pie_dataset_kwargs_str: wrapped_add_annotated_pie_documents_from_dataset(
|
406 |
+
retriever=_retriever[0],
|
407 |
+
verbose=True,
|
408 |
+
layer_captions=layer_caption_mapping,
|
409 |
+
**json.loads(_load_pie_dataset_kwargs_str),
|
410 |
+
),
|
411 |
+
inputs=[retriever_state, load_pie_dataset_kwargs_str],
|
412 |
+
outputs=[processed_documents_df],
|
413 |
+
)
|
414 |
+
|
415 |
+
selected_document_id.change(
|
416 |
+
fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs]
|
417 |
+
).then(**render_event_kwargs)
|
418 |
+
|
419 |
+
get_document_json_btn.click(
|
420 |
+
fn=lambda _retriever, _document_id: get_document_as_dict(
|
421 |
+
retriever=_retriever[0], doc_id=_document_id
|
422 |
+
),
|
423 |
+
inputs=[retriever_state, selected_document_id],
|
424 |
+
outputs=[document_json],
|
425 |
+
)
|
426 |
+
|
427 |
+
upload_btn.upload(
|
428 |
+
fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs]
|
429 |
+
).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then(
|
430 |
+
fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_files(
|
431 |
+
file_names=_file_names,
|
432 |
+
argumentation_model=_argumentation_model[0],
|
433 |
+
retriever=_retriever[0],
|
434 |
+
split_regex_escaped=unescape_regex(_split_regex_escaped),
|
435 |
+
handle_parts_of_same=handle_parts_of_same,
|
436 |
+
layer_captions=layer_caption_mapping,
|
437 |
+
),
|
438 |
+
inputs=[
|
439 |
+
upload_btn,
|
440 |
+
argumentation_model_state,
|
441 |
+
retriever_state,
|
442 |
+
split_regex_escaped,
|
443 |
+
],
|
444 |
+
outputs=[processed_documents_df],
|
445 |
+
)
|
446 |
+
processed_documents_df.select(
|
447 |
+
fn=get_cell_for_fixed_column_from_df,
|
448 |
+
inputs=[processed_documents_df, gr.State("doc_id")],
|
449 |
+
outputs=[selected_document_id],
|
450 |
+
)
|
451 |
+
|
452 |
+
download_processed_documents_btn.click(
|
453 |
+
fn=lambda _retriever: download_processed_documents(
|
454 |
+
_retriever[0], file_name="processed_documents"
|
455 |
+
),
|
456 |
+
inputs=[retriever_state],
|
457 |
+
outputs=[download_processed_documents_btn],
|
458 |
+
)
|
459 |
+
upload_processed_documents_btn.upload(
|
460 |
+
fn=lambda file_name, _retriever: upload_processed_documents(
|
461 |
+
file_name, retriever=_retriever[0], layer_captions=layer_caption_mapping
|
462 |
+
),
|
463 |
+
inputs=[upload_processed_documents_btn, retriever_state],
|
464 |
+
outputs=[processed_documents_df],
|
465 |
+
)
|
466 |
+
|
467 |
+
retrieve_relevant_adus_event_kwargs = dict(
|
468 |
+
fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k: retrieve_relevant_spans(
|
469 |
+
retriever=_retriever[0],
|
470 |
+
query_span_id=_selected_adu_id,
|
471 |
+
k=_top_k,
|
472 |
+
score_threshold=_min_similarity,
|
473 |
+
relation_label_mapping=relation_name_mapping,
|
474 |
+
# columns=relevant_adus.headers
|
475 |
+
),
|
476 |
+
inputs=[
|
477 |
+
retriever_state,
|
478 |
+
selected_adu_id,
|
479 |
+
min_similarity,
|
480 |
+
top_k,
|
481 |
+
],
|
482 |
+
outputs=[relevant_adus_df],
|
483 |
+
)
|
484 |
+
relevant_adus_df.select(
|
485 |
+
fn=get_cell_for_fixed_column_from_df,
|
486 |
+
inputs=[relevant_adus_df, gr.State("doc_id")],
|
487 |
+
outputs=[selected_document_id],
|
488 |
+
)
|
489 |
+
|
490 |
+
selected_adu_id.change(
|
491 |
+
fn=lambda _retriever, _selected_adu_id: get_span_annotation(
|
492 |
+
retriever=_retriever[0], span_id=_selected_adu_id
|
493 |
+
),
|
494 |
+
inputs=[retriever_state, selected_adu_id],
|
495 |
+
outputs=[selected_adu_text],
|
496 |
+
).success(**retrieve_relevant_adus_event_kwargs)
|
497 |
+
|
498 |
+
retrieve_similar_adus_btn.click(
|
499 |
+
fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k: retrieve_similar_spans(
|
500 |
+
retriever=_retriever[0],
|
501 |
+
query_span_id=_selected_adu_id,
|
502 |
+
k=_tok_k,
|
503 |
+
score_threshold=_min_similarity,
|
504 |
+
),
|
505 |
+
inputs=[
|
506 |
+
retriever_state,
|
507 |
+
selected_adu_id,
|
508 |
+
min_similarity,
|
509 |
+
top_k,
|
510 |
+
],
|
511 |
+
outputs=[similar_adus_df],
|
512 |
+
)
|
513 |
+
similar_adus_df.select(
|
514 |
+
fn=get_cell_for_fixed_column_from_df,
|
515 |
+
inputs=[similar_adus_df, gr.State("doc_id")],
|
516 |
+
outputs=[selected_document_id],
|
517 |
+
)
|
518 |
+
|
519 |
+
retrieve_all_similar_adus_btn.click(
|
520 |
+
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_similar_spans(
|
521 |
+
retriever=_retriever[0],
|
522 |
+
query_doc_id=_document_id,
|
523 |
+
k=_tok_k,
|
524 |
+
score_threshold=_min_similarity,
|
525 |
+
query_span_id_column="query_span_id",
|
526 |
+
),
|
527 |
+
inputs=[
|
528 |
+
retriever_state,
|
529 |
+
selected_document_id,
|
530 |
+
min_similarity,
|
531 |
+
top_k,
|
532 |
+
],
|
533 |
+
outputs=[all_similar_adus_df],
|
534 |
+
)
|
535 |
+
|
536 |
+
retrieve_all_relevant_adus_btn.click(
|
537 |
+
fn=lambda _retriever, _document_id, _min_similarity, _tok_k: retrieve_all_relevant_spans(
|
538 |
+
retriever=_retriever[0],
|
539 |
+
query_doc_id=_document_id,
|
540 |
+
k=_tok_k,
|
541 |
+
score_threshold=_min_similarity,
|
542 |
+
query_span_id_column="query_span_id",
|
543 |
+
),
|
544 |
+
inputs=[
|
545 |
+
retriever_state,
|
546 |
+
selected_document_id,
|
547 |
+
min_similarity,
|
548 |
+
top_k,
|
549 |
+
],
|
550 |
+
outputs=[all_relevant_adus_df],
|
551 |
+
)
|
552 |
+
|
553 |
+
# select query span id from the "retrieve all" result data frames
|
554 |
+
all_similar_adus_df.select(
|
555 |
+
fn=get_cell_for_fixed_column_from_df,
|
556 |
+
inputs=[all_similar_adus_df, gr.State("query_span_id")],
|
557 |
+
outputs=[selected_adu_id],
|
558 |
+
)
|
559 |
+
all_relevant_adus_df.select(
|
560 |
+
fn=get_cell_for_fixed_column_from_df,
|
561 |
+
inputs=[all_relevant_adus_df, gr.State("query_span_id")],
|
562 |
+
outputs=[selected_adu_id],
|
563 |
+
)
|
564 |
+
|
565 |
+
# argumentation_model_state.change(
|
566 |
+
# fn=lambda _argumentation_model: set_relation_types(_argumentation_model[0]),
|
567 |
+
# inputs=[argumentation_model_state],
|
568 |
+
# outputs=[relation_types],
|
569 |
+
# )
|
570 |
+
|
571 |
+
rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[])
|
572 |
+
|
573 |
+
demo.launch()
|
574 |
+
|
575 |
+
|
576 |
+
if __name__ == "__main__":
|
577 |
+
|
578 |
+
main()
|
src/taskmodules/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .cross_text_binary_coref import CrossTextBinaryCorefTaskModuleWithOptionalContext
|
2 |
+
from .cross_text_binary_coref_nli import CrossTextBinaryCorefTaskModuleByNli
|
3 |
+
from .re_text_classification_with_indices import (
|
4 |
+
CrossTextBinaryCorefByRETextClassificationTaskModule,
|
5 |
+
RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers,
|
6 |
+
)
|
7 |
+
|
8 |
+
CrossTextBinaryCorefTaskModule2 = CrossTextBinaryCorefByRETextClassificationTaskModule
|
src/taskmodules/components/__init__.py
ADDED
File without changes
|
src/taskmodules/cross_text_binary_coref.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Optional, Sequence, TypeVar, Union
|
3 |
+
|
4 |
+
from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule
|
5 |
+
from pie_modules.taskmodules.cross_text_binary_coref import (
|
6 |
+
DocumentType,
|
7 |
+
SpanDoesNotFitIntoAvailableWindow,
|
8 |
+
TaskEncodingType,
|
9 |
+
)
|
10 |
+
from pie_modules.utils.tokenization import SpanNotAlignedWithTokenException
|
11 |
+
from pytorch_ie.annotations import Span
|
12 |
+
from pytorch_ie.core import TaskEncoding, TaskModule
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
S = TypeVar("S", bound=Span)
|
18 |
+
|
19 |
+
|
20 |
+
def shift_span(span: S, offset: int) -> S:
|
21 |
+
return span.copy(start=span.start + offset, end=span.end + offset)
|
22 |
+
|
23 |
+
|
24 |
+
@TaskModule.register()
|
25 |
+
class CrossTextBinaryCorefTaskModuleWithOptionalContext(CrossTextBinaryCorefTaskModule):
|
26 |
+
"""Same as CrossTextBinaryCorefTaskModule, but:
|
27 |
+
- optionally without context.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
without_context: bool = False,
|
33 |
+
**kwargs,
|
34 |
+
) -> None:
|
35 |
+
super().__init__(**kwargs)
|
36 |
+
self.without_context = without_context
|
37 |
+
|
38 |
+
def encode_input(
|
39 |
+
self,
|
40 |
+
document: DocumentType,
|
41 |
+
is_training: bool = False,
|
42 |
+
) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
|
43 |
+
if self.without_context:
|
44 |
+
return self.encode_input_without_context(document)
|
45 |
+
else:
|
46 |
+
return super().encode_input(document)
|
47 |
+
|
48 |
+
def encode_input_without_context(
|
49 |
+
self, document: DocumentType
|
50 |
+
) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
|
51 |
+
self.collect_all_relations(kind="available", relations=document.binary_coref_relations)
|
52 |
+
tokenizer_kwargs = dict(
|
53 |
+
padding=False,
|
54 |
+
truncation=False,
|
55 |
+
add_special_tokens=False,
|
56 |
+
)
|
57 |
+
|
58 |
+
task_encodings = []
|
59 |
+
for coref_rel in document.binary_coref_relations:
|
60 |
+
|
61 |
+
# TODO: This can miss instances if both texts are the same. We could check that
|
62 |
+
# coref_rel.head is in document.labeled_spans (same for the tail), but would this
|
63 |
+
# slow down the encoding?
|
64 |
+
if not (
|
65 |
+
coref_rel.head.target == document.text
|
66 |
+
or coref_rel.tail.target == document.text_pair
|
67 |
+
):
|
68 |
+
raise ValueError(
|
69 |
+
f"It is expected that coref relations go from (head) spans over 'text' "
|
70 |
+
f"to (tail) spans over 'text_pair', but this is not the case for this "
|
71 |
+
f"relation (i.e. it points into the other direction): {coref_rel.resolve()}"
|
72 |
+
)
|
73 |
+
encoding = self.tokenizer(text=str(coref_rel.head), **tokenizer_kwargs)
|
74 |
+
encoding_pair = self.tokenizer(text=str(coref_rel.tail), **tokenizer_kwargs)
|
75 |
+
|
76 |
+
try:
|
77 |
+
current_encoding, token_span = self.truncate_encoding_around_span(
|
78 |
+
encoding=encoding, char_span=shift_span(coref_rel.head, -coref_rel.head.start)
|
79 |
+
)
|
80 |
+
current_encoding_pair, token_span_pair = self.truncate_encoding_around_span(
|
81 |
+
encoding=encoding_pair,
|
82 |
+
char_span=shift_span(coref_rel.tail, -coref_rel.tail.start),
|
83 |
+
)
|
84 |
+
except SpanNotAlignedWithTokenException as e:
|
85 |
+
logger.warning(
|
86 |
+
f"Could not get token offsets for argument ({e.span}) of coref relation: "
|
87 |
+
f"{coref_rel.resolve()}. Skip it."
|
88 |
+
)
|
89 |
+
self.collect_relation(kind="skipped_args_not_aligned", relation=coref_rel)
|
90 |
+
continue
|
91 |
+
except SpanDoesNotFitIntoAvailableWindow as e:
|
92 |
+
logger.warning(
|
93 |
+
f"Argument span [{e.span}] does not fit into available token window "
|
94 |
+
f"({self.available_window}). Skip it."
|
95 |
+
)
|
96 |
+
self.collect_relation(
|
97 |
+
kind="skipped_span_does_not_fit_into_window", relation=coref_rel
|
98 |
+
)
|
99 |
+
continue
|
100 |
+
|
101 |
+
task_encodings.append(
|
102 |
+
TaskEncoding(
|
103 |
+
document=document,
|
104 |
+
inputs={
|
105 |
+
"encoding": current_encoding,
|
106 |
+
"encoding_pair": current_encoding_pair,
|
107 |
+
"pooler_start_indices": token_span.start,
|
108 |
+
"pooler_end_indices": token_span.end,
|
109 |
+
"pooler_pair_start_indices": token_span_pair.start,
|
110 |
+
"pooler_pair_end_indices": token_span_pair.end,
|
111 |
+
},
|
112 |
+
metadata={"candidate_annotation": coref_rel},
|
113 |
+
)
|
114 |
+
)
|
115 |
+
self.collect_relation("used", coref_rel)
|
116 |
+
return task_encodings
|
src/taskmodules/cross_text_binary_coref_nli.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TypedDict, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
|
6 |
+
from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin
|
7 |
+
from pytorch_ie import Annotation
|
8 |
+
from pytorch_ie.core import TaskEncoding, TaskModule
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
from typing_extensions import TypeAlias
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
InputEncodingType: TypeAlias = Dict[str, Any]
|
15 |
+
TargetEncodingType: TypeAlias = Sequence[float]
|
16 |
+
DocumentType: TypeAlias = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
|
17 |
+
|
18 |
+
TaskEncodingType: TypeAlias = TaskEncoding[
|
19 |
+
DocumentType,
|
20 |
+
InputEncodingType,
|
21 |
+
TargetEncodingType,
|
22 |
+
]
|
23 |
+
|
24 |
+
|
25 |
+
class TaskOutputType(TypedDict, total=False):
|
26 |
+
label_pair: Tuple[str, str]
|
27 |
+
entailment_probability_pair: Tuple[float, float]
|
28 |
+
|
29 |
+
|
30 |
+
ModelInputType: TypeAlias = Dict[str, torch.Tensor]
|
31 |
+
ModelTargetType: TypeAlias = Dict[str, torch.Tensor]
|
32 |
+
ModelOutputType: TypeAlias = Dict[str, torch.Tensor]
|
33 |
+
|
34 |
+
TaskModuleType: TypeAlias = TaskModule[
|
35 |
+
# _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput
|
36 |
+
DocumentType,
|
37 |
+
InputEncodingType,
|
38 |
+
TargetEncodingType,
|
39 |
+
Tuple[ModelInputType, Optional[ModelTargetType]],
|
40 |
+
ModelTargetType,
|
41 |
+
TaskOutputType,
|
42 |
+
]
|
43 |
+
|
44 |
+
|
45 |
+
@TaskModule.register()
|
46 |
+
class CrossTextBinaryCorefTaskModuleByNli(RelationStatisticsMixin, TaskModuleType):
|
47 |
+
"""This taskmodule processes documents of type
|
48 |
+
TextPairDocumentWithLabeledSpansAndBinaryCorefRelations in preparation for a sequence
|
49 |
+
classification model trained for NLI. The assumption is that if the entailment class is
|
50 |
+
predicted for both directions, a coreference relation exists between the two spans.
|
51 |
+
|
52 |
+
It simply tokenizes and encodes the head and tail texts of the coreference relations as text
|
53 |
+
pairs, i.e. no context of head and tail is considered. During decoding, coreference relations
|
54 |
+
are created if the entailment class (see parameter entailment_label) is predicted for both
|
55 |
+
directions and the average probability is used as the score.
|
56 |
+
"""
|
57 |
+
|
58 |
+
DOCUMENT_TYPE = DocumentType
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
tokenizer_name_or_path: str,
|
63 |
+
labels: List[str],
|
64 |
+
entailment_label: str,
|
65 |
+
**kwargs,
|
66 |
+
) -> None:
|
67 |
+
super().__init__(**kwargs)
|
68 |
+
self.save_hyperparameters()
|
69 |
+
|
70 |
+
self.labels = labels
|
71 |
+
self.entailment_label = entailment_label
|
72 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
|
73 |
+
|
74 |
+
def _post_prepare(self):
|
75 |
+
self.id_to_label = dict(enumerate(self.labels))
|
76 |
+
self.label_to_id = {v: k for k, v in self.id_to_label.items()}
|
77 |
+
self.entailment_idx = self.label_to_id[self.entailment_label]
|
78 |
+
|
79 |
+
def encode(self, documents: Union[DocumentType, Iterable[DocumentType]], **kwargs):
|
80 |
+
self.reset_statistics()
|
81 |
+
result = super().encode(documents=documents, **kwargs)
|
82 |
+
self.show_statistics()
|
83 |
+
return result
|
84 |
+
|
85 |
+
def encode_input(
|
86 |
+
self,
|
87 |
+
document: DocumentType,
|
88 |
+
is_training: bool = False,
|
89 |
+
) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
|
90 |
+
self.collect_all_relations(kind="available", relations=document.binary_coref_relations)
|
91 |
+
result = []
|
92 |
+
for coref_rel in document.binary_coref_relations:
|
93 |
+
head_text = str(coref_rel.head)
|
94 |
+
tail_text = str(coref_rel.tail)
|
95 |
+
task_encoding = TaskEncoding(
|
96 |
+
document=document,
|
97 |
+
inputs={"text": [head_text, tail_text], "text_pair": [tail_text, head_text]},
|
98 |
+
metadata={"candidate_annotation": coref_rel},
|
99 |
+
)
|
100 |
+
result.append(task_encoding)
|
101 |
+
self.collect_relation("used", coref_rel)
|
102 |
+
return result
|
103 |
+
|
104 |
+
def encode_target(
|
105 |
+
self,
|
106 |
+
task_encoding: TaskEncodingType,
|
107 |
+
) -> Optional[TargetEncodingType]:
|
108 |
+
raise NotImplementedError()
|
109 |
+
|
110 |
+
def collate(
|
111 |
+
self,
|
112 |
+
task_encodings: Sequence[
|
113 |
+
TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType]
|
114 |
+
],
|
115 |
+
) -> Tuple[ModelInputType, Optional[ModelTargetType]]:
|
116 |
+
all_texts = []
|
117 |
+
all_texts_pair = []
|
118 |
+
for task_encoding in task_encodings:
|
119 |
+
all_texts.extend(task_encoding.inputs["text"])
|
120 |
+
all_texts_pair.extend(task_encoding.inputs["text_pair"])
|
121 |
+
inputs = self.tokenizer(
|
122 |
+
text=all_texts,
|
123 |
+
text_pair=all_texts_pair,
|
124 |
+
truncation=True,
|
125 |
+
padding=True,
|
126 |
+
return_tensors="pt",
|
127 |
+
)
|
128 |
+
if not task_encodings[0].has_targets:
|
129 |
+
return inputs, None
|
130 |
+
raise NotImplementedError()
|
131 |
+
|
132 |
+
def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]:
|
133 |
+
probs_tensor = model_output["probabilities"]
|
134 |
+
labels_tensor = model_output["labels"]
|
135 |
+
|
136 |
+
bs, num_classes = probs_tensor.size()
|
137 |
+
# Reshape the probs tensor to (bs/2, 2, num_classes)
|
138 |
+
probs_paired = probs_tensor.view(bs // 2, 2, num_classes).detach().cpu().tolist()
|
139 |
+
|
140 |
+
# Reshape the labels tensor to (bs/2, 2)
|
141 |
+
labels_paired = labels_tensor.view(bs // 2, 2).detach().cpu().tolist()
|
142 |
+
|
143 |
+
result = []
|
144 |
+
for (label_id, label_id_pair), (probs_list, probs_list_pair) in zip(
|
145 |
+
labels_paired, probs_paired
|
146 |
+
):
|
147 |
+
task_output: TaskOutputType = {
|
148 |
+
"label_pair": (self.id_to_label[label_id], self.id_to_label[label_id_pair]),
|
149 |
+
"entailment_probability_pair": (
|
150 |
+
probs_list[self.entailment_idx],
|
151 |
+
probs_list_pair[self.entailment_idx],
|
152 |
+
),
|
153 |
+
}
|
154 |
+
result.append(task_output)
|
155 |
+
return result
|
156 |
+
|
157 |
+
def create_annotations_from_output(
|
158 |
+
self,
|
159 |
+
task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType],
|
160 |
+
task_output: TaskOutputType,
|
161 |
+
) -> Iterator[Tuple[str, Annotation]]:
|
162 |
+
if all(label == self.entailment_label for label in task_output["label_pair"]):
|
163 |
+
probs = task_output["entailment_probability_pair"]
|
164 |
+
score = (probs[0] + probs[1]) / 2
|
165 |
+
new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score)
|
166 |
+
yield "binary_coref_relations", new_coref_rel
|
src/taskmodules/re_text_classification_with_indices.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from itertools import chain
|
3 |
+
from typing import Dict, Optional, Sequence, Type
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from pie_modules.annotations import BinaryCorefRelation
|
7 |
+
from pie_modules.document.processing.text_pair import shift_span
|
8 |
+
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
|
9 |
+
from pie_modules.taskmodules import RETextClassificationWithIndicesTaskModule
|
10 |
+
from pie_modules.taskmodules.common import TaskModuleWithDocumentConverter
|
11 |
+
from pie_modules.taskmodules.re_text_classification_with_indices import MarkerFactory
|
12 |
+
from pie_modules.taskmodules.re_text_classification_with_indices import (
|
13 |
+
ModelTargetType as REModelTargetType,
|
14 |
+
)
|
15 |
+
from pie_modules.taskmodules.re_text_classification_with_indices import (
|
16 |
+
TaskOutputType as RETaskOutputType,
|
17 |
+
)
|
18 |
+
from pytorch_ie import Document, TaskModule
|
19 |
+
from pytorch_ie.annotations import LabeledSpan
|
20 |
+
from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations
|
21 |
+
|
22 |
+
|
23 |
+
class SharpBracketMarkerFactory(MarkerFactory):
|
24 |
+
def _get_marker(self, role: str, is_start: bool, label: Optional[str] = None) -> str:
|
25 |
+
result = "<"
|
26 |
+
if not is_start:
|
27 |
+
result += "/"
|
28 |
+
result += self._get_role_marker(role)
|
29 |
+
if label is not None:
|
30 |
+
result += f":{label}"
|
31 |
+
result += ">"
|
32 |
+
return result
|
33 |
+
|
34 |
+
def get_append_marker(self, role: str, label: Optional[str] = None) -> str:
|
35 |
+
role_marker = self._get_role_marker(role)
|
36 |
+
if label is None:
|
37 |
+
return f"<{role_marker}>"
|
38 |
+
else:
|
39 |
+
return f"<{role_marker}={label}>"
|
40 |
+
|
41 |
+
|
42 |
+
@TaskModule.register()
|
43 |
+
class RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers(
|
44 |
+
RETextClassificationWithIndicesTaskModule
|
45 |
+
):
|
46 |
+
def __init__(self, use_sharp_marker: bool = False, **kwargs):
|
47 |
+
super().__init__(**kwargs)
|
48 |
+
self.use_sharp_marker = use_sharp_marker
|
49 |
+
|
50 |
+
def get_marker_factory(self) -> MarkerFactory:
|
51 |
+
if self.use_sharp_marker:
|
52 |
+
return SharpBracketMarkerFactory(role_to_marker=self.argument_role_to_marker)
|
53 |
+
else:
|
54 |
+
return MarkerFactory(role_to_marker=self.argument_role_to_marker)
|
55 |
+
|
56 |
+
|
57 |
+
def construct_text_document_from_text_pair_coref_document(
|
58 |
+
document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
|
59 |
+
glue_text: str,
|
60 |
+
no_relation_label: str,
|
61 |
+
relation_label_mapping: Optional[Dict[str, str]] = None,
|
62 |
+
add_span_mapping_to_metadata: bool = False,
|
63 |
+
) -> TextDocumentWithLabeledSpansAndBinaryRelations:
|
64 |
+
if document.text == document.text_pair:
|
65 |
+
new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
|
66 |
+
id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text
|
67 |
+
)
|
68 |
+
old2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
|
69 |
+
new2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
|
70 |
+
for old_span in chain(document.labeled_spans, document.labeled_spans_pair):
|
71 |
+
new_span = old_span.copy()
|
72 |
+
# when detaching / copying the span, it may be the same as a previous span from the other
|
73 |
+
new_span = new2new_spans.get(new_span, new_span)
|
74 |
+
new2new_spans[new_span] = new_span
|
75 |
+
old2new_spans[old_span] = new_span
|
76 |
+
else:
|
77 |
+
new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
|
78 |
+
text=document.text + glue_text + document.text_pair,
|
79 |
+
id=document.id,
|
80 |
+
metadata=copy.deepcopy(document.metadata),
|
81 |
+
)
|
82 |
+
old2new_spans = {}
|
83 |
+
old2new_spans.update({span: span.copy() for span in document.labeled_spans})
|
84 |
+
offset = len(document.text) + len(glue_text)
|
85 |
+
old2new_spans.update(
|
86 |
+
{span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair}
|
87 |
+
)
|
88 |
+
|
89 |
+
# sort to make order deterministic
|
90 |
+
new_doc.labeled_spans.extend(
|
91 |
+
sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label))
|
92 |
+
)
|
93 |
+
for old_rel in document.binary_coref_relations:
|
94 |
+
label = old_rel.label if old_rel.score > 0.0 else no_relation_label
|
95 |
+
if relation_label_mapping is not None:
|
96 |
+
label = relation_label_mapping.get(label, label)
|
97 |
+
new_rel = old_rel.copy(
|
98 |
+
head=old2new_spans[old_rel.head],
|
99 |
+
tail=old2new_spans[old_rel.tail],
|
100 |
+
label=label,
|
101 |
+
score=1.0,
|
102 |
+
)
|
103 |
+
new_doc.binary_relations.append(new_rel)
|
104 |
+
|
105 |
+
if add_span_mapping_to_metadata:
|
106 |
+
new_doc.metadata["span_mapping"] = old2new_spans
|
107 |
+
return new_doc
|
108 |
+
|
109 |
+
|
110 |
+
@TaskModule.register()
|
111 |
+
class CrossTextBinaryCorefByRETextClassificationTaskModule(
|
112 |
+
TaskModuleWithDocumentConverter,
|
113 |
+
RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers,
|
114 |
+
):
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
coref_relation_label: str,
|
118 |
+
relation_annotation: str = "binary_relations",
|
119 |
+
probability_threshold: float = 0.0,
|
120 |
+
**kwargs,
|
121 |
+
):
|
122 |
+
if relation_annotation != "binary_relations":
|
123 |
+
raise ValueError(
|
124 |
+
f"{type(self).__name__} requires relation_annotation='binary_relations', "
|
125 |
+
f"but it is: {relation_annotation}"
|
126 |
+
)
|
127 |
+
super().__init__(relation_annotation=relation_annotation, **kwargs)
|
128 |
+
self.coref_relation_label = coref_relation_label
|
129 |
+
self.probability_threshold = probability_threshold
|
130 |
+
|
131 |
+
@property
|
132 |
+
def document_type(self) -> Optional[Type[Document]]:
|
133 |
+
return TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
|
134 |
+
|
135 |
+
def _get_glue_text(self) -> str:
|
136 |
+
result = self.tokenizer.decode(self._get_glue_token_ids())
|
137 |
+
return result
|
138 |
+
|
139 |
+
def _convert_document(
|
140 |
+
self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
|
141 |
+
) -> TextDocumentWithLabeledSpansAndBinaryRelations:
|
142 |
+
return construct_text_document_from_text_pair_coref_document(
|
143 |
+
document,
|
144 |
+
glue_text=self._get_glue_text(),
|
145 |
+
relation_label_mapping={"coref": self.coref_relation_label},
|
146 |
+
no_relation_label=self.none_label,
|
147 |
+
add_span_mapping_to_metadata=True,
|
148 |
+
)
|
149 |
+
|
150 |
+
def _integrate_predictions_from_converted_document(
|
151 |
+
self,
|
152 |
+
document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
|
153 |
+
converted_document: TextDocumentWithLabeledSpansAndBinaryRelations,
|
154 |
+
) -> None:
|
155 |
+
original2converted_span = converted_document.metadata["span_mapping"]
|
156 |
+
new2original_span = {
|
157 |
+
converted_s: orig_s for orig_s, converted_s in original2converted_span.items()
|
158 |
+
}
|
159 |
+
|
160 |
+
for rel in converted_document.binary_relations.predictions:
|
161 |
+
original_head = new2original_span[rel.head]
|
162 |
+
original_tail = new2original_span[rel.tail]
|
163 |
+
if rel.label != self.coref_relation_label:
|
164 |
+
raise ValueError(f"unexpected label: {rel.label}")
|
165 |
+
if rel.score >= self.probability_threshold:
|
166 |
+
original_predicted_rel = BinaryCorefRelation(
|
167 |
+
head=original_head, tail=original_tail, label="coref", score=rel.score
|
168 |
+
)
|
169 |
+
document.binary_coref_relations.predictions.append(original_predicted_rel)
|
170 |
+
|
171 |
+
def unbatch_output(self, model_output: REModelTargetType) -> Sequence[RETaskOutputType]:
|
172 |
+
coref_relation_idx = self.label_to_id[self.coref_relation_label]
|
173 |
+
# we are just concerned with the coref class, so we overwrite the labels field
|
174 |
+
model_output = copy.copy(model_output)
|
175 |
+
model_output["labels"] = torch.ones_like(model_output["labels"]) * coref_relation_idx
|
176 |
+
return super().unbatch_output(model_output=model_output)
|
src/train.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrootutils
|
2 |
+
|
3 |
+
root = pyrootutils.setup_root(
|
4 |
+
search_from=__file__,
|
5 |
+
indicator=[".project-root"],
|
6 |
+
pythonpath=True,
|
7 |
+
dotenv=True,
|
8 |
+
)
|
9 |
+
|
10 |
+
# ------------------------------------------------------------------------------------ #
|
11 |
+
# `pyrootutils.setup_root(...)` is an optional line at the top of each entry file
|
12 |
+
# that helps to make the environment more robust and convenient
|
13 |
+
#
|
14 |
+
# the main advantages are:
|
15 |
+
# - allows you to keep all entry files in "src/" without installing project as a package
|
16 |
+
# - makes paths and scripts always work no matter where is your current work dir
|
17 |
+
# - automatically loads environment variables from ".env" file if exists
|
18 |
+
#
|
19 |
+
# how it works:
|
20 |
+
# - the line above recursively searches for either ".git" or "pyproject.toml" in present
|
21 |
+
# and parent dirs, to determine the project root dir
|
22 |
+
# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
|
23 |
+
# any place without installing project as a package
|
24 |
+
# - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
|
25 |
+
# to make all paths always relative to the project root
|
26 |
+
# - loads environment variables from ".env" file in root dir (if `dotenv=True`)
|
27 |
+
#
|
28 |
+
# you can remove `pyrootutils.setup_root(...)` if you:
|
29 |
+
# 1. either install project as a package or move each entry file to the project root dir
|
30 |
+
# 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
|
31 |
+
# 3. always run entry files from the project root dir
|
32 |
+
#
|
33 |
+
# https://github.com/ashleve/pyrootutils
|
34 |
+
# ------------------------------------------------------------------------------------ #
|
35 |
+
|
36 |
+
import os.path
|
37 |
+
from typing import Any, Dict, List, Optional, Tuple
|
38 |
+
|
39 |
+
import hydra
|
40 |
+
import pytorch_lightning as pl
|
41 |
+
from omegaconf import DictConfig
|
42 |
+
from pie_datasets import DatasetDict
|
43 |
+
from pie_modules.models import * # noqa: F403
|
44 |
+
from pie_modules.models import SimpleGenerativeModel
|
45 |
+
from pie_modules.models.interface import RequiresTaskmoduleConfig
|
46 |
+
from pie_modules.taskmodules import * # noqa: F403
|
47 |
+
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
48 |
+
from pytorch_ie.core import PyTorchIEModel, TaskModule
|
49 |
+
from pytorch_ie.models import * # noqa: F403
|
50 |
+
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
|
51 |
+
from pytorch_ie.taskmodules import * # noqa: F403
|
52 |
+
from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize
|
53 |
+
from pytorch_lightning import Callback, Trainer
|
54 |
+
from pytorch_lightning.loggers import Logger
|
55 |
+
|
56 |
+
from src import utils
|
57 |
+
from src.datamodules import PieDataModule
|
58 |
+
from src.models import * # noqa: F403
|
59 |
+
from src.taskmodules import * # noqa: F403
|
60 |
+
|
61 |
+
log = utils.get_pylogger(__name__)
|
62 |
+
|
63 |
+
|
64 |
+
def get_metric_value(metric_dict: dict, metric_name: str) -> Optional[float]:
|
65 |
+
"""Safely retrieves value of the metric logged in LightningModule."""
|
66 |
+
|
67 |
+
if not metric_name:
|
68 |
+
log.info("Metric name is None! Skipping metric value retrieval...")
|
69 |
+
return None
|
70 |
+
|
71 |
+
if metric_name not in metric_dict:
|
72 |
+
raise Exception(
|
73 |
+
f"Metric value not found! <metric_name={metric_name}>\n"
|
74 |
+
"Make sure metric name logged in LightningModule is correct!\n"
|
75 |
+
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
|
76 |
+
)
|
77 |
+
|
78 |
+
metric_value = metric_dict[metric_name].item()
|
79 |
+
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
|
80 |
+
|
81 |
+
return metric_value
|
82 |
+
|
83 |
+
|
84 |
+
@utils.task_wrapper
|
85 |
+
def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
86 |
+
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
|
87 |
+
training.
|
88 |
+
|
89 |
+
This method is wrapped in optional @task_wrapper decorator which applies extra utilities
|
90 |
+
before and after the call.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
cfg (DictConfig): Configuration composed by Hydra.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
|
97 |
+
"""
|
98 |
+
|
99 |
+
# set seed for random number generators in pytorch, numpy and python.random
|
100 |
+
if cfg.get("seed"):
|
101 |
+
pl.seed_everything(cfg.seed, workers=True)
|
102 |
+
|
103 |
+
# Init pytorch-ie taskmodule
|
104 |
+
log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
|
105 |
+
taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")
|
106 |
+
|
107 |
+
# Init pytorch-ie dataset
|
108 |
+
log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
|
109 |
+
dataset: DatasetDict = hydra.utils.instantiate(
|
110 |
+
cfg.dataset,
|
111 |
+
_convert_="partial",
|
112 |
+
)
|
113 |
+
|
114 |
+
# auto-convert the dataset if the taskmodule specifies a document type
|
115 |
+
dataset = taskmodule.convert_dataset(dataset)
|
116 |
+
|
117 |
+
# Init pytorch-ie datamodule
|
118 |
+
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
|
119 |
+
datamodule: PieDataModule = hydra.utils.instantiate(
|
120 |
+
cfg.datamodule, dataset=dataset, taskmodule=taskmodule, _convert_="partial"
|
121 |
+
)
|
122 |
+
# Use the train dataset split to prepare the taskmodule
|
123 |
+
taskmodule.prepare(dataset[datamodule.train_split])
|
124 |
+
|
125 |
+
# Init the pytorch-ie model
|
126 |
+
log.info(f"Instantiating model <{cfg.model._target_}>")
|
127 |
+
# get additional model arguments
|
128 |
+
additional_model_kwargs: Dict[str, Any] = {}
|
129 |
+
model_cls = hydra.utils.get_class(cfg.model["_target_"])
|
130 |
+
# NOTE: MODIFY THE additional_model_kwargs IF YOUR MODEL REQUIRES ANY MORE PARAMETERS FROM THE TASKMODULE!
|
131 |
+
# SEE EXAMPLES BELOW.
|
132 |
+
if issubclass(model_cls, RequiresNumClasses):
|
133 |
+
additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id)
|
134 |
+
if issubclass(model_cls, RequiresModelNameOrPath):
|
135 |
+
if "model_name_or_path" not in cfg.model:
|
136 |
+
raise Exception(
|
137 |
+
f"Please specify model_name_or_path in the model config for {model_cls.__name__}."
|
138 |
+
)
|
139 |
+
if isinstance(taskmodule, ChangesTokenizerVocabSize):
|
140 |
+
additional_model_kwargs["tokenizer_vocab_size"] = len(taskmodule.tokenizer)
|
141 |
+
|
142 |
+
pooler_config = cfg["model"].get("pooler")
|
143 |
+
if pooler_config is not None:
|
144 |
+
if isinstance(pooler_config, str):
|
145 |
+
pooler_config = {"type": pooler_config}
|
146 |
+
pooler_config = dict(pooler_config)
|
147 |
+
if pooler_config["type"] in ["start_tokens", "mention_pooling"]:
|
148 |
+
# NOTE: This is very hacky, we should create a new interface class, e.g. RequiresPoolerNumIndices
|
149 |
+
if hasattr(taskmodule, "argument_role2idx"):
|
150 |
+
pooler_config["num_indices"] = len(taskmodule.argument_role2idx)
|
151 |
+
else:
|
152 |
+
pooler_config["num_indices"] = 1
|
153 |
+
elif pooler_config["type"] == "cls_token":
|
154 |
+
pass
|
155 |
+
else:
|
156 |
+
raise Exception(
|
157 |
+
f"unknown pooler type: {pooler_config['type']}. Please adjust the train.py script for that type."
|
158 |
+
)
|
159 |
+
additional_model_kwargs["pooler"] = pooler_config
|
160 |
+
|
161 |
+
if issubclass(model_cls, RequiresTaskmoduleConfig):
|
162 |
+
additional_model_kwargs["taskmodule_config"] = taskmodule.config
|
163 |
+
|
164 |
+
if model_cls == SimpleGenerativeModel:
|
165 |
+
# There may be already some base_model_config entries in the model config. Also need to convert the
|
166 |
+
# base_model_config to a dict, because it is a OmegaConf object which does not accept additional entries.
|
167 |
+
base_model_config = (
|
168 |
+
dict(cfg.model.base_model_config) if "base_model_config" in cfg.model else {}
|
169 |
+
)
|
170 |
+
if isinstance(taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
|
171 |
+
base_model_config.update(
|
172 |
+
dict(
|
173 |
+
bos_token_id=taskmodule.bos_id,
|
174 |
+
eos_token_id=taskmodule.eos_id,
|
175 |
+
pad_token_id=taskmodule.eos_id,
|
176 |
+
target_token_ids=taskmodule.target_token_ids,
|
177 |
+
embedding_weight_mapping=taskmodule.label_embedding_weight_mapping,
|
178 |
+
)
|
179 |
+
)
|
180 |
+
additional_model_kwargs["base_model_config"] = base_model_config
|
181 |
+
|
182 |
+
# initialize the model
|
183 |
+
model: PyTorchIEModel = hydra.utils.instantiate(
|
184 |
+
cfg.model, _convert_="partial", **additional_model_kwargs
|
185 |
+
)
|
186 |
+
|
187 |
+
log.info("Instantiating callbacks...")
|
188 |
+
callbacks: List[Callback] = utils.instantiate_dict_entries(cfg, key="callbacks")
|
189 |
+
|
190 |
+
log.info("Instantiating loggers...")
|
191 |
+
logger: List[Logger] = utils.instantiate_dict_entries(cfg, key="logger")
|
192 |
+
|
193 |
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
194 |
+
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
|
195 |
+
|
196 |
+
object_dict = {
|
197 |
+
"cfg": cfg,
|
198 |
+
"dataset": dataset,
|
199 |
+
"taskmodule": taskmodule,
|
200 |
+
"model": model,
|
201 |
+
"callbacks": callbacks,
|
202 |
+
"logger": logger,
|
203 |
+
"trainer": trainer,
|
204 |
+
}
|
205 |
+
|
206 |
+
if logger:
|
207 |
+
log.info("Logging hyperparameters!")
|
208 |
+
utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)
|
209 |
+
|
210 |
+
if cfg.model_save_dir is not None:
|
211 |
+
log.info(f"Save taskmodule to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
|
212 |
+
taskmodule.save_pretrained(save_directory=cfg.model_save_dir, push_to_hub=cfg.push_to_hub)
|
213 |
+
else:
|
214 |
+
log.warning("the taskmodule is not saved because no save_dir is specified")
|
215 |
+
|
216 |
+
if cfg.get("train"):
|
217 |
+
log.info("Starting training!")
|
218 |
+
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
|
219 |
+
|
220 |
+
train_metrics = trainer.callback_metrics
|
221 |
+
|
222 |
+
best_ckpt_path = trainer.checkpoint_callback.best_model_path
|
223 |
+
if best_ckpt_path != "":
|
224 |
+
log.info(f"Best ckpt path: {best_ckpt_path}")
|
225 |
+
best_checkpoint_file = os.path.basename(best_ckpt_path)
|
226 |
+
utils.log_hyperparameters(
|
227 |
+
logger=logger,
|
228 |
+
best_checkpoint=best_checkpoint_file,
|
229 |
+
checkpoint_dir=trainer.checkpoint_callback.dirpath,
|
230 |
+
)
|
231 |
+
|
232 |
+
if not cfg.trainer.get("fast_dev_run"):
|
233 |
+
if cfg.model_save_dir is not None:
|
234 |
+
if best_ckpt_path == "":
|
235 |
+
log.warning("Best ckpt not found! Using current weights for saving...")
|
236 |
+
else:
|
237 |
+
model = type(model).load_from_checkpoint(best_ckpt_path)
|
238 |
+
|
239 |
+
log.info(f"Save model to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
|
240 |
+
model.save_pretrained(save_directory=cfg.model_save_dir, push_to_hub=cfg.push_to_hub)
|
241 |
+
else:
|
242 |
+
log.warning("the model is not saved because no save_dir is specified")
|
243 |
+
|
244 |
+
if cfg.get("validate"):
|
245 |
+
log.info("Starting validation!")
|
246 |
+
if best_ckpt_path == "":
|
247 |
+
log.warning("Best ckpt not found! Using current weights for validation...")
|
248 |
+
trainer.validate(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None)
|
249 |
+
elif cfg.get("train"):
|
250 |
+
log.warning(
|
251 |
+
"Validation after training is skipped! That means, the finally reported validation scores are "
|
252 |
+
"the values from the *last* checkpoint, not from the *best* checkpoint (which is saved)!"
|
253 |
+
)
|
254 |
+
|
255 |
+
if cfg.get("test"):
|
256 |
+
log.info("Starting testing!")
|
257 |
+
if best_ckpt_path == "":
|
258 |
+
log.warning("Best ckpt not found! Using current weights for testing...")
|
259 |
+
trainer.test(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None)
|
260 |
+
|
261 |
+
test_metrics = trainer.callback_metrics
|
262 |
+
|
263 |
+
# merge train and test metrics
|
264 |
+
metric_dict = {**train_metrics, **test_metrics}
|
265 |
+
|
266 |
+
# add model_save_dir to the result so that it gets dumped to job_return_value.json
|
267 |
+
# if we use hydra_callbacks.SaveJobReturnValueCallback
|
268 |
+
if cfg.get("model_save_dir") is not None:
|
269 |
+
metric_dict["model_save_dir"] = cfg.model_save_dir
|
270 |
+
|
271 |
+
return metric_dict, object_dict
|
272 |
+
|
273 |
+
|
274 |
+
@hydra.main(version_base="1.2", config_path=str(root / "configs"), config_name="train.yaml")
|
275 |
+
def main(cfg: DictConfig) -> Optional[float]:
|
276 |
+
# train the model
|
277 |
+
metric_dict, _ = train(cfg)
|
278 |
+
|
279 |
+
# safely retrieve metric value for hydra-based hyperparameter optimization
|
280 |
+
if cfg.get("optimized_metric") is not None:
|
281 |
+
metric_value = get_metric_value(
|
282 |
+
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
|
283 |
+
)
|
284 |
+
|
285 |
+
# return optimized metric
|
286 |
+
return metric_value
|
287 |
+
else:
|
288 |
+
return metric_dict
|
289 |
+
|
290 |
+
|
291 |
+
if __name__ == "__main__":
|
292 |
+
utils.replace_sys_args_with_values_from_files()
|
293 |
+
utils.prepare_omegaconf()
|
294 |
+
main()
|
src/utils/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .config_utils import execute_pipeline, instantiate_dict_entries, prepare_omegaconf
|
2 |
+
from .data_utils import download_and_unzip, filter_dataframe_and_get_column
|
3 |
+
from .logging_utils import close_loggers, get_pylogger, log_hyperparameters
|
4 |
+
from .rich_utils import enforce_tags, print_config_tree
|
5 |
+
from .span_utils import distance
|
6 |
+
from .task_utils import extras, replace_sys_args_with_values_from_files, save_file, task_wrapper
|
src/utils/config_utils.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import copy
|
2 |
+
from typing import Any, List, Optional
|
3 |
+
|
4 |
+
from hydra.utils import instantiate
|
5 |
+
from omegaconf import DictConfig, OmegaConf
|
6 |
+
|
7 |
+
from src.utils.logging_utils import get_pylogger
|
8 |
+
|
9 |
+
logger = get_pylogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
def execute_pipeline(
|
13 |
+
input: Any,
|
14 |
+
setup: Optional[Any] = None,
|
15 |
+
**processors,
|
16 |
+
) -> Any:
|
17 |
+
if setup is not None and callable(setup):
|
18 |
+
setup()
|
19 |
+
result = input
|
20 |
+
for processor_name, processor_config in processors.items():
|
21 |
+
if not isinstance(processor_config, dict) or "_processor_" not in processor_config:
|
22 |
+
continue
|
23 |
+
logger.info(f"call processor: {processor_name}")
|
24 |
+
config = copy(processor_config)
|
25 |
+
if not config.pop("_enabled_", True):
|
26 |
+
logger.warning(f"skip processor because it is disabled: {processor_name}")
|
27 |
+
continue
|
28 |
+
# rename key "_processor_" to "_target_"
|
29 |
+
if "_target_" in config:
|
30 |
+
raise ValueError(
|
31 |
+
f"processor {processor_name} has a key '_target_', which is not allowed"
|
32 |
+
)
|
33 |
+
config["_target_"] = config.pop("_processor_")
|
34 |
+
# IMPORTANT: We pass result as the first argument after the config in contrast to adding it to the config.
|
35 |
+
# By doing so, we prevent that it gets converted into a OmegaConf object which would be converted back to
|
36 |
+
# a simple dict breaking all the DatasetDict methods
|
37 |
+
tmp_result = instantiate(config, result, _convert_="partial")
|
38 |
+
if tmp_result is not None:
|
39 |
+
result = tmp_result
|
40 |
+
else:
|
41 |
+
logger.warning(f'processor "{processor_name}" did not return a result')
|
42 |
+
return result
|
43 |
+
|
44 |
+
|
45 |
+
def instantiate_dict_entries(
|
46 |
+
config: DictConfig, key: str, entry_description: Optional[str] = None
|
47 |
+
) -> List:
|
48 |
+
entries: List = []
|
49 |
+
key_config = config.get(key)
|
50 |
+
|
51 |
+
if not key_config:
|
52 |
+
logger.warning(f"{key} config is empty.")
|
53 |
+
return entries
|
54 |
+
|
55 |
+
if not isinstance(key_config, DictConfig):
|
56 |
+
raise TypeError("Logger config must be a DictConfig!")
|
57 |
+
|
58 |
+
for _, entry_conf in key_config.items():
|
59 |
+
if isinstance(entry_conf, DictConfig) and "_target_" in entry_conf:
|
60 |
+
logger.info(f"Instantiating {entry_description or key} <{entry_conf._target_}>")
|
61 |
+
entries.append(instantiate(entry_conf, _convert_="partial"))
|
62 |
+
|
63 |
+
return entries
|
64 |
+
|
65 |
+
|
66 |
+
def prepare_omegaconf():
|
67 |
+
# register replace resolver (used to replace "/" with "-" in names to use them as e.g. wandb project names)
|
68 |
+
if not OmegaConf.has_resolver("replace"):
|
69 |
+
OmegaConf.register_new_resolver("replace", lambda s, x, y: s.replace(x, y))
|
70 |
+
else:
|
71 |
+
logger.warning("OmegaConf resolver 'replace' is already registered")
|
src/utils/data_utils.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
from urllib.request import urlretrieve
|
4 |
+
from zipfile import ZipFile
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
from src.utils.logging_utils import get_pylogger
|
9 |
+
|
10 |
+
log = get_pylogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
def filter_dataframe_and_get_column(
|
14 |
+
dataframe: pd.DataFrame, filter_column: str, filter_value: str, select_column: str
|
15 |
+
) -> List[str]:
|
16 |
+
return dataframe[dataframe[filter_column] == filter_value][select_column].tolist()
|
17 |
+
|
18 |
+
|
19 |
+
def download_and_unzip(
|
20 |
+
url: str, target_path: str, force_download: bool = False, remove_tmp_file: bool = False
|
21 |
+
):
|
22 |
+
log.warning(f"download zip file from {url} to {target_path} ...")
|
23 |
+
if not (url.startswith("http://") or url.startswith("https://")):
|
24 |
+
raise ValueError(f"url needs to point to a http(s) address, but it is: {url}")
|
25 |
+
tmp_file = os.path.join(target_path, os.path.basename(url))
|
26 |
+
if os.path.exists(tmp_file) and not force_download:
|
27 |
+
log.warning(f"tmp file {tmp_file} already exists, skip downloading {url}")
|
28 |
+
else:
|
29 |
+
urlretrieve(url, tmp_file) # nosec
|
30 |
+
with ZipFile(tmp_file, "r") as zfile:
|
31 |
+
zfile.extractall(target_path)
|
32 |
+
if remove_tmp_file:
|
33 |
+
os.remove(tmp_file)
|