ArneBinder's picture
upload https://github.com/ArneBinder/pie-document-level/pull/452
e7eaeed verified
import logging
from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union
from pytorch_ie.core import Document
from pytorch_ie.core.taskmodule import (
IterableTaskEncodingDataset,
TaskEncoding,
TaskEncodingDataset,
TaskModule,
)
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Sampler
from typing_extensions import TypeAlias
from .components.sampler import ImbalancedDatasetSampler
DocumentType = TypeVar("DocumentType", bound=Document)
InputEncoding = TypeVar("InputEncoding")
TargetEncoding = TypeVar("TargetEncoding")
DatasetType: TypeAlias = Union[
TaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
]
logger = logging.getLogger(__name__)
class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]):
"""A simple LightningDataModule for PIE document datasets.
A DataModule implements 5 key methods:
- prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
- setup (things to do on every accelerator in distributed mode)
- train_dataloader (the training dataloader)
- val_dataloader (the validation dataloader(s))
- test_dataloader (the test dataloader(s))
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Read the docs:
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
"""
def __init__(
self,
taskmodule: TaskModule[DocumentType, InputEncoding, TargetEncoding, Any, Any, Any],
dataset: Dict[str, Sequence[DocumentType]],
data_config_path: Optional[str] = None,
train_split: Optional[str] = "train",
val_split: Optional[str] = "validation",
test_split: Optional[str] = "test",
show_progress_for_encode: bool = False,
train_sampler: Optional[str] = None,
dont_shuffle_train: bool = False,
**dataloader_kwargs,
):
super().__init__()
self.taskmodule = taskmodule
self.config_path = data_config_path
self.dataset = dataset
self.train_split = train_split
self.val_split = val_split
self.test_split = test_split
self.show_progress_for_encode = show_progress_for_encode
self.train_sampler_name = train_sampler
self.dataloader_kwargs = dataloader_kwargs
self.dont_shuffle_train = dont_shuffle_train
self._data: Dict[str, DatasetType] = {}
@property
def num_train(self) -> int:
if self.train_split is None:
raise ValueError("no train_split assigned")
data_train = self._data.get(self.train_split, None)
if data_train is None:
raise ValueError("can not get train size if setup() was not yet called")
if isinstance(data_train, IterableTaskEncodingDataset):
raise TypeError("IterableTaskEncodingDataset has no length")
return len(data_train)
def setup(self, stage: str):
if stage == "fit":
split_names = [self.train_split, self.val_split]
elif stage == "validate":
split_names = [self.val_split]
elif stage == "test":
split_names = [self.test_split]
else:
raise NotImplementedError(f"not implemented for stage={stage} ")
for split in split_names:
if split is None or split not in self.dataset:
continue
task_encoding_dataset = self.taskmodule.encode(
self.dataset[split],
encode_target=True,
as_dataset=True,
show_progress=self.show_progress_for_encode,
)
if not isinstance(
task_encoding_dataset,
(TaskEncodingDataset, IterableTaskEncodingDataset),
):
raise TypeError(
f"taskmodule.encode did not return a (Iterable)TaskEncodingDataset, but: {type(task_encoding_dataset)}"
)
self._data[split] = task_encoding_dataset
def data_split(self, split: Optional[str] = None) -> DatasetType:
if split is None or split not in self._data:
raise ValueError(f"data for split={split} not available")
return self._data[split]
def get_train_sampler(
self,
sampler_name: str,
dataset: DatasetType,
) -> Sampler:
if sampler_name == "imbalanced_dataset":
# for now, this work only with targets that have a single entry
return ImbalancedDatasetSampler(
dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds]
)
else:
raise ValueError(f"unknown sampler name: {sampler_name}")
def train_dataloader(self):
ds = self.data_split(self.train_split)
if self.train_sampler_name is not None:
sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds)
else:
sampler = None
# don't shuffle streamed datasets or if we use a sampler or if we explicitly set dont_shuffle_train
shuffle = not self.dont_shuffle_train and not (
isinstance(ds, IterableTaskEncodingDataset) or sampler is not None
)
if not shuffle:
logger.warning("not shuffling train dataloader")
return DataLoader(
dataset=ds,
sampler=sampler,
collate_fn=self.taskmodule.collate,
shuffle=shuffle,
**self.dataloader_kwargs,
)
def val_dataloader(self):
return DataLoader(
dataset=self.data_split(self.val_split),
collate_fn=self.taskmodule.collate,
shuffle=False,
**self.dataloader_kwargs,
)
def test_dataloader(self):
return DataLoader(
dataset=self.data_split(self.test_split),
collate_fn=self.taskmodule.collate,
shuffle=False,
**self.dataloader_kwargs,
)