CycleGAN / data /__init__.py
Yanguan's picture
0
58da73e
"""不同的模型使用不同的数据集
比如有监督模型使用的都是成对的训练数据、无监督模型使用的数据集不必使用成对的数据
This package includes all the modules related to data loading and preprocessing
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
You need to implement four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point from data loader.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
Now you can use the dataset class by specifying a flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import pickle
import importlib
import torch.utils.data
from .base_dataset import BaseDataset
from .one_dataset import *
__all__ = [OneDataset]
def find_dataset_by_name(dataset_name: str):
"""按照数据集名称来寻找所对应的dataset类进行动态导入
Import the module "data/[dataset_name]_dataset.py".
In the file, the class called DatasetNameDataset() will
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
"""
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
target_dataset_name = dataset_name.replace("_", "") + "dataset"
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise NotImplementedError(f"In {dataset_filename}.py, there should be a subclass of BaseDataset with class " f"name that matches {target_dataset_name} in lowercase.")
return dataset
def get_option_setter(dataset_name):
"""Return the static method <modify_commandline_options> of the dataset class."""
dataset_class = find_dataset_by_name(dataset_name)
return dataset_class.modify_commandline_options
def create_dataset(opt):
"""Create a dataset given the option.
This function wraps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from data import create_dataset
>>> dataset = create_dataset(opt)
"""
data_loader = CustomDatasetDataLoader(opt)
dataset = data_loader.load_data()
return dataset
class CustomDatasetDataLoader:
"""Wrapper class of Dataset class that performs multi-threading data loading"""
def __init__(self, opt):
"""Initialize this class
Step 1: create a dataset instance given the name [dataset_mode]
Step 2: create a multi-threading data loader.
"""
self.opt = opt
dataset_file = f"datasets/{opt.name}.pkl"
if not Path(dataset_file).exists():
# 判断数据集类型(成对/不成对),得到相应的类包
dataset_class = find_dataset_by_name(opt.dataset_mode)
# 传入数据集路径到类包中,得到数据集
self.dataset = dataset_class(opt)
# 打包下次直接使用
# 打包后文件也很大,暂时就这样
print("pickle dump dataset...")
pickle.dump(self.dataset, open(dataset_file, 'wb'))
else:
print("pickle load dataset...")
self.dataset = pickle.load(open(dataset_file, 'rb'))
print("dataset [%s] was created" % type(self.dataset).__name__)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.num_threads),
)
def load_data(self):
print(f"The number of training images = {len(self)}")
return self
def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data
def __len__(self):
"""Return the number of data in the dataset"""
return min(len(self.dataset), self.opt.max_dataset_size)