File size: 2,726 Bytes
22d8ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from io import BytesIO

import numpy as np
import lmdb
from PIL import Image
from skimage import color
import torch
from torch.utils.data import Dataset
from data.tps_transformation import tps_transform

def RGB2Lab(inputs):
    return color.rgb2lab(inputs)

def Normalize(inputs):
    # output l [-50,50] ab[-128,128]
    l = inputs[:, :, 0:1]
    ab = inputs[:, :, 1:3]
    l = l - 50
    # ab = ab
    lab = np.concatenate((l, ab), 2)

    return lab.astype('float32')

def selfnormalize(inputs):
    d = torch.max(inputs) - torch.min(inputs)
    out = (inputs) / d
    return out

def to_gray(inputs):
    img_gray = np.clip((np.concatenate((inputs[:,:,:1], inputs[:,:,:1], inputs[:,:,:1]), 2)+50)/100*255, 0, 255).astype('uint8')
    
    return img_gray

def numpy2tensor(inputs):
    out = torch.from_numpy(inputs.transpose(2,0,1))
    return out

class MultiResolutionDataset(Dataset):
    def __init__(self, path, transform, resolution=256):
        self.env = lmdb.open(
            path,
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

        if not self.env:
            raise IOError('Cannot open lmdb dataset', path)

        with self.env.begin(write=False) as txn:
            self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))

        self.resolution = resolution
        self.transform = transform

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        with self.env.begin(write=False) as txn:
            key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
            img_bytes = txn.get(key)

        buffer = BytesIO(img_bytes)
        img = Image.open(buffer)
        img_src = np.array(img) # [0,255] uint8

        # ima_a = img_src
        # ima_a = ima_a.astype('uint8')
        # ima_a = Image.fromarray(ima_a)
        # ima_a.show()

        ## add gaussian noise
        noise = np.random.uniform(-5, 5, np.shape(img_src))
        img_ref = np.clip(np.array(img_src) + noise, 0, 255)


        img_ref = tps_transform(img_ref) # [0,255] uint8
        img_ref = np.clip(img_ref, 0, 255)
        img_ref = img_ref.astype('uint8')
        img_ref = Image.fromarray(img_ref)
        img_ref = np.array(self.transform(img_ref)) # [0,255] uint8

        img_lab = Normalize(RGB2Lab(img_src)) # l [-50,50] ab [-128, 128]

        img = img_src.astype('float32') # [0,255] float32 RGB
        img_ref = img_ref.astype('float32') # [0,255] float32 RGB

        img = numpy2tensor(img)
        img_ref = numpy2tensor(img_ref) # [B, 3, 256, 256]
        img_lab = numpy2tensor(img_lab)

        return img, img_ref, img_lab