David Piscasio commited on
Commit
7369193
1 Parent(s): 3ce13dc

Added data folder

Browse files
data/__init__.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import importlib
14
+ import torch.utils.data
15
+ from data.base_dataset import BaseDataset
16
+
17
+
18
+ def find_dataset_using_name(dataset_name):
19
+ """Import the module "data/[dataset_name]_dataset.py".
20
+
21
+ In the file, the class called DatasetNameDataset() will
22
+ be instantiated. It has to be a subclass of BaseDataset,
23
+ and it is case-insensitive.
24
+ """
25
+ dataset_filename = "data." + dataset_name + "_dataset"
26
+ datasetlib = importlib.import_module(dataset_filename)
27
+
28
+ dataset = None
29
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30
+ for name, cls in datasetlib.__dict__.items():
31
+ if name.lower() == target_dataset_name.lower() \
32
+ and issubclass(cls, BaseDataset):
33
+ dataset = cls
34
+
35
+ if dataset is None:
36
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37
+
38
+ return dataset
39
+
40
+
41
+ def get_option_setter(dataset_name):
42
+ """Return the static method <modify_commandline_options> of the dataset class."""
43
+ dataset_class = find_dataset_using_name(dataset_name)
44
+ return dataset_class.modify_commandline_options
45
+
46
+
47
+ def create_dataset(opt):
48
+ """Create a dataset given the option.
49
+
50
+ This function wraps the class CustomDatasetDataLoader.
51
+ This is the main interface between this package and 'train.py'/'test.py'
52
+
53
+ Example:
54
+ >>> from data import create_dataset
55
+ >>> dataset = create_dataset(opt)
56
+ """
57
+ data_loader = CustomDatasetDataLoader(opt)
58
+ dataset = data_loader.load_data()
59
+ return dataset
60
+
61
+
62
+ class CustomDatasetDataLoader():
63
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
+
65
+ def __init__(self, opt):
66
+ """Initialize this class
67
+
68
+ Step 1: create a dataset instance given the name [dataset_mode]
69
+ Step 2: create a multi-threaded data loader.
70
+ """
71
+ self.opt = opt
72
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
73
+ self.dataset = dataset_class(opt)
74
+ print("dataset [%s] was created" % type(self.dataset).__name__)
75
+ self.dataloader = torch.utils.data.DataLoader(
76
+ self.dataset,
77
+ batch_size=opt.batch_size,
78
+ shuffle=not opt.serial_batches,
79
+ num_workers=int(opt.num_threads))
80
+
81
+ def load_data(self):
82
+ return self
83
+
84
+ def __len__(self):
85
+ """Return the number of data in the dataset"""
86
+ return min(len(self.dataset), self.opt.max_dataset_size)
87
+
88
+ def __iter__(self):
89
+ """Return a batch of data"""
90
+ for i, data in enumerate(self.dataloader):
91
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
92
+ break
93
+ yield data
data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (4.03 kB). View file
 
data/__pycache__/base_dataset.cpython-38.pyc ADDED
Binary file (5.9 kB). View file
 
data/__pycache__/image_folder.cpython-38.pyc ADDED
Binary file (2.53 kB). View file
 
data/__pycache__/single_dataset.cpython-38.pyc ADDED
Binary file (2.01 kB). View file
 
