|
from functools import partial |
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union |
|
|
|
import hydra |
|
import lightning as pl |
|
import torch |
|
from lightning.pytorch.trainer.states import RunningStage |
|
from omegaconf import DictConfig |
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
from relik.common.log import get_logger |
|
from relik.retriever.data.base.datasets import BaseDataset |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
STAGES_COMPATIBILITY_MAP = { |
|
"train": RunningStage.TRAINING, |
|
"val": RunningStage.VALIDATING, |
|
"test": RunningStage.TESTING, |
|
} |
|
|
|
DEFAULT_STAGES = { |
|
RunningStage.VALIDATING, |
|
RunningStage.TESTING, |
|
RunningStage.SANITY_CHECKING, |
|
RunningStage.PREDICTING, |
|
} |
|
|
|
|
|
class PredictionCallback(pl.Callback): |
|
def __init__( |
|
self, |
|
batch_size: int = 32, |
|
stages: Optional[Set[Union[str, RunningStage]]] = None, |
|
other_callbacks: Optional[ |
|
Union[List[DictConfig], List["NLPTemplateCallback"]] |
|
] = None, |
|
datasets: Optional[Union[DictConfig, BaseDataset]] = None, |
|
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.batch_size = batch_size |
|
self.datasets = datasets |
|
self.dataloaders = dataloaders |
|
|
|
|
|
if stages is None: |
|
stages = DEFAULT_STAGES |
|
|
|
|
|
stages = {STAGES_COMPATIBILITY_MAP.get(stage, stage) for stage in stages} |
|
self.stages = [RunningStage(stage) for stage in stages] |
|
self.other_callbacks = other_callbacks or [] |
|
for i, callback in enumerate(self.other_callbacks): |
|
if isinstance(callback, DictConfig): |
|
self.other_callbacks[i] = hydra.utils.instantiate( |
|
callback, _recursive_=False |
|
) |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
trainer: pl.Trainer, |
|
pl_module: pl.LightningModule, |
|
*args, |
|
**kwargs, |
|
) -> Any: |
|
|
|
raise NotImplementedError |
|
|
|
def on_validation_epoch_end( |
|
self, trainer: pl.Trainer, pl_module: pl.LightningModule |
|
): |
|
predictions = self(trainer, pl_module) |
|
for callback in self.other_callbacks: |
|
callback( |
|
trainer=trainer, |
|
pl_module=pl_module, |
|
callback=self, |
|
predictions=predictions, |
|
) |
|
|
|
def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): |
|
predictions = self(trainer, pl_module) |
|
for callback in self.other_callbacks: |
|
callback( |
|
trainer=trainer, |
|
pl_module=pl_module, |
|
callback=self, |
|
predictions=predictions, |
|
) |
|
|
|
@staticmethod |
|
def _get_datasets_and_dataloaders( |
|
dataset: Optional[Union[Dataset, DictConfig]], |
|
dataloader: Optional[DataLoader], |
|
trainer: pl.Trainer, |
|
dataloader_kwargs: Optional[Dict[str, Any]] = None, |
|
collate_fn: Optional[Callable] = None, |
|
collate_fn_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[List[Dataset], List[DataLoader]]: |
|
""" |
|
Get the datasets and dataloaders from the datamodule or from the dataset provided. |
|
|
|
Args: |
|
dataset (`Optional[Union[Dataset, DictConfig]]`): |
|
The dataset to use. If `None`, the datamodule is used. |
|
dataloader (`Optional[DataLoader]`): |
|
The dataloader to use. If `None`, the datamodule is used. |
|
trainer (`pl.Trainer`): |
|
The trainer that contains the datamodule. |
|
dataloader_kwargs (`Optional[Dict[str, Any]]`): |
|
The kwargs to pass to the dataloader. |
|
collate_fn (`Optional[Callable]`): |
|
The collate function to use. |
|
collate_fn_kwargs (`Optional[Dict[str, Any]]`): |
|
The kwargs to pass to the collate function. |
|
|
|
Returns: |
|
`Tuple[List[Dataset], List[DataLoader]]`: The datasets and dataloaders. |
|
""" |
|
|
|
if dataset is not None: |
|
dataloader_kwargs = dataloader_kwargs or {} |
|
|
|
if isinstance(dataset, DictConfig): |
|
dataset = hydra.utils.instantiate(dataset, _recursive_=False) |
|
datasets = [dataset] if not isinstance(dataset, list) else dataset |
|
if dataloader is not None: |
|
dataloaders = ( |
|
[dataloader] if isinstance(dataloader, DataLoader) else dataloader |
|
) |
|
else: |
|
collate_fn = collate_fn or partial( |
|
datasets[0].collate_fn, **collate_fn_kwargs |
|
) |
|
dataloader_kwargs["collate_fn"] = collate_fn |
|
dataloaders = [DataLoader(datasets[0], **dataloader_kwargs)] |
|
else: |
|
|
|
datasets = ( |
|
trainer.datamodule.test_datasets |
|
if trainer.state.stage == RunningStage.TESTING |
|
else trainer.datamodule.val_datasets |
|
) |
|
dataloaders = ( |
|
trainer.test_dataloaders |
|
if trainer.state.stage == RunningStage.TESTING |
|
else trainer.val_dataloaders |
|
) |
|
return datasets, dataloaders |
|
|
|
|
|
class NLPTemplateCallback: |
|
def __call__( |
|
self, |
|
trainer: pl.Trainer, |
|
pl_module: pl.LightningModule, |
|
callback: PredictionCallback, |
|
predictions: Dict[str, Any], |
|
*args, |
|
**kwargs, |
|
) -> Any: |
|
raise NotImplementedError |
|
|