InPeerReview commited on
Commit
7575913
·
verified ·
1 Parent(s): 83e745f

Upload 2 files

Browse files
Files changed (2) hide show
  1. dataset/Transforms.py +217 -0
  2. dataset/dataset.py +39 -0
dataset/Transforms.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import numpy as np
3
+ import torch
4
+ import random
5
+ import cv2
6
+
7
+
8
+ class Scale(object):
9
+ """
10
+ Resize the given image to a fixed scale
11
+ """
12
+
13
+ def __init__(self, wi, he):
14
+ '''
15
+ :param wi: width after resizing
16
+ :param he: height after reszing
17
+ '''
18
+ self.w = wi
19
+ self.h = he
20
+
21
+ # modified from torchvision to add support for max size
22
+
23
+ def __call__(self, img, label):
24
+ '''
25
+ :param img: RGB image
26
+ :param label: semantic label image
27
+ :return: resized images
28
+ '''
29
+ # bilinear interpolation for RGB image
30
+ img = cv2.resize(img, (self.w, self.h))
31
+ # nearest neighbour interpolation for label image
32
+ label = cv2.resize(label, (self.w, self.h), interpolation=cv2.INTER_NEAREST)
33
+ return [img, label]
34
+
35
+
36
+ class Resize(object):
37
+ def __init__(self, min_size, max_size, strict=False):
38
+ if not isinstance(min_size, (list, tuple)):
39
+ min_size = (min_size,)
40
+ self.min_size = min_size
41
+ self.max_size = max_size
42
+ self.strict = strict
43
+
44
+ # modified from torchvision to add support for max size
45
+ def get_size(self, image_size):
46
+ w, h = image_size
47
+ if not self.strict:
48
+ size = random.choice(self.min_size)
49
+ max_size = self.max_size
50
+ if max_size is not None:
51
+ min_original_size = float(min((w, h)))
52
+ max_original_size = float(max((w, h)))
53
+ if max_original_size / min_original_size * size > max_size:
54
+ size = int(round(max_size * min_original_size / max_original_size))
55
+
56
+ if (w <= h and w == size) or (h <= w and h == size):
57
+ return (h, w)
58
+
59
+ if w < h:
60
+ ow = size
61
+ oh = int(size * h / w)
62
+ else:
63
+ oh = size
64
+ ow = int(size * w / h)
65
+
66
+ return (oh, ow)
67
+ else:
68
+ if w < h:
69
+ return (self.max_size, self.min_size[0])
70
+ else:
71
+ return (self.min_size[0], self.max_size)
72
+
73
+ def __call__(self, image, label):
74
+ size = self.get_size(image.shape[:2])
75
+ image = cv2.resize(image, size)
76
+ # I confirm that the output size is right, not reversed
77
+ label = cv2.resize(label, size, interpolation=cv2.INTER_NEAREST)
78
+ return (image, label)
79
+
80
+
81
+ class RandomCropResize(object):
82
+ """
83
+ Randomly crop and resize the given image with a probability of 0.5
84
+ """
85
+
86
+ def __init__(self, crop_area):
87
+ '''
88
+ :param crop_area: area to be cropped (this is the max value and we select between 0 and crop area
89
+ '''
90
+ self.cw = crop_area
91
+ self.ch = crop_area
92
+
93
+ def __call__(self, img, label):
94
+ if random.random() < 0.5:
95
+ h, w = img.shape[:2]
96
+ x1 = random.randint(0, self.ch)
97
+ y1 = random.randint(0, self.cw)
98
+
99
+ img_crop = img[y1:h - y1, x1:w - x1]
100
+ label_crop = label[y1:h - y1, x1:w - x1]
101
+
102
+ img_crop = cv2.resize(img_crop, (w, h))
103
+ label_crop = cv2.resize(label_crop, (w, h), interpolation=cv2.INTER_NEAREST)
104
+
105
+ return img_crop, label_crop
106
+ else:
107
+ return [img, label]
108
+
109
+
110
+ class RandomFlip(object):
111
+ """
112
+ Randomly flip the given Image with a probability of 0.5
113
+ """
114
+
115
+ def __call__(self, image, label):
116
+ if random.random() < 0.5:
117
+ image = cv2.flip(image, 0) # horizontal flip
118
+ label = cv2.flip(label, 0) # horizontal flip
119
+ if random.random() < 0.5:
120
+ image = cv2.flip(image, 1) # veritcal flip
121
+ label = cv2.flip(label, 1) # veritcal flip
122
+ return [image, label]
123
+
124
+
125
+ class RandomExchange(object):
126
+ """
127
+ Randomly flip the given Image with a probability of 0.5
128
+ """
129
+
130
+ def __call__(self, image, label):
131
+ if random.random() < 0.5:
132
+ pre_img = image[:, :, 0:3]
133
+ post_img = image[:, :, 3:6]
134
+ image = numpy.concatenate((post_img, pre_img), axis=2)
135
+ return [image, label]
136
+
137
+
138
+ class Normalize(object):
139
+ """
140
+ Given mean: (B, G, R) and std: (B, G, R),
141
+ will normalize each channel of the torch.*Tensor, i.e.
142
+ channel = (channel - mean) / std
143
+ """
144
+
145
+ def __init__(self, mean, std):
146
+ '''
147
+ :param mean: global mean computed from dataset
148
+ :param std: global std computed from dataset
149
+ '''
150
+ self.mean = mean
151
+ self.std = std
152
+ self.depth_mean = [0.5]
153
+ self.depth_std = [0.5]
154
+
155
+ def __call__(self, image, label):
156
+ image = image.astype(np.float32)
157
+ image = image / 255
158
+ label = np.ceil(label / 255)
159
+ for i in range(6):
160
+ image[:, :, i] -= self.mean[i]
161
+ for i in range(6):
162
+ image[:, :, i] /= self.std[i]
163
+
164
+ return [image, label]
165
+
166
+
167
+ class GaussianNoise(object):
168
+ def __init__(self, std=0.05):
169
+ '''
170
+ :param mean: global mean computed from dataset
171
+ :param std: global std computed from dataset
172
+ '''
173
+ self.std = std
174
+
175
+ def __call__(self, image, label):
176
+ noise = np.random.normal(loc=0, scale=self.std, size=image.shape)
177
+ image = image + noise.astype(np.float32)
178
+ return [image, label]
179
+
180
+
181
+ class ToTensor(object):
182
+ '''
183
+ This class converts the data to tensor so that it can be processed by PyTorch
184
+ '''
185
+
186
+ def __init__(self, scale=1):
187
+ '''
188
+ :param scale: set this parameter according to the output scale
189
+ '''
190
+ self.scale = scale
191
+
192
+ def __call__(self, image, label):
193
+ if self.scale != 1:
194
+ h, w = label.shape[:2]
195
+ image = cv2.resize(image, (int(w), int(h)))
196
+ label = cv2.resize(label, (int(w / self.scale), int(h / self.scale)), \
197
+ interpolation=cv2.INTER_NEAREST)
198
+ image = image[:, :, ::-1].copy() # .copy() is to solve "torch does not support negative index"
199
+ image = image.transpose((2, 0, 1))
200
+ image_tensor = torch.from_numpy(image)
201
+ label_tensor = torch.LongTensor(np.array(label, dtype=np.int)).unsqueeze(dim=0)
202
+
203
+ return [image_tensor, label_tensor]
204
+
205
+
206
+ class Compose(object):
207
+ """
208
+ Composes several transforms together.
209
+ """
210
+
211
+ def __init__(self, transforms):
212
+ self.transforms = transforms
213
+
214
+ def __call__(self, *args):
215
+ for t in self.transforms:
216
+ args = t(*args)
217
+ return args
dataset/dataset.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ from os.path import join as osp
4
+ import numpy
5
+ import torch.utils.data
6
+
7
+
8
+ class Dataset(torch.utils.data.Dataset):
9
+ def __init__(self, file_root='data/', mode='train', transform=None):
10
+ self.file_list = os.listdir(osp(file_root, mode, 'A'))
11
+
12
+ self.pre_images = [osp(file_root, mode, 'A', x) for x in self.file_list]
13
+ self.post_images = [osp(file_root, mode, 'B', x) for x in self.file_list]
14
+ self.gts = [osp(file_root, mode, 'label', x) for x in self.file_list]
15
+
16
+ self.transform = transform
17
+
18
+ def __len__(self):
19
+ return len(self.pre_images)
20
+
21
+ def __getitem__(self, idx):
22
+ pre_image_name = self.pre_images[idx]
23
+ label_name = self.gts[idx]
24
+ post_image_name = self.post_images[idx]
25
+
26
+ pre_image = cv2.imread(pre_image_name)
27
+ label = cv2.imread(label_name, 0)
28
+ post_image = cv2.imread(post_image_name)
29
+
30
+ img = numpy.concatenate((pre_image, post_image), axis=2)
31
+
32
+ if self.transform:
33
+ [img, label] = self.transform(img, label)
34
+
35
+ return img, label
36
+
37
+ def get_img_info(self, idx):
38
+ img = cv2.imread(self.pre_images[idx])
39
+ return {"height": img.shape[0], "width": img.shape[1]}