OneRestore / makedataset.py
gy65896's picture
Upload 51 files
73ba284 verified
raw
history blame
5.51 kB
# -*- coding: utf-8 -*-
"""
Created on Wed Feb 12 20:00:46 2020
@author: Administrator
"""
import os
import os.path
import random
import numpy as np
import cv2
import h5py
import torch
import torch.utils.data as udata
import argparse
from PIL import Image
class Dataset(udata.Dataset):
r"""Implements torch.utils.data.Dataset
"""
def __init__(self, file, trainrgb=True,trainsyn = True, shuffle=False):
super(Dataset, self).__init__()
self.trainrgb = trainrgb
self.trainsyn = trainsyn
self.train_haze = file
h5f = h5py.File(self.train_haze, 'r')
self.keys = list(h5f.keys())
if shuffle:
random.shuffle(self.keys)
h5f.close()
def __len__(self):
return len(self.keys)
def __getitem__(self, index):
h5f = h5py.File(self.train_haze, 'r')
key = self.keys[index]
data = np.array(h5f[key])
h5f.close()
return torch.Tensor(data)
def data_augmentation(clear, mode):
r"""Performs dat augmentation of the input image
Args:
image: a cv2 (OpenCV) image
mode: int. Choice of transformation to apply to the image
0 - no transformation
1 - flip up and down
2 - rotate counterwise 90 degree
3 - rotate 90 degree and flip up and down
4 - rotate 180 degree
5 - rotate 180 degree and flip
6 - rotate 270 degree
7 - rotate 270 degree and flip
"""
clear = np.transpose(clear, (2, 3, 0, 1))
if mode == 0:
# original
clear = clear
elif mode == 1:
# flip up and down
clear = np.flipud(clear)
elif mode == 2:
# rotate counterwise 90 degree
clear = np.rot90(clear)
elif mode == 3:
# rotate 90 degree and flip up and down
clear = np.rot90(clear)
clear = np.flipud(clear)
elif mode == 4:
# rotate 180 degree
clear = np.rot90(clear, k=2)
elif mode == 5:
# rotate 180 degree and flip
clear = np.rot90(clear, k=2)
clear = np.flipud(clear)
elif mode == 6:
# rotate 270 degree
clear = np.rot90(clear, k=3)
elif mode == 7:
# rotate 270 degree and flip
clear = np.rot90(clear, k=3)
clear = np.flipud(clear)
else:
raise Exception('Invalid choice of image transformation')
return np.transpose(clear, (2, 3, 0, 1))
def img_to_patches(img,win,stride,Syn=True):
typ, chl, raw, col = img.shape
chl = int(chl)
num_raw = np.ceil((raw-win)/stride+1).astype(np.uint8)
num_col = np.ceil((col-win)/stride+1).astype(np.uint8)
count = 0
total_process = int(num_col)*int(num_raw)
img_patches = np.zeros([typ, chl, win, win, total_process])
if Syn:
for i in range(num_raw):
for j in range(num_col):
if stride * i + win <= raw and stride * j + win <=col:
img_patches[:,:,:,:,count] = img[:, :, stride*i : stride*i + win, stride*j : stride*j + win]
elif stride * i + win > raw and stride * j + win<=col:
img_patches[:,:,:,:,count] = img[:, :,raw-win : raw,stride * j : stride * j + win]
elif stride * i + win <= raw and stride*j + win>col:
img_patches[:,:,:,:,count] = img[:, :,stride*i : stride*i + win, col-win : col]
else:
img_patches[:,:,:,:,count] = img[:, :,raw-win : raw,col-win : col]
img_patches[:,:,:,:,count] = data_augmentation(img_patches[:, :, :, :, count], np.random.randint(0, 7))
count +=1
return img_patches
def read_img(img):
return np.array(Image.open(img))/255.
def Train_data(args):
file_list = os.listdir(f'{args.train_path}/{args.gt_name}')
with h5py.File(args.data_name, 'w') as h5f:
count = 0
for i in range(len(file_list)):
print(file_list[i])
img_list = []
img_list.append(read_img(f'{args.train_path}/{args.gt_name}/{file_list[i]}'))
for j in args.degradation_name:
img_list.append(read_img(f'{args.train_path}/{j}/{file_list[i]}'))
img = np.stack(img_list,0)
img = img_to_patches(img.transpose(0, 3, 1, 2), args.patch_size, args.stride)
for nx in range(img.shape[4]):
data = img[:,:,:,:,nx]
print(count, data.shape)
h5f.create_dataset(str(count), data=data)
count += 1
h5f.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description = "Building the training patch database")
parser.add_argument("--patch-size", type = int, default=256, help="Patch size")
parser.add_argument("--stride", type = int, default=200, help="Size of stride")
parser.add_argument("--train-path", type = str, default='./data/CDD-11_train', help="Train path")
parser.add_argument("--data-name", type = str, default='dataset.h5', help="Data name")
parser.add_argument("--gt-name", type = str, default='clear', help="HQ name")
parser.add_argument("--degradation-name", type = list, default=['low','haze','rain','snow',\
'low_haze','low_rain','low_snow','haze_rain','haze_snow','low_haze_rain','low_haze_snow'], help="LQ name")
args = parser.parse_args()
Train_data(args)