Spaces:
Sleeping
Sleeping
import random | |
import numpy as np | |
from pathlib import Path | |
from ResizeRight.resize_right import resize | |
from einops import rearrange | |
import torch | |
import torchvision as thv | |
from torch.utils.data import Dataset | |
from utils import util_sisr | |
from utils import util_image | |
from utils import util_common | |
from basicsr.data.realesrgan_dataset import RealESRGANDataset | |
from .ffhq_degradation_dataset import FFHQDegradationDataset | |
def get_transforms(transform_type, out_size, sf): | |
if transform_type == 'default': | |
transform = thv.transforms.Compose([ | |
util_image.SpatialAug(), | |
thv.transforms.ToTensor(), | |
thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
]) | |
elif transform_type == 'face': | |
transform = thv.transforms.Compose([ | |
thv.transforms.ToTensor(), | |
thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
]) | |
elif transform_type == 'bicubic': | |
transform = thv.transforms.Compose([ | |
util_sisr.Bicubic(1/sf), | |
thv.transforms.ToTensor(), | |
thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) | |
]) | |
else: | |
raise ValueError(f'Unexpected transform_variant {transform_variant}') | |
return transform | |
def create_dataset(dataset_config): | |
if dataset_config['type'] == 'gfpgan': | |
dataset = FFHQDegradationDataset(dataset_config['params']) | |
elif dataset_config['type'] == 'face': | |
dataset = BaseDatasetFace(**dataset_config['params']) | |
elif dataset_config['type'] == 'bicubic': | |
dataset = DatasetBicubic(**dataset_config['params']) | |
elif dataset_config['type'] == 'folder': | |
dataset = BaseDataFolder(**dataset_config['params']) | |
elif dataset_config['type'] == 'realesrgan': | |
dataset = RealESRGANDataset(dataset_config['params']) | |
else: | |
raise NotImplementedError(dataset_config['type']) | |
return dataset | |
class BaseDatasetFace(Dataset): | |
def __init__(self, celeba_txt=None, | |
ffhq_txt=None, | |
out_size=256, | |
transform_type='face', | |
sf=None, | |
length=None): | |
super().__init__() | |
self.files_names = util_common.readline_txt(celeba_txt) + util_common.readline_txt(ffhq_txt) | |
if length is None: | |
self.length = len(self.files_names) | |
else: | |
self.length = length | |
self.transform = get_transforms(transform_type, out_size, sf) | |
def __len__(self): | |
return self.length | |
def __getitem__(self, index): | |
im_path = self.files_names[index] | |
im = util_image.imread(im_path, chn='rgb', dtype='uint8') | |
im = self.transform(im) | |
return {'image':im,} | |
class DatasetBicubic(Dataset): | |
def __init__(self, | |
files_txt=None, | |
val_dir=None, | |
ext='png', | |
sf=None, | |
up_back=False, | |
need_gt_path=False, | |
length=None): | |
super().__init__() | |
if val_dir is None: | |
self.files_names = util_common.readline_txt(files_txt) | |
else: | |
self.files_names = [str(x) for x in Path(val_dir).glob(f"*.{ext}")] | |
self.sf = sf | |
self.up_back = up_back | |
self.need_gt_path = need_gt_path | |
if length is None: | |
self.length = len(self.files_names) | |
else: | |
self.length = length | |
def __len__(self): | |
return self.length | |
def __getitem__(self, index): | |
im_path = self.files_names[index] | |
im_gt = util_image.imread(im_path, chn='rgb', dtype='float32') | |
im_lq = resize(im_gt, scale_factors=1/self.sf) | |
if self.up_back: | |
im_lq = resize(im_lq, scale_factors=self.sf) | |
im_lq = rearrange(im_lq, 'h w c -> c h w') | |
im_lq = torch.from_numpy(im_lq).type(torch.float32) | |
im_gt = rearrange(im_gt, 'h w c -> c h w') | |
im_gt = torch.from_numpy(im_gt).type(torch.float32) | |
if self.need_gt_path: | |
return {'lq':im_lq, 'gt':im_gt, 'gt_path':im_path} | |
else: | |
return {'lq':im_lq, 'gt':im_gt} | |
class BaseDataFolder(Dataset): | |
def __init__( | |
self, | |
dir_path, | |
dir_path_gt, | |
need_gt_path=True, | |
length=None, | |
ext=['png', 'jpg', 'jpeg', 'JPEG', 'bmp'], | |
mean=0.5, | |
std=0.5, | |
): | |
super(BaseDataFolder, self).__init__() | |
if isinstance(ext, str): | |
files_path = [str(x) for x in Path(dir_path).glob(f'*.{ext}')] | |
else: | |
assert isinstance(ext, list) or isinstance(ext, tuple) | |
files_path = [] | |
for current_ext in ext: | |
files_path.extend([str(x) for x in Path(dir_path).glob(f'*.{current_ext}')]) | |
self.files_path = files_path if length is None else files_path[:length] | |
self.dir_path_gt = dir_path_gt | |
self.need_gt_path = need_gt_path | |
self.mean=mean | |
self.std=std | |
def __len__(self): | |
return len(self.files_path) | |
def __getitem__(self, index): | |
im_path = self.files_path[index] | |
im = util_image.imread(im_path, chn='rgb', dtype='float32') | |
im = util_image.normalize_np(im, mean=self.mean, std=self.std, reverse=False) | |
im = rearrange(im, 'h w c -> c h w') | |
out_dict = {'image':im.astype(np.float32), 'lq':im.astype(np.float32)} | |
if self.need_gt_path: | |
out_dict['path'] = im_path | |
if self.dir_path_gt is not None: | |
gt_path = str(Path(self.dir_path_gt) / Path(im_path).name) | |
im_gt = util_image.imread(gt_path, chn='rgb', dtype='float32') | |
im_gt = util_image.normalize_np(im_gt, mean=self.mean, std=self.std, reverse=False) | |
im_gt = rearrange(im_gt, 'h w c -> c h w') | |
out_dict['gt'] = im_gt.astype(np.float32) | |
return out_dict | |