|
import logging |
|
from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union |
|
|
|
from pytorch_ie.core import Document |
|
from pytorch_ie.core.taskmodule import ( |
|
IterableTaskEncodingDataset, |
|
TaskEncoding, |
|
TaskEncodingDataset, |
|
TaskModule, |
|
) |
|
from pytorch_lightning import LightningDataModule |
|
from torch.utils.data import DataLoader, Sampler |
|
from typing_extensions import TypeAlias |
|
|
|
from .components.sampler import ImbalancedDatasetSampler |
|
|
|
DocumentType = TypeVar("DocumentType", bound=Document) |
|
InputEncoding = TypeVar("InputEncoding") |
|
TargetEncoding = TypeVar("TargetEncoding") |
|
DatasetType: TypeAlias = Union[ |
|
TaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]], |
|
IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]], |
|
] |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]): |
|
"""A simple LightningDataModule for PIE document datasets. |
|
|
|
A DataModule implements 5 key methods: |
|
- prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode) |
|
- setup (things to do on every accelerator in distributed mode) |
|
- train_dataloader (the training dataloader) |
|
- val_dataloader (the validation dataloader(s)) |
|
- test_dataloader (the test dataloader(s)) |
|
|
|
This allows you to share a full dataset without explaining how to download, |
|
split, transform and process the data. |
|
|
|
Read the docs: |
|
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html |
|
""" |
|
|
|
def __init__( |
|
self, |
|
taskmodule: TaskModule[DocumentType, InputEncoding, TargetEncoding, Any, Any, Any], |
|
dataset: Dict[str, Sequence[DocumentType]], |
|
data_config_path: Optional[str] = None, |
|
train_split: Optional[str] = "train", |
|
val_split: Optional[str] = "validation", |
|
test_split: Optional[str] = "test", |
|
show_progress_for_encode: bool = False, |
|
train_sampler: Optional[str] = None, |
|
dont_shuffle_train: bool = False, |
|
**dataloader_kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.taskmodule = taskmodule |
|
self.config_path = data_config_path |
|
self.dataset = dataset |
|
self.train_split = train_split |
|
self.val_split = val_split |
|
self.test_split = test_split |
|
self.show_progress_for_encode = show_progress_for_encode |
|
self.train_sampler_name = train_sampler |
|
self.dataloader_kwargs = dataloader_kwargs |
|
self.dont_shuffle_train = dont_shuffle_train |
|
|
|
self._data: Dict[str, DatasetType] = {} |
|
|
|
@property |
|
def num_train(self) -> int: |
|
if self.train_split is None: |
|
raise ValueError("no train_split assigned") |
|
data_train = self._data.get(self.train_split, None) |
|
if data_train is None: |
|
raise ValueError("can not get train size if setup() was not yet called") |
|
if isinstance(data_train, IterableTaskEncodingDataset): |
|
raise TypeError("IterableTaskEncodingDataset has no length") |
|
return len(data_train) |
|
|
|
def setup(self, stage: str): |
|
if stage == "fit": |
|
split_names = [self.train_split, self.val_split] |
|
elif stage == "validate": |
|
split_names = [self.val_split] |
|
elif stage == "test": |
|
split_names = [self.test_split] |
|
else: |
|
raise NotImplementedError(f"not implemented for stage={stage} ") |
|
|
|
for split in split_names: |
|
if split is None or split not in self.dataset: |
|
continue |
|
task_encoding_dataset = self.taskmodule.encode( |
|
self.dataset[split], |
|
encode_target=True, |
|
as_dataset=True, |
|
show_progress=self.show_progress_for_encode, |
|
) |
|
if not isinstance( |
|
task_encoding_dataset, |
|
(TaskEncodingDataset, IterableTaskEncodingDataset), |
|
): |
|
raise TypeError( |
|
f"taskmodule.encode did not return a (Iterable)TaskEncodingDataset, but: {type(task_encoding_dataset)}" |
|
) |
|
self._data[split] = task_encoding_dataset |
|
|
|
def data_split(self, split: Optional[str] = None) -> DatasetType: |
|
if split is None or split not in self._data: |
|
raise ValueError(f"data for split={split} not available") |
|
return self._data[split] |
|
|
|
def get_train_sampler( |
|
self, |
|
sampler_name: str, |
|
dataset: DatasetType, |
|
) -> Sampler: |
|
if sampler_name == "imbalanced_dataset": |
|
|
|
return ImbalancedDatasetSampler( |
|
dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds] |
|
) |
|
else: |
|
raise ValueError(f"unknown sampler name: {sampler_name}") |
|
|
|
def train_dataloader(self): |
|
ds = self.data_split(self.train_split) |
|
if self.train_sampler_name is not None: |
|
sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds) |
|
else: |
|
sampler = None |
|
|
|
shuffle = not self.dont_shuffle_train and not ( |
|
isinstance(ds, IterableTaskEncodingDataset) or sampler is not None |
|
) |
|
if not shuffle: |
|
logger.warning("not shuffling train dataloader") |
|
return DataLoader( |
|
dataset=ds, |
|
sampler=sampler, |
|
collate_fn=self.taskmodule.collate, |
|
shuffle=shuffle, |
|
**self.dataloader_kwargs, |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
dataset=self.data_split(self.val_split), |
|
collate_fn=self.taskmodule.collate, |
|
shuffle=False, |
|
**self.dataloader_kwargs, |
|
) |
|
|
|
def test_dataloader(self): |
|
return DataLoader( |
|
dataset=self.data_split(self.test_split), |
|
collate_fn=self.taskmodule.collate, |
|
shuffle=False, |
|
**self.dataloader_kwargs, |
|
) |
|
|