File size: 2,726 Bytes
22d8ab7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from io import BytesIO
import numpy as np
import lmdb
from PIL import Image
from skimage import color
import torch
from torch.utils.data import Dataset
from data.tps_transformation import tps_transform
def RGB2Lab(inputs):
return color.rgb2lab(inputs)
def Normalize(inputs):
# output l [-50,50] ab[-128,128]
l = inputs[:, :, 0:1]
ab = inputs[:, :, 1:3]
l = l - 50
# ab = ab
lab = np.concatenate((l, ab), 2)
return lab.astype('float32')
def selfnormalize(inputs):
d = torch.max(inputs) - torch.min(inputs)
out = (inputs) / d
return out
def to_gray(inputs):
img_gray = np.clip((np.concatenate((inputs[:,:,:1], inputs[:,:,:1], inputs[:,:,:1]), 2)+50)/100*255, 0, 255).astype('uint8')
return img_gray
def numpy2tensor(inputs):
out = torch.from_numpy(inputs.transpose(2,0,1))
return out
class MultiResolutionDataset(Dataset):
def __init__(self, path, transform, resolution=256):
self.env = lmdb.open(
path,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
if not self.env:
raise IOError('Cannot open lmdb dataset', path)
with self.env.begin(write=False) as txn:
self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
self.resolution = resolution
self.transform = transform
def __len__(self):
return self.length
def __getitem__(self, index):
with self.env.begin(write=False) as txn:
key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
img_bytes = txn.get(key)
buffer = BytesIO(img_bytes)
img = Image.open(buffer)
img_src = np.array(img) # [0,255] uint8
# ima_a = img_src
# ima_a = ima_a.astype('uint8')
# ima_a = Image.fromarray(ima_a)
# ima_a.show()
## add gaussian noise
noise = np.random.uniform(-5, 5, np.shape(img_src))
img_ref = np.clip(np.array(img_src) + noise, 0, 255)
img_ref = tps_transform(img_ref) # [0,255] uint8
img_ref = np.clip(img_ref, 0, 255)
img_ref = img_ref.astype('uint8')
img_ref = Image.fromarray(img_ref)
img_ref = np.array(self.transform(img_ref)) # [0,255] uint8
img_lab = Normalize(RGB2Lab(img_src)) # l [-50,50] ab [-128, 128]
img = img_src.astype('float32') # [0,255] float32 RGB
img_ref = img_ref.astype('float32') # [0,255] float32 RGB
img = numpy2tensor(img)
img_ref = numpy2tensor(img_ref) # [B, 3, 256, 256]
img_lab = numpy2tensor(img_lab)
return img, img_ref, img_lab
|