data/aligned_dataset.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from data.base_dataset import BaseDataset, get_params, get_transform
3
+ from data.image_folder import make_dataset
4
+ from PIL import Image
5
+
6
+
7
+ class AlignedDataset(BaseDataset):
8
+ """A dataset class for paired image dataset.
9
+
10
+ It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
11
+ During test time, you need to prepare a directory '/path/to/data/test'.
12
+ """
13
+
14
+ def __init__(self, opt):
15
+ """Initialize this dataset class.
16
+
17
+ Parameters:
18
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
19
+ """
20
+ BaseDataset.__init__(self, opt)
21
+ self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory
22
+ self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths
23
+ assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image
24
+ self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
25
+ self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc
26
+
27
+ def __getitem__(self, index):
28
+ """Return a data point and its metadata information.
29
+
30
+ Parameters:
31
+ index - - a random integer for data indexing
32
+
33
+ Returns a dictionary that contains A, B, A_paths and B_paths
34
+ A (tensor) - - an image in the input domain
35
+ B (tensor) - - its corresponding image in the target domain
36
+ A_paths (str) - - image paths
37
+ B_paths (str) - - image paths (same as A_paths)
38
+ """
39
+ # read a image given a random integer index
40
+ AB_path = self.AB_paths[index]
41
+ AB = Image.open(AB_path).convert('RGB')
42
+ # split AB image into A and B
43
+ w, h = AB.size
44
+ w2 = int(w / 2)
45
+ A = AB.crop((0, 0, w2, h))
46
+ B = AB.crop((w2, 0, w, h))
47
+
48
+ # apply the same transform to both A and B
49
+ transform_params = get_params(self.opt, A.size)
50
+ A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
51
+ B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1))
52
+
53
+ A = A_transform(A)
54
+ B = B_transform(B)
55
+
56
+ return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
57
+
58
+ def __len__(self):
59
+ """Return the total number of images in the dataset."""
60
+ return len(self.AB_paths)
data/base_dataset.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ self.root = opt.dataroot
31
+
32
+ @staticmethod
33
+ def modify_commandline_options(parser, is_train):
34
+ """Add new dataset-specific options, and rewrite default values for existing options.
35
+
36
+ Parameters:
37
+ parser -- original option parser
38
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
39
+
40
+ Returns:
41
+ the modified parser.
42
+ """
43
+ return parser
44
+
45
+ @abstractmethod
46
+ def __len__(self):
47
+ """Return the total number of images in the dataset."""
48
+ return 0
49
+
50
+ @abstractmethod
51
+ def __getitem__(self, index):
52
+ """Return a data point and its metadata information.
53
+
54
+ Parameters:
55
+ index - - a random integer for data indexing
56
+
57
+ Returns:
58
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
59
+ """
60
+ pass
61
+
62
+
63
+ def get_params(opt, size):
64
+ w, h = size
65
+ new_h = h
66
+ new_w = w
67
+ if opt.preprocess == 'resize_and_crop':
68
+ new_h = new_w = opt.load_size
69
+ elif opt.preprocess == 'scale_width_and_crop':
70
+ new_w = opt.load_size
71
+ new_h = opt.load_size * h // w
72
+
73
+ x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
74
+ y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
75
+
76
+ flip = random.random() > 0.5
77
+
78
+ return {'crop_pos': (x, y), 'flip': flip}
79
+
80
+
81
+ def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
82
+ transform_list = []
83
+ if grayscale:
84
+ transform_list.append(transforms.Grayscale(1))
85
+ if 'resize' in opt.preprocess:
86
+ osize = [opt.load_size, opt.load_size]
87
+ transform_list.append(transforms.Resize(osize, method))
88
+ elif 'scale_width' in opt.preprocess:
89
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
90
+
91
+ if 'crop' in opt.preprocess:
92
+ if params is None:
93
+ transform_list.append(transforms.RandomCrop(opt.crop_size))
94
+ else:
95
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
96
+
97
+ if opt.preprocess == 'none':
98
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
99
+
100
+ if not opt.no_flip:
101
+ if params is None:
102
+ transform_list.append(transforms.RandomHorizontalFlip())
103
+ elif params['flip']:
104
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
105
+
106
+ if convert:
107
+ transform_list += [transforms.ToTensor()]
108
+ if grayscale:
109
+ transform_list += [transforms.Normalize((0.5,), (0.5,))]
110
+ else:
111
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
112
+ return transforms.Compose(transform_list)
113
+
114
+
115
+ def __make_power_2(img, base, method=Image.BICUBIC):
116
+ ow, oh = img.size
117
+ h = int(round(oh / base) * base)
118
+ w = int(round(ow / base) * base)
119
+ if h == oh and w == ow:
120
+ return img
121
+
122
+ __print_size_warning(ow, oh, w, h)
123
+ return img.resize((w, h), method)
124
+
125
+
126
+ def __scale_width(img, target_size, crop_size, method=Image.BICUBIC):
127
+ ow, oh = img.size
128
+ if ow == target_size and oh >= crop_size:
129
+ return img
130
+ w = target_size
131
+ h = int(max(target_size * oh / ow, crop_size))
132
+ return img.resize((w, h), method)
133
+
134
+
135
+ def __crop(img, pos, size):
136
+ ow, oh = img.size
137
+ x1, y1 = pos
138
+ tw = th = size
139
+ if (ow > tw or oh > th):
140
+ return img.crop((x1, y1, x1 + tw, y1 + th))
141
+ return img
142
+
143
+
144
+ def __flip(img, flip):
145
+ if flip:
146
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
147
+ return img
148
+
149
+
150
+ def __print_size_warning(ow, oh, w, h):
151
+ """Print warning information about image size(only print once)"""
152
+ if not hasattr(__print_size_warning, 'has_printed'):
153
+ print("The image size needs to be a multiple of 4. "
154
+ "The loaded image size was (%d, %d), so it was adjusted to "
155
+ "(%d, %d). This adjustment will be done to all images "
156
+ "whose sizes are not multiples of 4" % (ow, oh, w, h))
157
+ __print_size_warning.has_printed = True
data/colorization_dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from data.base_dataset import BaseDataset, get_transform
3
+ from data.image_folder import make_dataset
4
+ from skimage import color # require skimage
5
+ from PIL import Image
6
+ import numpy as np
7
+ import torchvision.transforms as transforms
8
+
9
+
10
+ class ColorizationDataset(BaseDataset):
11
+ """This dataset class can load a set of natural images in RGB, and convert RGB format into (L, ab) pairs in Lab color space.
12
+
13
+ This dataset is required by pix2pix-based colorization model ('--model colorization')
14
+ """
15
+ @staticmethod
16
+ def modify_commandline_options(parser, is_train):
17
+ """Add new dataset-specific options, and rewrite default values for existing options.
18
+
19
+ Parameters:
20
+ parser -- original option parser
21
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
22
+
23
+ Returns:
24
+ the modified parser.
25
+
26
+ By default, the number of channels for input image is 1 (L) and
27
+ the number of channels for output image is 2 (ab). The direction is from A to B
28
+ """
29
+ parser.set_defaults(input_nc=1, output_nc=2, direction='AtoB')
30
+ return parser
31
+
32
+ def __init__(self, opt):
33
+ """Initialize this dataset class.
34
+
35
+ Parameters:
36
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
37
+ """
38
+ BaseDataset.__init__(self, opt)
39
+ self.dir = os.path.join(opt.dataroot, opt.phase)
40
+ self.AB_paths = sorted(make_dataset(self.dir, opt.max_dataset_size))
41
+ assert(opt.input_nc == 1 and opt.output_nc == 2 and opt.direction == 'AtoB')
42
+ self.transform = get_transform(self.opt, convert=False)
43
+
44
+ def __getitem__(self, index):
45
+ """Return a data point and its metadata information.
46
+
47
+ Parameters:
48
+ index - - a random integer for data indexing
49
+
50
+ Returns a dictionary that contains A, B, A_paths and B_paths
51
+ A (tensor) - - the L channel of an image
52
+ B (tensor) - - the ab channels of the same image
53
+ A_paths (str) - - image paths
54
+ B_paths (str) - - image paths (same as A_paths)
55
+ """
56
+ path = self.AB_paths[index]
57
+ im = Image.open(path).convert('RGB')
58
+ im = self.transform(im)
59
+ im = np.array(im)
60
+ lab = color.rgb2lab(im).astype(np.float32)
61
+ lab_t = transforms.ToTensor()(lab)
62
+ A = lab_t[[0], ...] / 50.0 - 1.0
63
+ B = lab_t[[1, 2], ...] / 110.0
64
+ return {'A': A, 'B': B, 'A_paths': path, 'B_paths': path}
65
+
66
+ def __len__(self):
67
+ """Return the total number of images in the dataset."""
68
+ return len(self.AB_paths)
data/image_folder.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A modified image folder class
2
+
3
+ We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
+ so that this class can load images from both current directory and its subdirectories.
5
+ """
6
+
7
+ import torch.utils.data as data
8
+
9
+ from PIL import Image
10
+ import os
11
+
12
+ IMG_EXTENSIONS = [
13
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
14
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
15
+ '.tif', '.TIF', '.tiff', '.TIFF',
16
+ ]
17
+
18
+
19
+ def is_image_file(filename):
20
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21
+
22
+
23
+ def make_dataset(dir, max_dataset_size=float("inf")):
24
+ images = []
25
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
26
+
27
+ for root, _, fnames in sorted(os.walk(dir)):
28
+ for fname in fnames:
29
+ if is_image_file(fname):
30
+ path = os.path.join(root, fname)
31
+ images.append(path)
32
+ return images[:min(max_dataset_size, len(images))]
33
+
34
+
35
+ def default_loader(path):
36
+ return Image.open(path).convert('RGB')
37
+
38
+
39
+ class ImageFolder(data.Dataset):
40
+
41
+ def __init__(self, root, transform=None, return_paths=False,
42
+ loader=default_loader):
43
+ imgs = make_dataset(root)
44
+ if len(imgs) == 0:
45
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
46
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
47
+
48
+ self.root = root
49
+ self.imgs = imgs
50
+ self.transform = transform
51
+ self.return_paths = return_paths
52
+ self.loader = loader
53
+
54
+ def __getitem__(self, index):
55
+ path = self.imgs[index]
56
+ img = self.loader(path)
57
+ if self.transform is not None:
58
+ img = self.transform(img)
59
+ if self.return_paths:
60
+ return img, path
61
+ else:
62
+ return img
63
+
64
+ def __len__(self):
65
+ return len(self.imgs)
data/single_dataset.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.base_dataset import BaseDataset, get_transform
2
+ from data.image_folder import make_dataset
3
+ from PIL import Image
4
+
5
+
6
+ class SingleDataset(BaseDataset):
7
+ """This dataset class can load a set of images specified by the path --dataroot /path/to/data.
8
+
9
+ It can be used for generating CycleGAN results only for one side with the model option '-model test'.
10
+ """
11
+
12
+ def __init__(self, opt):
13
+ """Initialize this dataset class.
14
+
15
+ Parameters:
16
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
17
+ """
18
+ BaseDataset.__init__(self, opt)
19
+ self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
20
+ input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
21
+ self.transform = get_transform(opt, grayscale=(input_nc == 1))
22
+
23
+ def __getitem__(self, index):
24
+ """Return a data point and its metadata information.
25
+
26
+ Parameters:
27
+ index - - a random integer for data indexing
28
+
29
+ Returns a dictionary that contains A and A_paths
30
+ A(tensor) - - an image in one domain
31
+ A_paths(str) - - the path of the image
32
+ """
33
+ A_path = self.A_paths[index]
34
+ A_img = Image.open(A_path).convert('RGB')
35
+ A = self.transform(A_img)
36
+ return {'A': A, 'A_paths': A_path}
37
+
38
+ def __len__(self):
39
+ """Return the total number of images in the dataset."""
40
+ return len(self.A_paths)
data/template_dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset class template
2
+
3
+ This module provides a template for users to implement custom datasets.
4
+ You can specify '--dataset_mode template' to use this dataset.
5
+ The class name should be consistent with both the filename and its dataset_mode option.
6
+ The filename should be <dataset_mode>_dataset.py
7
+ The class name should be <Dataset_mode>Dataset.py
8
+ You need to implement the following functions:
9
+ -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
10
+ -- <__init__>: Initialize this dataset class.
11
+ -- <__getitem__>: Return a data point and its metadata information.
12
+ -- <__len__>: Return the number of images.
13
+ """
14
+ from data.base_dataset import BaseDataset, get_transform
15
+ # from data.image_folder import make_dataset
16
+ # from PIL import Image
17
+
18
+
19
+ class TemplateDataset(BaseDataset):
20
+ """A template dataset class for you to implement custom datasets."""
21
+ @staticmethod
22
+ def modify_commandline_options(parser, is_train):
23
+ """Add new dataset-specific options, and rewrite default values for existing options.
24
+
25
+ Parameters:
26
+ parser -- original option parser
27
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28
+
29
+ Returns:
30
+ the modified parser.
31
+ """
32
+ parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33
+ parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34
+ return parser
35
+
36
+ def __init__(self, opt):
37
+ """Initialize this dataset class.
38
+
39
+ Parameters:
40
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41
+
42
+ A few things can be done here.
43
+ - save the options (have been done in BaseDataset)
44
+ - get image paths and meta information of the dataset.
45
+ - define the image transformation.
46
+ """
47
+ # save the option and dataset root
48
+ BaseDataset.__init__(self, opt)
49
+ # get the image paths of your dataset;
50
+ self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51
+ # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
52
+ self.transform = get_transform(opt)
53
+
54
+ def __getitem__(self, index):
55
+ """Return a data point and its metadata information.
56
+
57
+ Parameters:
58
+ index -- a random integer for data indexing
59
+
60
+ Returns:
61
+ a dictionary of data with their names. It usually contains the data itself and its metadata information.
62
+
63
+ Step 1: get a random image path: e.g., path = self.image_paths[index]
64
+ Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65
+ Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66
+ Step 4: return a data point as a dictionary.
67
+ """
68
+ path = 'temp' # needs to be a string
69
+ data_A = None # needs to be a tensor
70
+ data_B = None # needs to be a tensor
71
+ return {'data_A': data_A, 'data_B': data_B, 'path': path}
72
+
73
+ def __len__(self):
74
+ """Return the total number of images."""
75
+ return len(self.image_paths)
data/unaligned_dataset.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from data.base_dataset import BaseDataset, get_transform
3
+ from data.image_folder import make_dataset
4
+ from PIL import Image
5
+ import random
6
+
7
+
8
+ class UnalignedDataset(BaseDataset):
9
+ """
10
+ This dataset class can load unaligned/unpaired datasets.
11
+
12
+ It requires two directories to host training images from domain A '/path/to/data/trainA'
13
+ and from domain B '/path/to/data/trainB' respectively.
14
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
15
+ Similarly, you need to prepare two directories:
16
+ '/path/to/data/testA' and '/path/to/data/testB' during test time.
17
+ """
18
+
19
+ def __init__(self, opt):
20
+ """Initialize this dataset class.
21
+
22
+ Parameters:
23
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
24
+ """
25
+ BaseDataset.__init__(self, opt)
26
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
27
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
28
+
29
+ self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
30
+ self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
31
+ self.A_size = len(self.A_paths) # get the size of dataset A
32
+ self.B_size = len(self.B_paths) # get the size of dataset B
33
+ btoA = self.opt.direction == 'BtoA'
34
+ input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
35
+ output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
36
+ self.transform_A = get_transform(self.opt, grayscale=(input_nc == 1))
37
+ self.transform_B = get_transform(self.opt, grayscale=(output_nc == 1))
38
+
39
+ def __getitem__(self, index):
40
+ """Return a data point and its metadata information.
41
+
42
+ Parameters:
43
+ index (int) -- a random integer for data indexing
44
+
45
+ Returns a dictionary that contains A, B, A_paths and B_paths
46
+ A (tensor) -- an image in the input domain
47
+ B (tensor) -- its corresponding image in the target domain
48
+ A_paths (str) -- image paths
49
+ B_paths (str) -- image paths
50
+ """
51
+ A_path = self.A_paths[index % self.A_size] # make sure index is within then range
52
+ if self.opt.serial_batches: # make sure index is within then range
53
+ index_B = index % self.B_size
54
+ else: # randomize the index for domain B to avoid fixed pairs.
55
+ index_B = random.randint(0, self.B_size - 1)
56
+ B_path = self.B_paths[index_B]
57
+ A_img = Image.open(A_path).convert('RGB')
58
+ B_img = Image.open(B_path).convert('RGB')
59
+ # apply image transformation
60
+ A = self.transform_A(A_img)
61
+ B = self.transform_B(B_img)
62
+
63
+ return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
64
+
65
+ def __len__(self):
66
+ """Return the total number of images in the dataset.
67
+
68
+ As we have two datasets with potentially different number of images,
69
+ we take a maximum of
70
+ """
71
+ return max(self.A_size, self.B_size)