Spaces:
Runtime error
Runtime error
import os.path | |
from data.base_dataset import BaseDataset, get_params, get_transform, normalize | |
from data.image_folder import make_dataset | |
from PIL import Image | |
class AlignedDataset(BaseDataset): | |
def initialize(self, opt): | |
self.opt = opt | |
self.root = opt.dataroot | |
### input A (label maps) | |
dir_A = '_A' if self.opt.label_nc == 0 else '_label' | |
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A) | |
self.A_paths = sorted(make_dataset(self.dir_A)) | |
### input B (real images) | |
if opt.isTrain or opt.use_encoded_image: | |
dir_B = '_B' if self.opt.label_nc == 0 else '_img' | |
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B) | |
self.B_paths = sorted(make_dataset(self.dir_B)) | |
### instance maps | |
if not opt.no_instance: | |
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst') | |
self.inst_paths = sorted(make_dataset(self.dir_inst)) | |
### load precomputed instance-wise encoded features | |
if opt.load_features: | |
self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat') | |
print('----------- loading features from %s ----------' % self.dir_feat) | |
self.feat_paths = sorted(make_dataset(self.dir_feat)) | |
self.dataset_size = len(self.A_paths) | |
def __getitem__(self, index): | |
### input A (label maps) | |
A_path = self.A_paths[index] | |
A = Image.open(A_path) | |
params = get_params(self.opt, A.size) | |
if self.opt.label_nc == 0: | |
transform_A = get_transform(self.opt, params) | |
A_tensor = transform_A(A.convert('RGB')) | |
else: | |
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) | |
A_tensor = transform_A(A) * 255.0 | |
B_tensor = inst_tensor = feat_tensor = 0 | |
### input B (real images) | |
if self.opt.isTrain or self.opt.use_encoded_image: | |
B_path = self.B_paths[index] | |
B = Image.open(B_path).convert('RGB') | |
transform_B = get_transform(self.opt, params) | |
B_tensor = transform_B(B) | |
### if using instance maps | |
if not self.opt.no_instance: | |
inst_path = self.inst_paths[index] | |
inst = Image.open(inst_path) | |
inst_tensor = transform_A(inst) | |
if self.opt.load_features: | |
feat_path = self.feat_paths[index] | |
feat = Image.open(feat_path).convert('RGB') | |
norm = normalize() | |
feat_tensor = norm(transform_A(feat)) | |
input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, | |
'feat': feat_tensor, 'path': A_path} | |
return input_dict | |
def __len__(self): | |
return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize | |
def name(self): | |
return 'AlignedDataset' |