import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms class BaseDataset(data.Dataset): def __init__(self): super(BaseDataset, self).__init__() def name(self): return 'BaseDataset' @staticmethod def modify_commandline_options(parser, is_train): return parser def initialize(self, opt): pass def __len__(self): return 0 def get_transform(opt): transform_list = [] if opt.resize_or_crop == 'resize_and_crop': osize = [opt.loadSize, opt.fineSize] transform_list.append(transforms.Resize(osize, Image.BICUBIC)) transform_list.append(transforms.RandomCrop(opt.fineSize)) elif opt.resize_or_crop == 'crop': transform_list.append(transforms.RandomCrop(opt.fineSize)) elif opt.resize_or_crop == 'scale_width': transform_list.append(transforms.Lambda( lambda img: __scale_width(img, opt.fineSize))) elif opt.resize_or_crop == 'scale_width_and_crop': transform_list.append(transforms.Lambda( lambda img: __scale_width(img, opt.loadSize))) transform_list.append(transforms.RandomCrop(opt.fineSize)) elif opt.resize_or_crop == 'none': transform_list.append(transforms.Lambda( lambda img: __adjust(img))) else: raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop) if opt.isTrain and not opt.no_flip: transform_list.append(transforms.RandomHorizontalFlip()) transform_list += [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) # just modify the width and height to be multiple of 4 def __adjust(img): ow, oh = img.size # the size needs to be a multiple of this number, # because going through generator network may change img size # and eventually cause size mismatch error mult = 4 if ow % mult == 0 and oh % mult == 0: return img w = (ow - 1) // mult w = (w + 1) * mult h = (oh - 1) // mult h = (h + 1) * mult if ow != w or oh != h: __print_size_warning(ow, oh, w, h) return img.resize((w, h), Image.BICUBIC) def __scale_width(img, target_width): ow, oh = img.size # the size needs to be a multiple of this number, # because going through generator network may change img size # and eventually cause size mismatch error mult = 4 assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult if (ow == target_width and oh % mult == 0): return img w = target_width target_height = int(target_width * oh / ow) m = (target_height - 1) // mult h = (m + 1) * mult if target_height != h: __print_size_warning(target_width, target_height, w, h) return img.resize((w, h), Image.BICUBIC) def __print_size_warning(ow, oh, w, h): if not hasattr(__print_size_warning, 'has_printed'): print("The image size needs to be a multiple of 4. " "The loaded image size was (%d, %d), so it was adjusted to " "(%d, %d). This adjustment will be done to all images " "whose sizes are not multiples of 4" % (ow, oh, w, h)) __print_size_warning.has_printed = True