pogzyb commited on
Commit
0342dad
1 Parent(s): 7c916ae

Upload processor

Browse files
Files changed (2) hide show
  1. image_processor.py +257 -0
  2. preprocessor_config.json +22 -0
image_processor.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union, Iterable
2
+
3
+ import numpy as np
4
+ import torch
5
+ import transformers
6
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
7
+ from transformers.image_transforms import (
8
+ ChannelDimension,
9
+ get_resize_output_image_size,
10
+ rescale,
11
+ resize,
12
+ to_channel_dimension_format,
13
+ )
14
+ from transformers.image_utils import (
15
+ ImageInput,
16
+ PILImageResampling,
17
+ infer_channel_dimension_format,
18
+ get_channel_dimension_axis,
19
+ make_list_of_images,
20
+ to_numpy_array,
21
+ valid_images,
22
+ )
23
+ from transformers.utils import is_torch_tensor
24
+
25
+
26
+ class FaceSegformerImageProcessor(BaseImageProcessor):
27
+ def __init__(self, **kwargs):
28
+ super().__init__(**kwargs)
29
+ self.image_size = kwargs.get("image_size", (224, 224))
30
+ self.normalize_mean = kwargs.get("normalize_mean", [0.485, 0.456, 0.406])
31
+ self.normalize_std = kwargs.get("normalize_std", [0.229, 0.224, 0.225])
32
+ self.resample = kwargs.get("resample", PILImageResampling.BILINEAR)
33
+ self.data_format = kwargs.get("data_format", ChannelDimension.FIRST)
34
+
35
+ @staticmethod
36
+ def normalize(
37
+ image: np.ndarray,
38
+ mean: Union[float, Iterable[float]],
39
+ std: Union[float, Iterable[float]],
40
+ max_pixel_value: float = 255.0,
41
+ data_format: Optional[ChannelDimension] = None,
42
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
43
+ ) -> np.ndarray:
44
+ """
45
+ Copied from:
46
+ https://github.com/huggingface/transformers/blob/3eddda1111f70f3a59485e08540e8262b927e867/src/transformers/image_transforms.py#L209
47
+
48
+ BUT uses the formula from albumentations:
49
+ https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize
50
+
51
+ img = (img - mean * max_pixel_value) / (std * max_pixel_value)
52
+ """
53
+ if not isinstance(image, np.ndarray):
54
+ raise ValueError("image must be a numpy array")
55
+
56
+ if input_data_format is None:
57
+ input_data_format = infer_channel_dimension_format(image)
58
+ channel_axis = get_channel_dimension_axis(
59
+ image, input_data_format=input_data_format
60
+ )
61
+ num_channels = image.shape[channel_axis]
62
+
63
+ # We cast to float32 to avoid errors that can occur when subtracting uint8 values.
64
+ # We preserve the original dtype if it is a float type to prevent upcasting float16.
65
+ if not np.issubdtype(image.dtype, np.floating):
66
+ image = image.astype(np.float32)
67
+
68
+ if isinstance(mean, Iterable):
69
+ if len(mean) != num_channels:
70
+ raise ValueError(
71
+ f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}"
72
+ )
73
+ else:
74
+ mean = [mean] * num_channels
75
+ mean = np.array(mean, dtype=image.dtype)
76
+
77
+ if isinstance(std, Iterable):
78
+ if len(std) != num_channels:
79
+ raise ValueError(
80
+ f"std must have {num_channels} elements if it is an iterable, got {len(std)}"
81
+ )
82
+ else:
83
+ std = [std] * num_channels
84
+ std = np.array(std, dtype=image.dtype)
85
+
86
+ # Uses max_pixel_value for normalization
87
+ if input_data_format == ChannelDimension.LAST:
88
+ image = (image - mean * max_pixel_value) / (std * max_pixel_value)
89
+ else:
90
+ image = ((image.T - mean * max_pixel_value) / (std * max_pixel_value)).T
91
+
92
+ image = (
93
+ to_channel_dimension_format(image, data_format, input_data_format)
94
+ if data_format is not None
95
+ else image
96
+ )
97
+ return image
98
+
99
+ def resize(
100
+ self,
101
+ image: np.ndarray,
102
+ size: Dict[str, int],
103
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
104
+ data_format: Optional[Union[str, ChannelDimension]] = None,
105
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
106
+ **kwargs,
107
+ ) -> np.ndarray:
108
+ """
109
+ Copied from:
110
+ https://github.com/huggingface/transformers/blob/3eddda1111f70f3a59485e08540e8262b927e867/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py
111
+ """
112
+ default_to_square = True
113
+ if "shortest_edge" in size:
114
+ size = size["shortest_edge"]
115
+ default_to_square = False
116
+ elif "height" in size and "width" in size:
117
+ size = (size["height"], size["width"])
118
+ else:
119
+ raise ValueError(
120
+ "Size must contain either 'shortest_edge' or 'height' and 'width'."
121
+ )
122
+
123
+ output_size = get_resize_output_image_size(
124
+ image,
125
+ size=size,
126
+ default_to_square=default_to_square,
127
+ input_data_format=input_data_format,
128
+ )
129
+ return resize(
130
+ image,
131
+ size=output_size,
132
+ resample=resample,
133
+ data_format=data_format,
134
+ input_data_format=input_data_format,
135
+ **kwargs,
136
+ )
137
+
138
+ def __call__(self, images: ImageInput, masks: ImageInput = None, **kwargs):
139
+ """
140
+ Adapted from:
141
+ https://github.com/huggingface/transformers/blob/3eddda1111f70f3a59485e08540e8262b927e867/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py
142
+ """
143
+ # single to iterable if needed
144
+ images = make_list_of_images(images)
145
+
146
+ # validate
147
+ if not valid_images(images):
148
+ raise ValueError(
149
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
150
+ "torch.Tensor, tf.Tensor or jax.ndarray."
151
+ )
152
+
153
+ # make numpy arrays
154
+ images = [to_numpy_array(image) for image in images]
155
+
156
+ # get channel dimensions
157
+ input_data_format = kwargs.get("input_data_format")
158
+ if input_data_format is None:
159
+ # We assume that all images have the same channel dimension format.
160
+ input_data_format = infer_channel_dimension_format(images[0])
161
+
162
+ # check if training
163
+ # todo: can also assume if masks are passed that we are doing training?
164
+ if kwargs.get("do_training", False) is True:
165
+ if mask is None:
166
+ raise ValueError("must pass masks if doing training.")
167
+ # todo: implement this soon.
168
+ raise NotImplementedError("not yet implemented.")
169
+ # Assume we want to do all transformations for training
170
+ else:
171
+ # do transformations for inference...
172
+ images = [
173
+ self.resize(
174
+ image=image,
175
+ size={
176
+ "shortest_edge": min(
177
+ kwargs.get("image_size") or self.image_size
178
+ )
179
+ },
180
+ resample=kwargs.get("resample") or self.resample,
181
+ input_data_format=input_data_format,
182
+ )
183
+ for image in images
184
+ ]
185
+ images = [
186
+ self.normalize(
187
+ image=image,
188
+ mean=kwargs.get("normalize_mean") or self.normalize_mean,
189
+ std=kwargs.get("normalize_std") or self.normalize_std,
190
+ input_data_format=input_data_format,
191
+ )
192
+ for image in images
193
+ ]
194
+ # fix dimensions
195
+ images = [
196
+ to_channel_dimension_format(
197
+ image,
198
+ kwargs.get("data_format") or self.data_format,
199
+ input_channel_dim=input_data_format,
200
+ )
201
+ for image in images
202
+ ]
203
+
204
+ data = {"pixel_values": images}
205
+ return BatchFeature(data=data, tensor_type="pt")
206
+
207
+ # Copied from transformers.models.segformer.image_processing_segformer.SegformerImageProcessor.post_process_semantic_segmentation
208
+ def post_process_semantic_segmentation(
209
+ self, outputs, target_sizes: List[Tuple] = None
210
+ ):
211
+ """
212
+ Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
213
+
214
+ Args:
215
+ outputs ([`SegformerForSemanticSegmentation`]):
216
+ Raw outputs of the model.
217
+ target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
218
+ List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
219
+ predictions will not be resized.
220
+
221
+ Returns:
222
+ semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
223
+ segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
224
+ specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
225
+ """
226
+ # TODO: add support for other frameworks
227
+ logits = outputs.logits
228
+
229
+ # Resize logits and compute semantic segmentation maps
230
+ if target_sizes is not None:
231
+ if len(logits) != len(target_sizes):
232
+ raise ValueError(
233
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
234
+ )
235
+
236
+ if is_torch_tensor(target_sizes):
237
+ target_sizes = target_sizes.numpy()
238
+
239
+ semantic_segmentation = []
240
+
241
+ for idx in range(len(logits)):
242
+ resized_logits = torch.nn.functional.interpolate(
243
+ logits[idx].unsqueeze(dim=0),
244
+ size=target_sizes[idx],
245
+ mode="bilinear",
246
+ align_corners=False,
247
+ )
248
+ semantic_map = resized_logits[0].argmax(dim=0)
249
+ semantic_segmentation.append(semantic_map)
250
+ else:
251
+ semantic_segmentation = logits.argmax(dim=1)
252
+ semantic_segmentation = [
253
+ semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])
254
+ ]
255
+
256
+ return semantic_segmentation
257
+
preprocessor_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processor.FaceSegformerImageProcessor"
4
+ },
5
+ "data_format": "channels_first",
6
+ "image_processor_type": "FaceSegformerImageProcessor",
7
+ "image_size": [
8
+ 224,
9
+ 224
10
+ ],
11
+ "normalize_mean": [
12
+ 0.485,
13
+ 0.456,
14
+ 0.406
15
+ ],
16
+ "normalize_std": [
17
+ 0.229,
18
+ 0.224,
19
+ 0.225
20
+ ],
21
+ "resample": 2
22
+ }