File size: 2,799 Bytes
6b1e9f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import copy
from os.path import join as pjoin
from typing import Any, Callable

from torch.utils.data import DataLoader

from .humanml.dataset import Text2MotionDatasetV2


class BASEDataModule:
    def __init__(self, collate_fn: Callable, batch_size: int,
                 num_workers: int, persistent_workers: bool) -> None:
        super(BASEDataModule, self).__init__()
        self.dataloader_options = {
            "batch_size": batch_size,
            "num_workers": num_workers,
            "collate_fn": collate_fn,
            "persistent_workers": persistent_workers
        }
        self.is_mm = False

    def get_sample_set(self, overrides: dict) -> Text2MotionDatasetV2:
        sample_params = copy.deepcopy(self.hparams)
        sample_params.update(overrides)
        split_file = pjoin(
            eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"),
            self.cfg.EVAL.SPLIT + ".txt",
        )
        return self.Dataset(split_file=split_file, **sample_params)

    def __getattr__(self, item: str) -> Any:
        if item.endswith("_dataset") and not item.startswith("_"):
            subset = item[:-len("_dataset")]
            item_c = "_" + item
            if item_c not in self.__dict__:

                subset = subset.upper() if subset != "val" else "EVAL"
                split = eval(f"self.cfg.{subset}.SPLIT")
                split_file = pjoin(
                    eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"),
                    eval(f"self.cfg.{subset}.SPLIT") + ".txt",
                )
                self.__dict__[item_c] = self.Dataset(split_file=split_file,
                                                     split=split,
                                                     **self.hparams)
            return getattr(self, item_c)
        classname = self.__class__.__name__
        raise AttributeError(f"'{classname}' object has no attribute '{item}'")

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_options)

    def val_dataloader(self) -> DataLoader:
        dataloader_options = self.dataloader_options.copy()
        dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE
        dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS
        dataloader_options["shuffle"] = False
        return DataLoader(self.val_dataset, **dataloader_options)

    def test_dataloader(self) -> DataLoader:
        dataloader_options = self.dataloader_options.copy()
        dataloader_options["batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
        dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
        dataloader_options["shuffle"] = False
        return DataLoader(self.test_dataset, **dataloader_options)