sunshineatnoon
Add application file
1b2a9b1
raw
history blame
No virus
4.55 kB
"""This package includes all the modules related to data loading and preprocessing
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
You need to implement four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point from data loader.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from swapae.data.base_dataset import BaseDataset
import swapae.util as util
def find_dataset_using_name(dataset_name):
"""Import the module "data/[dataset_name]_dataset.py".
In the file, the class called DatasetNameDataset() will
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
"""
dataset_filename = "swapae.data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
return dataset
def get_option_setter(dataset_name):
"""Return the static method <modify_commandline_options> of the dataset class."""
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options
def create_dataset(opt):
return ConfigurableDataLoader(opt)
class DataPrefetcher():
def __init__(self, dataset):
self.dataset = dataset
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
self.next_input = next(self.dataset)
except StopIteration:
self.next_input = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
self.preload()
return input
def __iter__(self):
return self
def __len__(self):
return len(self.dataset)
class ConfigurableDataLoader():
def __init__(self, opt):
self.opt = opt
self.initialize(opt.phase)
def initialize(self, phase):
opt = self.opt
self.phase = phase
if hasattr(self, "dataloader"):
del self.dataloader
dataset_class = find_dataset_using_name(opt.dataset_mode)
dataset = dataset_class(util.copyconf(opt, phase=phase, isTrain=phase == "train"))
shuffle = phase == "train" if opt.shuffle_dataset is None else opt.shuffle_dataset == "true"
print("dataset [%s] of size %d was created. shuffled=%s" % (type(dataset).__name__, len(dataset), shuffle))
#dataset = DataPrefetcher(dataset)
self.opt = opt
self.dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=opt.batch_size,
shuffle=shuffle,
num_workers=int(opt.num_gpus),
drop_last=phase == "train",
)
#self.dataloader = dataset
self.dataloader_iterator = iter(self.dataloader)
self.repeat = phase == "train"
self.length = len(dataset)
self.underlying_dataset = dataset
def set_phase(self, target_phase):
if self.phase != target_phase:
self.initialize(target_phase)
def __iter__(self):
self.dataloader_iterator = iter(self.dataloader)
return self
def __len__(self):
return self.length
def __next__(self):
try:
return next(self.dataloader_iterator)
except StopIteration:
if self.repeat:
self.dataloader_iterator = iter(self.dataloader)
return next(self.dataloader_iterator)
else:
raise StopIteration