import random import numpy as np from pathlib import Path from ResizeRight.resize_right import resize from einops import rearrange import torch import torchvision as thv from torch.utils.data import Dataset from utils import util_sisr from utils import util_image from utils import util_common from basicsr.data.realesrgan_dataset import RealESRGANDataset from .ffhq_degradation_dataset import FFHQDegradationDataset def get_transforms(transform_type, out_size, sf): if transform_type == 'default': transform = thv.transforms.Compose([ util_image.SpatialAug(), thv.transforms.ToTensor(), thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) elif transform_type == 'face': transform = thv.transforms.Compose([ thv.transforms.ToTensor(), thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) elif transform_type == 'bicubic': transform = thv.transforms.Compose([ util_sisr.Bicubic(1/sf), thv.transforms.ToTensor(), thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) else: raise ValueError(f'Unexpected transform_variant {transform_variant}') return transform def create_dataset(dataset_config): if dataset_config['type'] == 'gfpgan': dataset = FFHQDegradationDataset(dataset_config['params']) elif dataset_config['type'] == 'face': dataset = BaseDatasetFace(**dataset_config['params']) elif dataset_config['type'] == 'bicubic': dataset = DatasetBicubic(**dataset_config['params']) elif dataset_config['type'] == 'folder': dataset = BaseDataFolder(**dataset_config['params']) elif dataset_config['type'] == 'realesrgan': dataset = RealESRGANDataset(dataset_config['params']) else: raise NotImplementedError(dataset_config['type']) return dataset class BaseDatasetFace(Dataset): def __init__(self, celeba_txt=None, ffhq_txt=None, out_size=256, transform_type='face', sf=None, length=None): super().__init__() self.files_names = util_common.readline_txt(celeba_txt) + util_common.readline_txt(ffhq_txt) if length is None: self.length = len(self.files_names) else: self.length = length self.transform = get_transforms(transform_type, out_size, sf) def __len__(self): return self.length def __getitem__(self, index): im_path = self.files_names[index] im = util_image.imread(im_path, chn='rgb', dtype='uint8') im = self.transform(im) return {'image':im,} class DatasetBicubic(Dataset): def __init__(self, files_txt=None, val_dir=None, ext='png', sf=None, up_back=False, need_gt_path=False, length=None): super().__init__() if val_dir is None: self.files_names = util_common.readline_txt(files_txt) else: self.files_names = [str(x) for x in Path(val_dir).glob(f"*.{ext}")] self.sf = sf self.up_back = up_back self.need_gt_path = need_gt_path if length is None: self.length = len(self.files_names) else: self.length = length def __len__(self): return self.length def __getitem__(self, index): im_path = self.files_names[index] im_gt = util_image.imread(im_path, chn='rgb', dtype='float32') im_lq = resize(im_gt, scale_factors=1/self.sf) if self.up_back: im_lq = resize(im_lq, scale_factors=self.sf) im_lq = rearrange(im_lq, 'h w c -> c h w') im_lq = torch.from_numpy(im_lq).type(torch.float32) im_gt = rearrange(im_gt, 'h w c -> c h w') im_gt = torch.from_numpy(im_gt).type(torch.float32) if self.need_gt_path: return {'lq':im_lq, 'gt':im_gt, 'gt_path':im_path} else: return {'lq':im_lq, 'gt':im_gt} class BaseDataFolder(Dataset): def __init__( self, dir_path, dir_path_gt, need_gt_path=True, length=None, ext=['png', 'jpg', 'jpeg', 'JPEG', 'bmp'], mean=0.5, std=0.5, ): super(BaseDataFolder, self).__init__() if isinstance(ext, str): files_path = [str(x) for x in Path(dir_path).glob(f'*.{ext}')] else: assert isinstance(ext, list) or isinstance(ext, tuple) files_path = [] for current_ext in ext: files_path.extend([str(x) for x in Path(dir_path).glob(f'*.{current_ext}')]) self.files_path = files_path if length is None else files_path[:length] self.dir_path_gt = dir_path_gt self.need_gt_path = need_gt_path self.mean=mean self.std=std def __len__(self): return len(self.files_path) def __getitem__(self, index): im_path = self.files_path[index] im = util_image.imread(im_path, chn='rgb', dtype='float32') im = util_image.normalize_np(im, mean=self.mean, std=self.std, reverse=False) im = rearrange(im, 'h w c -> c h w') out_dict = {'image':im.astype(np.float32), 'lq':im.astype(np.float32)} if self.need_gt_path: out_dict['path'] = im_path if self.dir_path_gt is not None: gt_path = str(Path(self.dir_path_gt) / Path(im_path).name) im_gt = util_image.imread(gt_path, chn='rgb', dtype='float32') im_gt = util_image.normalize_np(im_gt, mean=self.mean, std=self.std, reverse=False) im_gt = rearrange(im_gt, 'h w c -> c h w') out_dict['gt'] = im_gt.astype(np.float32) return out_dict