"""This package includes all the modules related to data loading and preprocessing To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. You need to implement four functions: -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). -- <__len__>: return the size of dataset. -- <__getitem__>: get a data point from data loader. -- : (optionally) add dataset-specific options and set default options. Now you can use the dataset class by specifying flag '--dataset_mode dummy'. See our template dataset class 'template_dataset.py' for more details. """ import numpy as np import importlib import torch.utils.data from face3d.data.base_dataset import BaseDataset def find_dataset_using_name(dataset_name): """Import the module "data/[dataset_name]_dataset.py". In the file, the class called DatasetNameDataset() will be instantiated. It has to be a subclass of BaseDataset, and it is case-insensitive. """ dataset_filename = "data." + dataset_name + "_dataset" datasetlib = importlib.import_module(dataset_filename) dataset = None target_dataset_name = dataset_name.replace('_', '') + 'dataset' for name, cls in datasetlib.__dict__.items(): if name.lower() == target_dataset_name.lower() \ and issubclass(cls, BaseDataset): dataset = cls if dataset is None: raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) return dataset def get_option_setter(dataset_name): """Return the static method of the dataset class.""" dataset_class = find_dataset_using_name(dataset_name) return dataset_class.modify_commandline_options def create_dataset(opt, rank=0): """Create a dataset given the option. This function wraps the class CustomDatasetDataLoader. This is the main interface between this package and 'train.py'/'test.py' Example: >>> from data import create_dataset >>> dataset = create_dataset(opt) """ data_loader = CustomDatasetDataLoader(opt, rank=rank) dataset = data_loader.load_data() return dataset class CustomDatasetDataLoader(): """Wrapper class of Dataset class that performs multi-threaded data loading""" def __init__(self, opt, rank=0): """Initialize this class Step 1: create a dataset instance given the name [dataset_mode] Step 2: create a multi-threaded data loader. """ self.opt = opt dataset_class = find_dataset_using_name(opt.dataset_mode) self.dataset = dataset_class(opt) self.sampler = None print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__)) if opt.use_ddp and opt.isTrain: world_size = opt.world_size self.sampler = torch.utils.data.distributed.DistributedSampler( self.dataset, num_replicas=world_size, rank=rank, shuffle=not opt.serial_batches ) self.dataloader = torch.utils.data.DataLoader( self.dataset, sampler=self.sampler, num_workers=int(opt.num_threads / world_size), batch_size=int(opt.batch_size / world_size), drop_last=True) else: self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=opt.batch_size, shuffle=(not opt.serial_batches) and opt.isTrain, num_workers=int(opt.num_threads), drop_last=True ) def set_epoch(self, epoch): self.dataset.current_epoch = epoch if self.sampler is not None: self.sampler.set_epoch(epoch) def load_data(self): return self def __len__(self): """Return the number of data in the dataset""" return min(len(self.dataset), self.opt.max_dataset_size) def __iter__(self): """Return a batch of data""" for i, data in enumerate(self.dataloader): if i * self.opt.batch_size >= self.opt.max_dataset_size: break yield data