Spaces:
Runtime error
Runtime error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
from data.base_dataset import BaseDataset, get_params, get_transform | |
from PIL import Image | |
import util.util as util | |
import os | |
import torch | |
class FaceTestDataset(BaseDataset): | |
def modify_commandline_options(parser, is_train): | |
parser.add_argument( | |
"--no_pairing_check", | |
action="store_true", | |
help="If specified, skip sanity check of correct label-image file pairing", | |
) | |
# parser.set_defaults(contain_dontcare_label=False) | |
# parser.set_defaults(no_instance=True) | |
return parser | |
def initialize(self, opt): | |
self.opt = opt | |
image_path = os.path.join(opt.dataroot, opt.old_face_folder) | |
label_path = os.path.join(opt.dataroot, opt.old_face_label_folder) | |
image_list = os.listdir(image_path) | |
image_list = sorted(image_list) | |
# image_list=image_list[:opt.max_dataset_size] | |
self.label_paths = label_path ## Just the root dir | |
self.image_paths = image_list ## All the image name | |
self.parts = [ | |
"skin", | |
"hair", | |
"l_brow", | |
"r_brow", | |
"l_eye", | |
"r_eye", | |
"eye_g", | |
"l_ear", | |
"r_ear", | |
"ear_r", | |
"nose", | |
"mouth", | |
"u_lip", | |
"l_lip", | |
"neck", | |
"neck_l", | |
"cloth", | |
"hat", | |
] | |
size = len(self.image_paths) | |
self.dataset_size = size | |
def __getitem__(self, index): | |
params = get_params(self.opt, (-1, -1)) | |
image_name = self.image_paths[index] | |
image_path = os.path.join(self.opt.dataroot, self.opt.old_face_folder, image_name) | |
image = Image.open(image_path) | |
image = image.convert("RGB") | |
transform_image = get_transform(self.opt, params) | |
image_tensor = transform_image(image) | |
img_name = image_name[:-4] | |
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) | |
full_label = [] | |
cnt = 0 | |
for each_part in self.parts: | |
part_name = img_name + "_" + each_part + ".png" | |
part_url = os.path.join(self.label_paths, part_name) | |
if os.path.exists(part_url): | |
label = Image.open(part_url).convert("RGB") | |
label_tensor = transform_label(label) ## 3 channels and pixel [0,1] | |
full_label.append(label_tensor[0]) | |
else: | |
current_part = torch.zeros((self.opt.load_size, self.opt.load_size)) | |
full_label.append(current_part) | |
cnt += 1 | |
full_label_tensor = torch.stack(full_label, 0) | |
input_dict = { | |
"label": full_label_tensor, | |
"image": image_tensor, | |
"path": image_path, | |
} | |
return input_dict | |
def __len__(self): | |
return self.dataset_size | |