# -*- coding: utf-8 -*- """ Created on Wed Feb 12 20:00:46 2020 @author: Administrator """ import os import os.path import random import numpy as np import cv2 import h5py import torch import torch.utils.data as udata import argparse from PIL import Image class Dataset(udata.Dataset): r"""Implements torch.utils.data.Dataset """ def __init__(self, file, trainrgb=True,trainsyn = True, shuffle=False): super(Dataset, self).__init__() self.trainrgb = trainrgb self.trainsyn = trainsyn self.train_haze = file h5f = h5py.File(self.train_haze, 'r') self.keys = list(h5f.keys()) if shuffle: random.shuffle(self.keys) h5f.close() def __len__(self): return len(self.keys) def __getitem__(self, index): h5f = h5py.File(self.train_haze, 'r') key = self.keys[index] data = np.array(h5f[key]) h5f.close() return torch.Tensor(data) def data_augmentation(clear, mode): r"""Performs dat augmentation of the input image Args: image: a cv2 (OpenCV) image mode: int. Choice of transformation to apply to the image 0 - no transformation 1 - flip up and down 2 - rotate counterwise 90 degree 3 - rotate 90 degree and flip up and down 4 - rotate 180 degree 5 - rotate 180 degree and flip 6 - rotate 270 degree 7 - rotate 270 degree and flip """ clear = np.transpose(clear, (2, 3, 0, 1)) if mode == 0: # original clear = clear elif mode == 1: # flip up and down clear = np.flipud(clear) elif mode == 2: # rotate counterwise 90 degree clear = np.rot90(clear) elif mode == 3: # rotate 90 degree and flip up and down clear = np.rot90(clear) clear = np.flipud(clear) elif mode == 4: # rotate 180 degree clear = np.rot90(clear, k=2) elif mode == 5: # rotate 180 degree and flip clear = np.rot90(clear, k=2) clear = np.flipud(clear) elif mode == 6: # rotate 270 degree clear = np.rot90(clear, k=3) elif mode == 7: # rotate 270 degree and flip clear = np.rot90(clear, k=3) clear = np.flipud(clear) else: raise Exception('Invalid choice of image transformation') return np.transpose(clear, (2, 3, 0, 1)) def img_to_patches(img,win,stride,Syn=True): typ, chl, raw, col = img.shape chl = int(chl) num_raw = np.ceil((raw-win)/stride+1).astype(np.uint8) num_col = np.ceil((col-win)/stride+1).astype(np.uint8) count = 0 total_process = int(num_col)*int(num_raw) img_patches = np.zeros([typ, chl, win, win, total_process]) if Syn: for i in range(num_raw): for j in range(num_col): if stride * i + win <= raw and stride * j + win <=col: img_patches[:,:,:,:,count] = img[:, :, stride*i : stride*i + win, stride*j : stride*j + win] elif stride * i + win > raw and stride * j + win<=col: img_patches[:,:,:,:,count] = img[:, :,raw-win : raw,stride * j : stride * j + win] elif stride * i + win <= raw and stride*j + win>col: img_patches[:,:,:,:,count] = img[:, :,stride*i : stride*i + win, col-win : col] else: img_patches[:,:,:,:,count] = img[:, :,raw-win : raw,col-win : col] img_patches[:,:,:,:,count] = data_augmentation(img_patches[:, :, :, :, count], np.random.randint(0, 7)) count +=1 return img_patches def read_img(img): return np.array(Image.open(img))/255. def Train_data(args): file_list = os.listdir(f'{args.train_path}/{args.gt_name}') with h5py.File(args.data_name, 'w') as h5f: count = 0 for i in range(len(file_list)): print(file_list[i]) img_list = [] img_list.append(read_img(f'{args.train_path}/{args.gt_name}/{file_list[i]}')) for j in args.degradation_name: img_list.append(read_img(f'{args.train_path}/{j}/{file_list[i]}')) img = np.stack(img_list,0) img = img_to_patches(img.transpose(0, 3, 1, 2), args.patch_size, args.stride) for nx in range(img.shape[4]): data = img[:,:,:,:,nx] print(count, data.shape) h5f.create_dataset(str(count), data=data) count += 1 h5f.close() if __name__ == "__main__": parser = argparse.ArgumentParser(description = "Building the training patch database") parser.add_argument("--patch-size", type = int, default=256, help="Patch size") parser.add_argument("--stride", type = int, default=200, help="Size of stride") parser.add_argument("--train-path", type = str, default='./data/CDD-11_train', help="Train path") parser.add_argument("--data-name", type = str, default='dataset.h5', help="Data name") parser.add_argument("--gt-name", type = str, default='clear', help="HQ name") parser.add_argument("--degradation-name", type = list, default=['low','haze','rain','snow',\ 'low_haze','low_rain','low_snow','haze_rain','haze_snow','low_haze_rain','low_haze_snow'], help="LQ name") args = parser.parse_args() Train_data(args)