File size: 4,461 Bytes
2f044c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from typing import Any, List, Optional, Sequence, Union

import hydra
import lightning as pl
import torch
from lightning.pytorch.utilities.types import EVAL_DATALOADERS
from omegaconf import DictConfig
from torch.utils.data import DataLoader

from relik.common.log import get_logger
from relik.retriever.data.datasets import GoldenRetrieverDataset

logger = get_logger(__name__)


class GoldenRetrieverPLDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_dataset: Optional[GoldenRetrieverDataset] = None,
        val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None,
        test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None,
        num_workers: Optional[Union[DictConfig, int]] = None,
        datasets: Optional[DictConfig] = None,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.datasets = datasets
        if num_workers is None:
            num_workers = 0
        if isinstance(num_workers, int):
            num_workers = DictConfig(
                {"train": num_workers, "val": num_workers, "test": num_workers}
            )
        self.num_workers = num_workers
        # data
        self.train_dataset: Optional[GoldenRetrieverDataset] = train_dataset
        self.val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = val_datasets
        self.test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = test_datasets

    def prepare_data(self, *args, **kwargs):
        """
        Method for preparing the data before the training. This method is called only once.
        It is used to download the data, tokenize the data, etc.
        """
        pass

    def setup(self, stage: Optional[str] = None):
        if stage == "fit" or stage is None:
            # usually there is only one dataset for train
            # if you need more train loader, you can follow
            # the same logic as val and test datasets
            if self.train_dataset is None:
                self.train_dataset = hydra.utils.instantiate(self.datasets.train)
                self.val_datasets = [
                    hydra.utils.instantiate(dataset_cfg)
                    for dataset_cfg in self.datasets.val
                ]
        if stage == "test":
            if self.test_datasets is None:
                self.test_datasets = [
                    hydra.utils.instantiate(dataset_cfg)
                    for dataset_cfg in self.datasets.test
                ]

    def train_dataloader(self, *args, **kwargs) -> DataLoader:
        torch_dataset = self.train_dataset.to_torch_dataset()
        return DataLoader(
            # self.train_dataset.to_torch_dataset(),
            torch_dataset,
            shuffle=False,
            batch_size=None,
            num_workers=self.num_workers.train,
            pin_memory=True,
            collate_fn=lambda x: x,
        )

    def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
        dataloaders = []
        for dataset in self.val_datasets:
            torch_dataset = dataset.to_torch_dataset()
            dataloaders.append(
                DataLoader(
                    torch_dataset,
                    shuffle=False,
                    batch_size=None,
                    num_workers=self.num_workers.val,
                    pin_memory=True,
                    collate_fn=lambda x: x,
                )
            )
        return dataloaders

    def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
        dataloaders = []
        for dataset in self.test_datasets:
            torch_dataset = dataset.to_torch_dataset()
            dataloaders.append(
                DataLoader(
                    torch_dataset,
                    shuffle=False,
                    batch_size=None,
                    num_workers=self.num_workers.test,
                    pin_memory=True,
                    collate_fn=lambda x: x,
                )
            )
        return dataloaders

    def predict_dataloader(self) -> EVAL_DATALOADERS:
        raise NotImplementedError

    def transfer_batch_to_device(
        self, batch: Any, device: torch.device, dataloader_idx: int
    ) -> Any:
        return super().transfer_batch_to_device(batch, device, dataloader_idx)

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}(" f"{self.datasets=}, " f"{self.num_workers=}, "
        )