Spaces:
Runtime error
Runtime error
import glob | |
import io | |
import numpy as np | |
import re | |
import os | |
import random | |
from io import BytesIO | |
from uuid import uuid4 | |
import sqlite3 | |
import h5py | |
import torch | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from torchvision.transforms import RandomCrop | |
from torchvision.transforms.functional import to_tensor | |
class ImageH5Data(Dataset): | |
def __init__(self, h5py_file, folder_name): | |
self.data = h5py.File(h5py_file, 'r')[folder_name] | |
self.data_hr = self.data['train_hr'] | |
self.data_lr = self.data['train_lr'] | |
self.len_imgs = len(self.data_hr) | |
self.h5py_file = h5py_file | |
self.folder_name = folder_name | |
def __len__(self): | |
# with h5py.File(self.h5py_file, 'r') as f: | |
# return len(f[self.folder_name]['train_lr']) | |
return self.len_imgs | |
def __getitem__(self, index): | |
# with h5py.File(self.h5py_file, 'r') as f: | |
# data_lr = f[self.folder_name]['train_lr'][index] | |
# data_hr = f[self.folder_name]['train_lr'][index] | |
# | |
# return data_lr, data_hr | |
return self.data_lr[index], self.data_hr[index] | |
class ImageData(Dataset): | |
def __init__(self, | |
img_folder, | |
patch_size=96, | |
shrink_size=2, | |
noise_level=1, | |
down_sample_method=None, | |
color_mod='RGB', | |
dummy_len=None): | |
self.img_folder = img_folder | |
all_img = glob.glob(self.img_folder + "/**", recursive=True) | |
self.img = list(filter(lambda x: x.endswith('png') or x.endswith("jpg") or x.endswith("jpeg"), all_img)) | |
self.total_img = len(self.img) | |
self.dummy_len = dummy_len if dummy_len is not None else self.total_img | |
self.random_cropper = RandomCrop(size=patch_size) | |
self.color_mod = color_mod | |
self.img_augmenter = ImageAugment(shrink_size, noise_level, down_sample_method) | |
def get_img_patches(self, img_file): | |
img_pil = Image.open(img_file).convert("RGB") | |
img_patch = self.random_cropper(img_pil) | |
lr_hr_patches = self.img_augmenter.process(img_patch) | |
return lr_hr_patches | |
def __len__(self): | |
return self.dummy_len # len(self.img) | |
def __getitem__(self, index): | |
idx = random.choice(range(0, self.total_img)) | |
img = self.img[idx] | |
patch = self.get_img_patches(img) | |
if self.color_mod == 'RGB': | |
lr_img = patch[0].convert("RGB") | |
hr_img = patch[1].convert("RGB") | |
elif self.color_mod == 'YCbCr': | |
lr_img, _, _ = patch[0].convert('YCbCr').split() | |
hr_img, _, _ = patch[1].convert('YCbCr').split() | |
else: | |
raise KeyError('Either RGB or YCbCr') | |
return to_tensor(lr_img), to_tensor(hr_img) | |
class Image2Sqlite(ImageData): | |
def __getitem__(self, item): | |
img = self.img[item] | |
lr_hr_patch = self.get_img_patches(img) | |
if self.color_mod == 'RGB': | |
lr_img = lr_hr_patch[0].convert("RGB") | |
hr_img = lr_hr_patch[1].convert("RGB") | |
elif self.color_mod == 'YCbCr': | |
lr_img, _, _ = lr_hr_patch[0].convert('YCbCr').split() | |
hr_img, _, _ = lr_hr_patch[1].convert('YCbCr').split() | |
else: | |
raise KeyError('Either RGB or YCbCr') | |
lr_byte = self.convert_to_bytevalue(lr_img) | |
hr_byte = self.convert_to_bytevalue(hr_img) | |
return [lr_byte, hr_byte] | |
def convert_to_bytevalue(pil_img): | |
img_byte = io.BytesIO() | |
pil_img.save(img_byte, format='png') | |
return img_byte.getvalue() | |
class ImageDBData(Dataset): | |
def __init__(self, db_file, db_table="images", lr_col="lr_img", hr_col="hr_img", max_images=None): | |
self.db_file = db_file | |
self.db_table = db_table | |
self.lr_col = lr_col | |
self.hr_col = hr_col | |
self.total_images = self.get_num_rows(max_images) | |
# self.lr_hr_images = self.get_all_images() | |
def __len__(self): | |
return self.total_images | |
# def get_all_images(self): | |
# with sqlite3.connect(self.db_file) as conn: | |
# cursor = conn.cursor() | |
# cursor.execute(f"SELECT * FROM {self.db_table} LIMIT {self.total_images}") | |
# return cursor.fetchall() | |
def get_num_rows(self, max_images): | |
with sqlite3.connect(self.db_file) as conn: | |
cursor = conn.cursor() | |
cursor.execute(f"SELECT MAX(ROWID) FROM {self.db_table}") | |
db_rows = cursor.fetchone()[0] | |
if max_images: | |
return min(max_images, db_rows) | |
else: | |
return db_rows | |
def __getitem__(self, item): | |
# lr, hr = self.lr_hr_images[item] | |
# lr = Image.open(io.BytesIO(lr)) | |
# hr = Image.open(io.BytesIO(hr)) | |
# return to_tensor(lr), to_tensor(hr) | |
# note sqlite rowid starts with 1 | |
with sqlite3.connect(self.db_file) as conn: | |
cursor = conn.cursor() | |
cursor.execute(f"SELECT {self.lr_col}, {self.hr_col} FROM {self.db_table} WHERE ROWID={item + 1}") | |
lr, hr = cursor.fetchone() | |
lr = Image.open(io.BytesIO(lr)).convert("RGB") | |
hr = Image.open(io.BytesIO(hr)).convert("RGB") | |
# lr = np.array(lr) # use scale [0, 255] instead of [0,1] | |
# hr = np.array(hr) | |
return to_tensor(lr), to_tensor(hr) | |
class ImagePatchData(Dataset): | |
def __init__(self, lr_folder, hr_folder): | |
self.lr_folder = lr_folder | |
self.hr_folder = hr_folder | |
self.lr_imgs = glob.glob(os.path.join(lr_folder, "**")) | |
self.total_imgs = len(self.lr_imgs) | |
def __len__(self): | |
return self.total_imgs | |
def __getitem__(self, item): | |
lr_file = self.lr_imgs[item] | |
hr_path = re.sub("lr", 'hr', os.path.dirname(lr_file)) | |
filename = os.path.basename(lr_file) | |
hr_file = os.path.join(hr_path, filename) | |
return to_tensor(Image.open(lr_file)), to_tensor(Image.open(hr_file)) | |
class ImageAugment: | |
def __init__(self, | |
shrink_size=2, | |
noise_level=1, | |
down_sample_method=None | |
): | |
# noise_level (int): 0: no noise; 1: 75-95% quality; 2:50-75% | |
if noise_level == 0: | |
self.noise_level = [0, 0] | |
elif noise_level == 1: | |
self.noise_level = [5, 25] | |
elif noise_level == 2: | |
self.noise_level = [25, 50] | |
else: | |
raise KeyError("Noise level should be either 0, 1, 2") | |
self.shrink_size = shrink_size | |
self.down_sample_method = down_sample_method | |
def shrink_img(self, hr_img): | |
if self.down_sample_method is None: | |
resample_method = random.choice([Image.BILINEAR, Image.BICUBIC, Image.LANCZOS]) | |
else: | |
resample_method = self.down_sample_method | |
img_w, img_h = tuple(map(lambda x: int(x / self.shrink_size), hr_img.size)) | |
lr_img = hr_img.resize((img_w, img_h), resample_method) | |
return lr_img | |
def add_jpeg_noise(self, hr_img): | |
quality = 100 - round(random.uniform(*self.noise_level)) | |
lr_img = BytesIO() | |
hr_img.save(lr_img, format='JPEG', quality=quality) | |
lr_img.seek(0) | |
lr_img = Image.open(lr_img) | |
return lr_img | |
def process(self, hr_patch_pil): | |
lr_patch_pil = self.shrink_img(hr_patch_pil) | |
if self.noise_level[1] > 0: | |
lr_patch_pil = self.add_jpeg_noise(lr_patch_pil) | |
return lr_patch_pil, hr_patch_pil | |
def up_sample(self, img, resample): | |
width, height = img.size | |
return img.resize((self.shrink_size * width, self.shrink_size * height), resample=resample) | |