MotionLCM / mld /data /base.py
wxDai's picture
init
6b1e9f7
raw
history blame
2.8 kB
import copy
from os.path import join as pjoin
from typing import Any, Callable
from torch.utils.data import DataLoader
from .humanml.dataset import Text2MotionDatasetV2
class BASEDataModule:
def __init__(self, collate_fn: Callable, batch_size: int,
num_workers: int, persistent_workers: bool) -> None:
super(BASEDataModule, self).__init__()
self.dataloader_options = {
"batch_size": batch_size,
"num_workers": num_workers,
"collate_fn": collate_fn,
"persistent_workers": persistent_workers
}
self.is_mm = False
def get_sample_set(self, overrides: dict) -> Text2MotionDatasetV2:
sample_params = copy.deepcopy(self.hparams)
sample_params.update(overrides)
split_file = pjoin(
eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"),
self.cfg.EVAL.SPLIT + ".txt",
)
return self.Dataset(split_file=split_file, **sample_params)
def __getattr__(self, item: str) -> Any:
if item.endswith("_dataset") and not item.startswith("_"):
subset = item[:-len("_dataset")]
item_c = "_" + item
if item_c not in self.__dict__:
subset = subset.upper() if subset != "val" else "EVAL"
split = eval(f"self.cfg.{subset}.SPLIT")
split_file = pjoin(
eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"),
eval(f"self.cfg.{subset}.SPLIT") + ".txt",
)
self.__dict__[item_c] = self.Dataset(split_file=split_file,
split=split,
**self.hparams)
return getattr(self, item_c)
classname = self.__class__.__name__
raise AttributeError(f"'{classname}' object has no attribute '{item}'")
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_options)
def val_dataloader(self) -> DataLoader:
dataloader_options = self.dataloader_options.copy()
dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE
dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS
dataloader_options["shuffle"] = False
return DataLoader(self.val_dataset, **dataloader_options)
def test_dataloader(self) -> DataLoader:
dataloader_options = self.dataloader_options.copy()
dataloader_options["batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
dataloader_options["shuffle"] = False
return DataLoader(self.test_dataset, **dataloader_options)