|
import glob |
|
|
|
import h5py |
|
from PIL import Image |
|
from torchvision.transforms import RandomCrop |
|
from torchvision.transforms.functional import to_tensor |
|
from tqdm import tqdm |
|
|
|
from Dataloader import ImageAugment |
|
|
|
patch_size = 128 |
|
shrink_size = 2 |
|
noise_level = 1 |
|
patches_per_img = 20 |
|
images = glob.glob("dataset/train/*") |
|
|
|
database = h5py.File("train_images.hdf5", "w") |
|
|
|
dat_group = database.create_group("shrink_2_noise_level_1_downsample_random_rgb") |
|
|
|
storage_lr = dat_group.create_dataset( |
|
"train_lr", |
|
shape=( |
|
patches_per_img * len(images), |
|
3, |
|
patch_size // shrink_size, |
|
patch_size // shrink_size, |
|
), |
|
dtype="float32", |
|
|
|
) |
|
storage_hr = dat_group.create_dataset( |
|
"train_hr", |
|
shape=(patches_per_img * len(images), 3, patch_size, patch_size), |
|
|
|
dtype="float32", |
|
) |
|
|
|
random_cropper = RandomCrop(size=patch_size) |
|
img_augmenter = ImageAugment(shrink_size, noise_level, down_sample_method=None) |
|
|
|
|
|
def get_img_patches(img_pil): |
|
img_patch = random_cropper(img_pil) |
|
lr_hr_patches = img_augmenter.process(img_patch) |
|
return lr_hr_patches |
|
|
|
|
|
counter = 0 |
|
for img in tqdm(images): |
|
img_pil = Image.open(img).convert("RGB") |
|
for i in range(patches_per_img): |
|
patch = get_img_patches(img_pil) |
|
storage_lr[counter] = to_tensor(patch[0].convert("RGB")).numpy() |
|
storage_hr[counter] = to_tensor(patch[1].convert("RGB")).numpy() |
|
counter += 1 |
|
database.close() |
|
|