Spaces:
Runtime error
Runtime error
| import random | |
| import numpy as np | |
| from pathlib import Path | |
| from ResizeRight.resize_right import resize | |
| from einops import rearrange | |
| import torch | |
| import torchvision as thv | |
| from torch.utils.data import Dataset | |
| from utils import util_sisr | |
| from utils import util_image | |
| from utils import util_common | |
| from basicsr.data.realesrgan_dataset import RealESRGANDataset | |
| from .ffhq_degradation_dataset import FFHQDegradationDataset | |
| def get_transforms(transform_type, out_size, sf): | |
| if transform_type == 'default': | |
| transform = thv.transforms.Compose([ | |
| util_image.SpatialAug(), | |
| thv.transforms.ToTensor(), | |
| thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
| ]) | |
| elif transform_type == 'face': | |
| transform = thv.transforms.Compose([ | |
| thv.transforms.ToTensor(), | |
| thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
| ]) | |
| elif transform_type == 'bicubic': | |
| transform = thv.transforms.Compose([ | |
| util_sisr.Bicubic(1/sf), | |
| thv.transforms.ToTensor(), | |
| thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
| ]) | |
| else: | |
| raise ValueError(f'Unexpected transform_variant {transform_variant}') | |
| return transform | |
| def create_dataset(dataset_config): | |
| if dataset_config['type'] == 'gfpgan': | |
| dataset = FFHQDegradationDataset(dataset_config['params']) | |
| elif dataset_config['type'] == 'face': | |
| dataset = BaseDatasetFace(**dataset_config['params']) | |
| elif dataset_config['type'] == 'bicubic': | |
| dataset = DatasetBicubic(**dataset_config['params']) | |
| elif dataset_config['type'] == 'folder': | |
| dataset = BaseDataFolder(**dataset_config['params']) | |
| elif dataset_config['type'] == 'realesrgan': | |
| dataset = RealESRGANDataset(dataset_config['params']) | |
| else: | |
| raise NotImplementedError(dataset_config['type']) | |
| return dataset | |
| class BaseDatasetFace(Dataset): | |
| def __init__(self, celeba_txt=None, | |
| ffhq_txt=None, | |
| out_size=256, | |
| transform_type='face', | |
| sf=None, | |
| length=None): | |
| super().__init__() | |
| self.files_names = util_common.readline_txt(celeba_txt) + util_common.readline_txt(ffhq_txt) | |
| if length is None: | |
| self.length = len(self.files_names) | |
| else: | |
| self.length = length | |
| self.transform = get_transforms(transform_type, out_size, sf) | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, index): | |
| im_path = self.files_names[index] | |
| im = util_image.imread(im_path, chn='rgb', dtype='uint8') | |
| im = self.transform(im) | |
| return {'image':im,} | |
| class DatasetBicubic(Dataset): | |
| def __init__(self, | |
| files_txt=None, | |
| val_dir=None, | |
| ext='png', | |
| sf=None, | |
| up_back=False, | |
| need_gt_path=False, | |
| length=None): | |
| super().__init__() | |
| if val_dir is None: | |
| self.files_names = util_common.readline_txt(files_txt) | |
| else: | |
| self.files_names = [str(x) for x in Path(val_dir).glob(f"*.{ext}")] | |
| self.sf = sf | |
| self.up_back = up_back | |
| self.need_gt_path = need_gt_path | |
| if length is None: | |
| self.length = len(self.files_names) | |
| else: | |
| self.length = length | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, index): | |
| im_path = self.files_names[index] | |
| im_gt = util_image.imread(im_path, chn='rgb', dtype='float32') | |
| im_lq = resize(im_gt, scale_factors=1/self.sf) | |
| if self.up_back: | |
| im_lq = resize(im_lq, scale_factors=self.sf) | |
| im_lq = rearrange(im_lq, 'h w c -> c h w') | |
| im_lq = torch.from_numpy(im_lq).type(torch.float32) | |
| im_gt = rearrange(im_gt, 'h w c -> c h w') | |
| im_gt = torch.from_numpy(im_gt).type(torch.float32) | |
| if self.need_gt_path: | |
| return {'lq':im_lq, 'gt':im_gt, 'gt_path':im_path} | |
| else: | |
| return {'lq':im_lq, 'gt':im_gt} | |
| class BaseDataFolder(Dataset): | |
| def __init__( | |
| self, | |
| dir_path, | |
| dir_path_gt, | |
| need_gt_path=True, | |
| length=None, | |
| ext=['png', 'jpg', 'jpeg', 'JPEG', 'bmp'], | |
| mean=0.5, | |
| std=0.5, | |
| ): | |
| super(BaseDataFolder, self).__init__() | |
| if isinstance(ext, str): | |
| files_path = [str(x) for x in Path(dir_path).glob(f'*.{ext}')] | |
| else: | |
| assert isinstance(ext, list) or isinstance(ext, tuple) | |
| files_path = [] | |
| for current_ext in ext: | |
| files_path.extend([str(x) for x in Path(dir_path).glob(f'*.{current_ext}')]) | |
| self.files_path = files_path if length is None else files_path[:length] | |
| self.dir_path_gt = dir_path_gt | |
| self.need_gt_path = need_gt_path | |
| self.mean=mean | |
| self.std=std | |
| def __len__(self): | |
| return len(self.files_path) | |
| def __getitem__(self, index): | |
| im_path = self.files_path[index] | |
| im = util_image.imread(im_path, chn='rgb', dtype='float32') | |
| im = util_image.normalize_np(im, mean=self.mean, std=self.std, reverse=False) | |
| im = rearrange(im, 'h w c -> c h w') | |
| out_dict = {'image':im.astype(np.float32), 'lq':im.astype(np.float32)} | |
| if self.need_gt_path: | |
| out_dict['path'] = im_path | |
| if self.dir_path_gt is not None: | |
| gt_path = str(Path(self.dir_path_gt) / Path(im_path).name) | |
| im_gt = util_image.imread(gt_path, chn='rgb', dtype='float32') | |
| im_gt = util_image.normalize_np(im_gt, mean=self.mean, std=self.std, reverse=False) | |
| im_gt = rearrange(im_gt, 'h w c -> c h w') | |
| out_dict['gt'] = im_gt.astype(np.float32) | |
| return out_dict | |