|
"""不同的模型使用不同的数据集 |
|
|
|
比如有监督模型使用的都是成对的训练数据、无监督模型使用的数据集不必使用成对的数据 |
|
This package includes all the modules related to data loading and preprocessing |
|
|
|
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. |
|
You need to implement four functions: |
|
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). |
|
-- <__len__>: return the size of dataset. |
|
-- <__getitem__>: get a data point from data loader. |
|
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options. |
|
|
|
Now you can use the dataset class by specifying a flag '--dataset_mode dummy'. |
|
See our template dataset class 'template_dataset.py' for more details. |
|
""" |
|
|
|
import pickle |
|
import importlib |
|
import torch.utils.data |
|
from .base_dataset import BaseDataset |
|
from .one_dataset import * |
|
|
|
__all__ = [OneDataset] |
|
|
|
|
|
def find_dataset_by_name(dataset_name: str): |
|
"""按照数据集名称来寻找所对应的dataset类进行动态导入 |
|
Import the module "data/[dataset_name]_dataset.py". |
|
|
|
In the file, the class called DatasetNameDataset() will |
|
be instantiated. It has to be a subclass of BaseDataset, |
|
and it is case-insensitive. |
|
""" |
|
dataset_filename = "data." + dataset_name + "_dataset" |
|
datasetlib = importlib.import_module(dataset_filename) |
|
|
|
dataset = None |
|
target_dataset_name = dataset_name.replace("_", "") + "dataset" |
|
for name, cls in datasetlib.__dict__.items(): |
|
if name.lower() == target_dataset_name.lower() and issubclass(cls, BaseDataset): |
|
dataset = cls |
|
|
|
if dataset is None: |
|
raise NotImplementedError(f"In {dataset_filename}.py, there should be a subclass of BaseDataset with class " f"name that matches {target_dataset_name} in lowercase.") |
|
return dataset |
|
|
|
|
|
def get_option_setter(dataset_name): |
|
"""Return the static method <modify_commandline_options> of the dataset class.""" |
|
dataset_class = find_dataset_by_name(dataset_name) |
|
return dataset_class.modify_commandline_options |
|
|
|
|
|
def create_dataset(opt): |
|
"""Create a dataset given the option. |
|
|
|
This function wraps the class CustomDatasetDataLoader. |
|
This is the main interface between this package and 'train.py'/'test.py' |
|
|
|
Example: |
|
>>> from data import create_dataset |
|
>>> dataset = create_dataset(opt) |
|
""" |
|
data_loader = CustomDatasetDataLoader(opt) |
|
dataset = data_loader.load_data() |
|
return dataset |
|
|
|
|
|
class CustomDatasetDataLoader: |
|
"""Wrapper class of Dataset class that performs multi-threading data loading""" |
|
|
|
def __init__(self, opt): |
|
"""Initialize this class |
|
|
|
Step 1: create a dataset instance given the name [dataset_mode] |
|
Step 2: create a multi-threading data loader. |
|
""" |
|
self.opt = opt |
|
dataset_file = f"datasets/{opt.name}.pkl" |
|
if not Path(dataset_file).exists(): |
|
|
|
dataset_class = find_dataset_by_name(opt.dataset_mode) |
|
|
|
self.dataset = dataset_class(opt) |
|
|
|
|
|
print("pickle dump dataset...") |
|
pickle.dump(self.dataset, open(dataset_file, 'wb')) |
|
else: |
|
print("pickle load dataset...") |
|
self.dataset = pickle.load(open(dataset_file, 'rb')) |
|
print("dataset [%s] was created" % type(self.dataset).__name__) |
|
|
|
self.dataloader = torch.utils.data.DataLoader( |
|
self.dataset, |
|
batch_size=opt.batch_size, |
|
shuffle=not opt.serial_batches, |
|
num_workers=int(opt.num_threads), |
|
) |
|
|
|
def load_data(self): |
|
print(f"The number of training images = {len(self)}") |
|
return self |
|
|
|
def __iter__(self): |
|
"""Return a batch of data""" |
|
for i, data in enumerate(self.dataloader): |
|
if i * self.opt.batch_size >= self.opt.max_dataset_size: |
|
break |
|
yield data |
|
|
|
def __len__(self): |
|
"""Return the number of data in the dataset""" |
|
return min(len(self.dataset), self.opt.max_dataset_size) |
|
|