import copy from os.path import join as pjoin from typing import Any, Callable from torch.utils.data import DataLoader from .humanml.dataset import Text2MotionDatasetV2 class BASEDataModule: def __init__(self, collate_fn: Callable, batch_size: int, num_workers: int, persistent_workers: bool) -> None: super(BASEDataModule, self).__init__() self.dataloader_options = { "batch_size": batch_size, "num_workers": num_workers, "collate_fn": collate_fn, "persistent_workers": persistent_workers } self.is_mm = False def get_sample_set(self, overrides: dict) -> Text2MotionDatasetV2: sample_params = copy.deepcopy(self.hparams) sample_params.update(overrides) split_file = pjoin( eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"), self.cfg.EVAL.SPLIT + ".txt", ) return self.Dataset(split_file=split_file, **sample_params) def __getattr__(self, item: str) -> Any: if item.endswith("_dataset") and not item.startswith("_"): subset = item[:-len("_dataset")] item_c = "_" + item if item_c not in self.__dict__: subset = subset.upper() if subset != "val" else "EVAL" split = eval(f"self.cfg.{subset}.SPLIT") split_file = pjoin( eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"), eval(f"self.cfg.{subset}.SPLIT") + ".txt", ) self.__dict__[item_c] = self.Dataset(split_file=split_file, split=split, **self.hparams) return getattr(self, item_c) classname = self.__class__.__name__ raise AttributeError(f"'{classname}' object has no attribute '{item}'") def train_dataloader(self) -> DataLoader: return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_options) def val_dataloader(self) -> DataLoader: dataloader_options = self.dataloader_options.copy() dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS dataloader_options["shuffle"] = False return DataLoader(self.val_dataset, **dataloader_options) def test_dataloader(self) -> DataLoader: dataloader_options = self.dataloader_options.copy() dataloader_options["batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS dataloader_options["shuffle"] = False return DataLoader(self.test_dataset, **dataloader_options)