Spaces:
Running
on
A10G
Running
on
A10G
"""This package includes all the modules related to data loading and preprocessing | |
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. | |
You need to implement four functions: | |
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). | |
-- <__len__>: return the size of dataset. | |
-- <__getitem__>: get a data point from data loader. | |
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options. | |
Now you can use the dataset class by specifying flag '--dataset_mode dummy'. | |
See our template dataset class 'template_dataset.py' for more details. | |
""" | |
import numpy as np | |
import importlib | |
import torch.utils.data | |
from face3d.data.base_dataset import BaseDataset | |
def find_dataset_using_name(dataset_name): | |
"""Import the module "data/[dataset_name]_dataset.py". | |
In the file, the class called DatasetNameDataset() will | |
be instantiated. It has to be a subclass of BaseDataset, | |
and it is case-insensitive. | |
""" | |
dataset_filename = "data." + dataset_name + "_dataset" | |
datasetlib = importlib.import_module(dataset_filename) | |
dataset = None | |
target_dataset_name = dataset_name.replace('_', '') + 'dataset' | |
for name, cls in datasetlib.__dict__.items(): | |
if name.lower() == target_dataset_name.lower() \ | |
and issubclass(cls, BaseDataset): | |
dataset = cls | |
if dataset is None: | |
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) | |
return dataset | |
def get_option_setter(dataset_name): | |
"""Return the static method <modify_commandline_options> of the dataset class.""" | |
dataset_class = find_dataset_using_name(dataset_name) | |
return dataset_class.modify_commandline_options | |
def create_dataset(opt, rank=0): | |
"""Create a dataset given the option. | |
This function wraps the class CustomDatasetDataLoader. | |
This is the main interface between this package and 'train.py'/'test.py' | |
Example: | |
>>> from data import create_dataset | |
>>> dataset = create_dataset(opt) | |
""" | |
data_loader = CustomDatasetDataLoader(opt, rank=rank) | |
dataset = data_loader.load_data() | |
return dataset | |
class CustomDatasetDataLoader(): | |
"""Wrapper class of Dataset class that performs multi-threaded data loading""" | |
def __init__(self, opt, rank=0): | |
"""Initialize this class | |
Step 1: create a dataset instance given the name [dataset_mode] | |
Step 2: create a multi-threaded data loader. | |
""" | |
self.opt = opt | |
dataset_class = find_dataset_using_name(opt.dataset_mode) | |
self.dataset = dataset_class(opt) | |
self.sampler = None | |
print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__)) | |
if opt.use_ddp and opt.isTrain: | |
world_size = opt.world_size | |
self.sampler = torch.utils.data.distributed.DistributedSampler( | |
self.dataset, | |
num_replicas=world_size, | |
rank=rank, | |
shuffle=not opt.serial_batches | |
) | |
self.dataloader = torch.utils.data.DataLoader( | |
self.dataset, | |
sampler=self.sampler, | |
num_workers=int(opt.num_threads / world_size), | |
batch_size=int(opt.batch_size / world_size), | |
drop_last=True) | |
else: | |
self.dataloader = torch.utils.data.DataLoader( | |
self.dataset, | |
batch_size=opt.batch_size, | |
shuffle=(not opt.serial_batches) and opt.isTrain, | |
num_workers=int(opt.num_threads), | |
drop_last=True | |
) | |
def set_epoch(self, epoch): | |
self.dataset.current_epoch = epoch | |
if self.sampler is not None: | |
self.sampler.set_epoch(epoch) | |
def load_data(self): | |
return self | |
def __len__(self): | |
"""Return the number of data in the dataset""" | |
return min(len(self.dataset), self.opt.max_dataset_size) | |
def __iter__(self): | |
"""Return a batch of data""" | |
for i, data in enumerate(self.dataloader): | |
if i * self.opt.batch_size >= self.opt.max_dataset_size: | |
break | |
yield data | |