BOPBTL / Face_Enhancement /data /face_dataset.py
manhkhanhUIT's picture
Add code
7fab858
raw
history blame contribute delete
No virus
3.03 kB
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from data.base_dataset import BaseDataset, get_params, get_transform
from PIL import Image
import util.util as util
import os
import torch
class FaceTestDataset(BaseDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument(
"--no_pairing_check",
action="store_true",
help="If specified, skip sanity check of correct label-image file pairing",
)
# parser.set_defaults(contain_dontcare_label=False)
# parser.set_defaults(no_instance=True)
return parser
def initialize(self, opt):
self.opt = opt
image_path = os.path.join(opt.dataroot, opt.old_face_folder)
label_path = os.path.join(opt.dataroot, opt.old_face_label_folder)
image_list = os.listdir(image_path)
image_list = sorted(image_list)
# image_list=image_list[:opt.max_dataset_size]
self.label_paths = label_path ## Just the root dir
self.image_paths = image_list ## All the image name
self.parts = [
"skin",
"hair",
"l_brow",
"r_brow",
"l_eye",
"r_eye",
"eye_g",
"l_ear",
"r_ear",
"ear_r",
"nose",
"mouth",
"u_lip",
"l_lip",
"neck",
"neck_l",
"cloth",
"hat",
]
size = len(self.image_paths)
self.dataset_size = size
def __getitem__(self, index):
params = get_params(self.opt, (-1, -1))
image_name = self.image_paths[index]
image_path = os.path.join(self.opt.dataroot, self.opt.old_face_folder, image_name)
image = Image.open(image_path)
image = image.convert("RGB")
transform_image = get_transform(self.opt, params)
image_tensor = transform_image(image)
img_name = image_name[:-4]
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
full_label = []
cnt = 0
for each_part in self.parts:
part_name = img_name + "_" + each_part + ".png"
part_url = os.path.join(self.label_paths, part_name)
if os.path.exists(part_url):
label = Image.open(part_url).convert("RGB")
label_tensor = transform_label(label) ## 3 channels and pixel [0,1]
full_label.append(label_tensor[0])
else:
current_part = torch.zeros((self.opt.load_size, self.opt.load_size))
full_label.append(current_part)
cnt += 1
full_label_tensor = torch.stack(full_label, 0)
input_dict = {
"label": full_label_tensor,
"image": image_tensor,
"path": image_path,
}
return input_dict
def __len__(self):
return self.dataset_size