shunk031 commited on
Commit
3b947a6
·
verified ·
1 Parent(s): cdd61a2

Upload processor

Browse files
image_processing_basnet.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple, Union
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PilImage
8
+ from torchvision import transforms
9
+ from transformers.image_processing_base import BatchFeature
10
+ from transformers.image_processing_utils import BaseImageProcessor
11
+ from transformers.image_utils import ImageInput
12
+
13
+
14
+ class RescaleT(object):
15
+ def __init__(self, output_size: Union[int, Tuple[int, int]]) -> None:
16
+ super().__init__()
17
+ assert isinstance(output_size, (int, tuple))
18
+ self.output_size = output_size
19
+
20
+ def __call__(self, sample):
21
+ image, label = sample["image"], sample["label"]
22
+
23
+ h, w = image.shape[:2]
24
+
25
+ if isinstance(self.output_size, int):
26
+ if h > w:
27
+ new_h, new_w = self.output_size * h / w, self.output_size
28
+ else:
29
+ new_h, new_w = self.output_size, self.output_size * w / h
30
+ else:
31
+ new_h, new_w = self.output_size
32
+
33
+ new_h, new_w = int(new_h), int(new_w)
34
+
35
+ # resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
36
+ # img = transform.resize(image,(new_h,new_w),mode='constant')
37
+ # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
38
+
39
+ # img = transform.resize(image, (self.output_size, self.output_size), mode='constant')
40
+ img = (
41
+ cv2.resize(
42
+ image,
43
+ (self.output_size, self.output_size),
44
+ interpolation=cv2.INTER_AREA,
45
+ )
46
+ / 255.0
47
+ )
48
+ # lbl = transform.resize(label, (self.output_size, self.output_size),
49
+ # mode='constant',
50
+ # order=0,
51
+ # preserve_range=True)
52
+ lbl = cv2.resize(
53
+ label, (self.output_size, self.output_size), interpolation=cv2.INTER_NEAREST
54
+ )
55
+ lbl = np.expand_dims(lbl, axis=-1)
56
+ lbl = np.clip(lbl, np.min(label), np.max(label))
57
+
58
+ return {"image": img, "label": lbl}
59
+
60
+
61
+ class ToTensorLab(object):
62
+ """Convert ndarrays in sample to Tensors."""
63
+
64
+ def __init__(self, flag=0):
65
+ self.flag = flag
66
+
67
+ def __call__(self, sample):
68
+ image, label = sample["image"], sample["label"]
69
+
70
+ tmpLbl = np.zeros(label.shape)
71
+
72
+ if np.max(label) < 1e-6:
73
+ label = label
74
+ else:
75
+ label = label / np.max(label)
76
+
77
+ # print('self.flag:', self.flag) # Default: 0
78
+ # change the color space
79
+ if self.flag == 2: # with rgb and Lab colors
80
+ tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
81
+ tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
82
+ if image.shape[2] == 1:
83
+ tmpImgt[:, :, 0] = image[:, :, 0]
84
+ tmpImgt[:, :, 1] = image[:, :, 0]
85
+ tmpImgt[:, :, 2] = image[:, :, 0]
86
+ else:
87
+ tmpImgt = image
88
+ # tmpImgtl = color.rgb2lab(tmpImgt)
89
+ tmpImgtl = cv2.cvtColor(tmpImgt, cv2.COLOR_RGB2LAB)
90
+
91
+ # nomalize image to range [0,1]
92
+ tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
93
+ np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0])
94
+ )
95
+ tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
96
+ np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1])
97
+ )
98
+ tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
99
+ np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2])
100
+ )
101
+ tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
102
+ np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0])
103
+ )
104
+ tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
105
+ np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1])
106
+ )
107
+ tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
108
+ np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2])
109
+ )
110
+
111
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
112
+
113
+ tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
114
+ tmpImg[:, :, 0]
115
+ )
116
+ tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
117
+ tmpImg[:, :, 1]
118
+ )
119
+ tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
120
+ tmpImg[:, :, 2]
121
+ )
122
+ tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(
123
+ tmpImg[:, :, 3]
124
+ )
125
+ tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(
126
+ tmpImg[:, :, 4]
127
+ )
128
+ tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(
129
+ tmpImg[:, :, 5]
130
+ )
131
+
132
+ elif self.flag == 1: # with Lab color
133
+ tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
134
+
135
+ if image.shape[2] == 1:
136
+ tmpImg[:, :, 0] = image[:, :, 0]
137
+ tmpImg[:, :, 1] = image[:, :, 0]
138
+ tmpImg[:, :, 2] = image[:, :, 0]
139
+ else:
140
+ tmpImg = image
141
+
142
+ # tmpImg = color.rgb2lab(tmpImg)
143
+ print("tmpImg:", tmpImg.min(), tmpImg.max())
144
+ exit()
145
+ tmpImg = cv2.cvtColor(tmpImg, cv2.COLOR_RGB2LAB)
146
+
147
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
148
+
149
+ tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
150
+ np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0])
151
+ )
152
+ tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
153
+ np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1])
154
+ )
155
+ tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
156
+ np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2])
157
+ )
158
+
159
+ tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
160
+ tmpImg[:, :, 0]
161
+ )
162
+ tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
163
+ tmpImg[:, :, 1]
164
+ )
165
+ tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
166
+ tmpImg[:, :, 2]
167
+ )
168
+
169
+ else: # with rgb color
170
+ tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
171
+ image = image / np.max(image)
172
+ if image.shape[2] == 1:
173
+ tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
174
+ tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
175
+ tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
176
+ else:
177
+ tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
178
+ tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
179
+ tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
180
+
181
+ tmpLbl[:, :, 0] = label[:, :, 0]
182
+
183
+ # change the r,g,b to b,r,g from [0,255] to [0,1]
184
+ # transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
185
+ tmpImg = tmpImg.transpose((2, 0, 1))
186
+ tmpLbl = label.transpose((2, 0, 1))
187
+
188
+ return {"image": torch.from_numpy(tmpImg), "label": torch.from_numpy(tmpLbl)}
189
+
190
+
191
+ def apply_transform(
192
+ data: Dict[str, np.ndarray], rescale_size: int, to_tensor_lab_flag: int
193
+ ) -> Dict[str, torch.Tensor]:
194
+ transform = transforms.Compose(
195
+ [RescaleT(output_size=rescale_size), ToTensorLab(flag=to_tensor_lab_flag)]
196
+ )
197
+ return transform(data) # type: ignore
198
+
199
+
200
+ class BASNetImageProcessor(BaseImageProcessor):
201
+ model_input_names = ["pixel_values"]
202
+
203
+ def __init__(
204
+ self, rescale_size: int = 256, to_tensor_lab_flag: int = 0, **kwargs
205
+ ) -> None:
206
+ super().__init__(**kwargs)
207
+ self.rescale_size = rescale_size
208
+ self.to_tensor_lab_flag = to_tensor_lab_flag
209
+
210
+ def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
211
+ if not isinstance(images, PilImage):
212
+ raise ValueError(f"Expected PIL.Image, got {type(images)}")
213
+
214
+ image_pil = images
215
+ image_npy = np.array(image_pil, dtype=np.uint8)
216
+ width, height = image_pil.size
217
+ label_npy = np.zeros((height, width), dtype=np.uint8)
218
+
219
+ assert image_npy.shape[-1] == 3
220
+ output = apply_transform(
221
+ {"image": image_npy, "label": label_npy},
222
+ rescale_size=self.rescale_size,
223
+ to_tensor_lab_flag=self.to_tensor_lab_flag,
224
+ )
225
+ image = output["image"]
226
+
227
+ assert isinstance(image, torch.Tensor)
228
+
229
+ return BatchFeature(
230
+ data={"pixel_values": image.float().unsqueeze(dim=0)}, tensor_type="pt"
231
+ )
232
+
233
+ def postprocess(
234
+ self, prediction: torch.Tensor, width: int, height: int
235
+ ) -> PilImage:
236
+ def _norm_prediction(d: torch.Tensor) -> torch.Tensor:
237
+ ma, mi = torch.max(d), torch.min(d)
238
+
239
+ # division while avoiding zero division
240
+ dn = (d - mi) / ((ma - mi) + torch.finfo(torch.float32).eps)
241
+ return dn
242
+
243
+ # prediction = _norm_output(prediction)
244
+ # prediction = prediction.squeeze()
245
+ # prediction_np = prediction.cpu().numpy()
246
+
247
+ # image = Image.fromarray(prediction_np * 255).convert("RGB")
248
+ # image = image.resize((width, height), resample=Image.Resampling.BILINEAR)
249
+
250
+ # return image
251
+
252
+ # breakpoint()
253
+
254
+ # output = F.interpolate(output, (height, width), mode="bilinear")
255
+ # output = output.squeeze(dim=0)
256
+
257
+ # output = _norm_output(output)
258
+
259
+ # # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
260
+ # output = output * 255 + 0.5
261
+ # output = output.clamp(0, 255)
262
+
263
+ # # shape: (C=1, W, H) -> (W, H, C=1)
264
+ # output = output.permute(1, 2, 0)
265
+ # # shape: (W, H, C=3)
266
+ # output = output.repeat(1, 1, 3)
267
+
268
+ # output_np = output.cpu().numpy().astype(np.uint8)
269
+ # return Image.fromarray(output_np)
270
+
271
+ prediction = _norm_prediction(prediction)
272
+ prediction = prediction.squeeze()
273
+ prediction = prediction * 255 + 0.5
274
+ prediction = prediction.clamp(0, 255)
275
+
276
+ prediction_np = prediction.cpu().numpy()
277
+ image = Image.fromarray(prediction_np).convert("RGB")
278
+ image = image.resize((width, height), resample=Image.Resampling.BILINEAR)
279
+ return image
preprocessor_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_basnet.BASNetImageProcessor"
4
+ },
5
+ "image_processor_type": "BASNetImageProcessor",
6
+ "rescale_size": 256,
7
+ "to_tensor_lab_flag": 0
8
+ }