# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms import numpy as np import random class BaseDataset(data.Dataset): def __init__(self): super(BaseDataset, self).__init__() def name(self): return 'BaseDataset' def initialize(self, opt): pass def get_params(opt, size): w, h = size new_h = h new_w = w if opt.resize_or_crop == 'resize_and_crop': new_h = new_w = opt.loadSize if opt.resize_or_crop == 'scale_width_and_crop': # we scale the shorter side into 256 if w 0.5 return {'crop_pos': (x, y), 'flip': flip} def get_transform(opt, params, method=Image.BICUBIC, normalize=True): transform_list = [] if 'resize' in opt.resize_or_crop: osize = [opt.loadSize, opt.loadSize] transform_list.append(transforms.Scale(osize, method)) elif 'scale_width' in opt.resize_or_crop: # transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) ## Here , We want the shorter side to match 256, and Scale will finish it. transform_list.append(transforms.Scale(256,method)) if 'crop' in opt.resize_or_crop: if opt.isTrain: transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) else: if opt.test_random_crop: transform_list.append(transforms.RandomCrop(opt.fineSize)) else: transform_list.append(transforms.CenterCrop(opt.fineSize)) ## when testing, for ablation study, choose center_crop directly. if opt.resize_or_crop == 'none': base = float(2 ** opt.n_downsample_global) if opt.netG == 'local': base *= (2 ** opt.n_local_enhancers) transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) if opt.isTrain and not opt.no_flip: transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) transform_list += [transforms.ToTensor()] if normalize: transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) def normalize(): return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) def __make_power_2(img, base, method=Image.BICUBIC): ow, oh = img.size h = int(round(oh / base) * base) w = int(round(ow / base) * base) if (h == oh) and (w == ow): return img return img.resize((w, h), method) def __scale_width(img, target_width, method=Image.BICUBIC): ow, oh = img.size if (ow == target_width): return img w = target_width h = int(target_width * oh / ow) return img.resize((w, h), method) def __crop(img, pos, size): ow, oh = img.size x1, y1 = pos tw = th = size if (ow > tw or oh > th): return img.crop((x1, y1, x1 + tw, y1 + th)) return img def __flip(img, flip): if flip: return img.transpose(Image.FLIP_LEFT_RIGHT) return img