Spaces:
Runtime error
Runtime error
File size: 4,550 Bytes
1b2a9b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""This package includes all the modules related to data loading and preprocessing
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
You need to implement four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point from data loader.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from swapae.data.base_dataset import BaseDataset
import swapae.util as util
def find_dataset_using_name(dataset_name):
"""Import the module "data/[dataset_name]_dataset.py".
In the file, the class called DatasetNameDataset() will
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
"""
dataset_filename = "swapae.data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
return dataset
def get_option_setter(dataset_name):
"""Return the static method <modify_commandline_options> of the dataset class."""
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options
def create_dataset(opt):
return ConfigurableDataLoader(opt)
class DataPrefetcher():
def __init__(self, dataset):
self.dataset = dataset
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
self.next_input = next(self.dataset)
except StopIteration:
self.next_input = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
self.preload()
return input
def __iter__(self):
return self
def __len__(self):
return len(self.dataset)
class ConfigurableDataLoader():
def __init__(self, opt):
self.opt = opt
self.initialize(opt.phase)
def initialize(self, phase):
opt = self.opt
self.phase = phase
if hasattr(self, "dataloader"):
del self.dataloader
dataset_class = find_dataset_using_name(opt.dataset_mode)
dataset = dataset_class(util.copyconf(opt, phase=phase, isTrain=phase == "train"))
shuffle = phase == "train" if opt.shuffle_dataset is None else opt.shuffle_dataset == "true"
print("dataset [%s] of size %d was created. shuffled=%s" % (type(dataset).__name__, len(dataset), shuffle))
#dataset = DataPrefetcher(dataset)
self.opt = opt
self.dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=opt.batch_size,
shuffle=shuffle,
num_workers=int(opt.num_gpus),
drop_last=phase == "train",
)
#self.dataloader = dataset
self.dataloader_iterator = iter(self.dataloader)
self.repeat = phase == "train"
self.length = len(dataset)
self.underlying_dataset = dataset
def set_phase(self, target_phase):
if self.phase != target_phase:
self.initialize(target_phase)
def __iter__(self):
self.dataloader_iterator = iter(self.dataloader)
return self
def __len__(self):
return self.length
def __next__(self):
try:
return next(self.dataloader_iterator)
except StopIteration:
if self.repeat:
self.dataloader_iterator = iter(self.dataloader)
return next(self.dataloader_iterator)
else:
raise StopIteration
|