File size: 2,799 Bytes
6b1e9f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import copy
from os.path import join as pjoin
from typing import Any, Callable
from torch.utils.data import DataLoader
from .humanml.dataset import Text2MotionDatasetV2
class BASEDataModule:
def __init__(self, collate_fn: Callable, batch_size: int,
num_workers: int, persistent_workers: bool) -> None:
super(BASEDataModule, self).__init__()
self.dataloader_options = {
"batch_size": batch_size,
"num_workers": num_workers,
"collate_fn": collate_fn,
"persistent_workers": persistent_workers
}
self.is_mm = False
def get_sample_set(self, overrides: dict) -> Text2MotionDatasetV2:
sample_params = copy.deepcopy(self.hparams)
sample_params.update(overrides)
split_file = pjoin(
eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"),
self.cfg.EVAL.SPLIT + ".txt",
)
return self.Dataset(split_file=split_file, **sample_params)
def __getattr__(self, item: str) -> Any:
if item.endswith("_dataset") and not item.startswith("_"):
subset = item[:-len("_dataset")]
item_c = "_" + item
if item_c not in self.__dict__:
subset = subset.upper() if subset != "val" else "EVAL"
split = eval(f"self.cfg.{subset}.SPLIT")
split_file = pjoin(
eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"),
eval(f"self.cfg.{subset}.SPLIT") + ".txt",
)
self.__dict__[item_c] = self.Dataset(split_file=split_file,
split=split,
**self.hparams)
return getattr(self, item_c)
classname = self.__class__.__name__
raise AttributeError(f"'{classname}' object has no attribute '{item}'")
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_options)
def val_dataloader(self) -> DataLoader:
dataloader_options = self.dataloader_options.copy()
dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE
dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS
dataloader_options["shuffle"] = False
return DataLoader(self.val_dataset, **dataloader_options)
def test_dataloader(self) -> DataLoader:
dataloader_options = self.dataloader_options.copy()
dataloader_options["batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
dataloader_options["shuffle"] = False
return DataLoader(self.test_dataset, **dataloader_options)
|