BOPBTL / Face_Enhancement /data /pix2pix_dataset.py
manhkhanhUIT's picture
Add code
7fab858
raw
history blame contribute delete
No virus
3.91 kB
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from data.base_dataset import BaseDataset, get_params, get_transform
from PIL import Image
import util.util as util
import os
class Pix2pixDataset(BaseDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument(
"--no_pairing_check",
action="store_true",
help="If specified, skip sanity check of correct label-image file pairing",
)
return parser
def initialize(self, opt):
self.opt = opt
label_paths, image_paths, instance_paths = self.get_paths(opt)
util.natural_sort(label_paths)
util.natural_sort(image_paths)
if not opt.no_instance:
util.natural_sort(instance_paths)
label_paths = label_paths[: opt.max_dataset_size]
image_paths = image_paths[: opt.max_dataset_size]
instance_paths = instance_paths[: opt.max_dataset_size]
if not opt.no_pairing_check:
for path1, path2 in zip(label_paths, image_paths):
assert self.paths_match(path1, path2), (
"The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this."
% (path1, path2)
)
self.label_paths = label_paths
self.image_paths = image_paths
self.instance_paths = instance_paths
size = len(self.label_paths)
self.dataset_size = size
def get_paths(self, opt):
label_paths = []
image_paths = []
instance_paths = []
assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"
return label_paths, image_paths, instance_paths
def paths_match(self, path1, path2):
filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]
filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]
return filename1_without_ext == filename2_without_ext
def __getitem__(self, index):
# Label Image
label_path = self.label_paths[index]
label = Image.open(label_path)
params = get_params(self.opt, label.size)
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
label_tensor = transform_label(label) * 255.0
label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc
# input image (real images)
image_path = self.image_paths[index]
assert self.paths_match(
label_path, image_path
), "The label_path %s and image_path %s don't match." % (label_path, image_path)
image = Image.open(image_path)
image = image.convert("RGB")
transform_image = get_transform(self.opt, params)
image_tensor = transform_image(image)
# if using instance maps
if self.opt.no_instance:
instance_tensor = 0
else:
instance_path = self.instance_paths[index]
instance = Image.open(instance_path)
if instance.mode == "L":
instance_tensor = transform_label(instance) * 255
instance_tensor = instance_tensor.long()
else:
instance_tensor = transform_label(instance)
input_dict = {
"label": label_tensor,
"instance": instance_tensor,
"image": image_tensor,
"path": image_path,
}
# Give subclasses a chance to modify the final output
self.postprocess(input_dict)
return input_dict
def postprocess(self, input_dict):
return input_dict
def __len__(self):
return self.dataset_size