huu-ontocord commited on
Commit
e83e090
1 Parent(s): 9e922eb

Create image_processing_phi3_v.py

Browse files
Files changed (1) hide show
  1. image_processing_phi3_v.py +273 -0
image_processing_phi3_v.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Image processor class for Phi3-V."""
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ import numpy as np
21
+
22
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
23
+ from transformers.image_transforms import (
24
+ convert_to_rgb,
25
+ )
26
+ from transformers.image_utils import (
27
+ OPENAI_CLIP_MEAN,
28
+ OPENAI_CLIP_STD,
29
+ ImageInput,
30
+ make_list_of_images,
31
+ valid_images,
32
+ )
33
+ from transformers.utils import TensorType, is_vision_available, logging
34
+
35
+ from transformers import AutoImageProcessor
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ if is_vision_available():
41
+ from PIL import Image
42
+
43
+ import torch
44
+ import torchvision
45
+
46
+ def padding_336(b):
47
+ width, height = b.size
48
+ tar = int(np.ceil(height / 336) * 336)
49
+ top_padding = int((tar - height)/2)
50
+ bottom_padding = tar - height - top_padding
51
+ left_padding = 0
52
+ right_padding = 0
53
+ b = torchvision.transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255,255,255])
54
+
55
+ return b
56
+
57
+ def calc_padded_size(width, height, padding_unit=336):
58
+ target_height = int(np.ceil(height / padding_unit) * padding_unit)
59
+ top_padding = int((target_height - height) / 2)
60
+ bottom_padding = target_height - height - top_padding
61
+ left_padding = 0
62
+ right_padding = 0
63
+ padded_width = width + left_padding + right_padding
64
+ padded_height = height + top_padding + bottom_padding
65
+ return padded_width, padded_height
66
+
67
+ def HD_transform(img, hd_num=16):
68
+ width, height = img.size
69
+ trans = False
70
+ if width < height:
71
+ img = img.transpose(Image.TRANSPOSE)
72
+ trans = True
73
+ width, height = img.size
74
+ ratio = (width/ height)
75
+ scale = 1
76
+ while scale*np.ceil(scale/ratio) <= hd_num:
77
+ scale += 1
78
+ scale -= 1
79
+ new_w = int(scale * 336)
80
+ new_h = int(new_w / ratio)
81
+
82
+ img = torchvision.transforms.functional.resize(img, [new_h, new_w],)
83
+ img = padding_336(img)
84
+ width, height = img.size
85
+ if trans:
86
+ img = img.transpose(Image.TRANSPOSE)
87
+
88
+ return img
89
+
90
+ def calc_hd_transform_size(width, height, hd_num=16):
91
+ transposed = False
92
+ if width < height:
93
+ width, height = height, width
94
+ transposed = True
95
+
96
+ ratio = width / height
97
+ scale = 1
98
+ while scale * np.ceil(scale / ratio) <= hd_num:
99
+ scale += 1
100
+ scale -= 1
101
+
102
+ new_width = int(scale * 336)
103
+ new_height = int(new_width / ratio)
104
+
105
+ padded_width, padded_height = calc_padded_size(new_width, new_height)
106
+
107
+ if transposed:
108
+ padded_width, padded_height = padded_height, padded_width
109
+
110
+ return padded_width, padded_height
111
+
112
+ def pad_to_max_num_crops_tensor(images, max_crops=5):
113
+ """
114
+ images: B x 3 x H x W, B<=max_crops
115
+ """
116
+ B, _, H, W = images.shape
117
+ if B < max_crops:
118
+ pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device)
119
+ images = torch.cat([images, pad], dim=0)
120
+ return images
121
+
122
+
123
+ class Phi3VImageProcessor(BaseImageProcessor):
124
+ r"""
125
+ Constructs a Phi3 image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques
126
+ for processing high resolution images as explained in the [InternLM-XComposer2-4KHD](https://arxiv.org/abs/2401.16420)
127
+ Args:
128
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
129
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
130
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
131
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
132
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
133
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
134
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
135
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
136
+ Whether to convert the image to RGB.
137
+ """
138
+
139
+ model_input_names = ["pixel_values"]
140
+
141
+ def __init__(
142
+ self,
143
+ num_crops: int = 1,
144
+ image_mean: Optional[Union[float, List[float]]] = None,
145
+ image_std: Optional[Union[float, List[float]]] = None,
146
+ do_convert_rgb: bool = True,
147
+ **kwargs,
148
+ ) -> None:
149
+ super().__init__(**kwargs)
150
+ self.num_crops = num_crops
151
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
152
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
153
+ self.do_convert_rgb = do_convert_rgb
154
+
155
+ def calc_num_image_tokens(
156
+ self,
157
+ images: ImageInput
158
+ ):
159
+ """ Calculate the number of image tokens for each image.
160
+ Args:
161
+ images (`ImageInput`):
162
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
163
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
164
+ """
165
+ images = make_list_of_images(images)
166
+
167
+ if not valid_images(images):
168
+ raise ValueError(
169
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
170
+ "torch.Tensor, tf.Tensor or jax.ndarray."
171
+ )
172
+
173
+ images = [image.convert('RGB') for image in images]
174
+ # (H, W, C)
175
+ elems = [HD_transform(im, hd_num = self.num_crops) for im in images]
176
+ shapes = [[im.size[1], im.size[0]] for im in elems]
177
+ num_img_tokens = [int((h//336*w//336+1)*144 + 1 + (h//336+1)*12) for h, w in shapes]
178
+ return num_img_tokens
179
+
180
+ def calc_num_image_tokens_from_image_size(self, width, height):
181
+ """
182
+ Calculate the number of image tokens for a given image size.
183
+ Args:
184
+ width (`int`): Width of the image.
185
+ height (`int`): Height of the image.
186
+ """
187
+ new_width, new_height = calc_hd_transform_size(width, height, hd_num=self.num_crops)
188
+ num_img_tokens = int((new_height // 336 * new_width // 336 + 1) * 144 + 1 + (new_height // 336 + 1) * 12)
189
+ return num_img_tokens
190
+
191
+ def preprocess(
192
+ self,
193
+ images: ImageInput,
194
+ image_mean: Optional[Union[float, List[float]]] = None,
195
+ image_std: Optional[Union[float, List[float]]] = None,
196
+ do_convert_rgb: bool = None,
197
+ return_tensors: Optional[Union[str, TensorType]] = None,
198
+ ):
199
+ """
200
+ Args:
201
+ images (`ImageInput`):
202
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
203
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
204
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
205
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
206
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
207
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
208
+ `True`.
209
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
210
+ Whether to convert the image to RGB.
211
+ return_tensors (`str` or `TensorType`, *optional*):
212
+ The type of tensors to return. Can be one of:
213
+ - Unset: Return a list of `np.ndarray`.
214
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
215
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
216
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
217
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
218
+ """
219
+ image_mean = image_mean if image_mean is not None else self.image_mean
220
+ image_std = image_std if image_std is not None else self.image_std
221
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
222
+
223
+ images = make_list_of_images(images)
224
+
225
+ if not valid_images(images):
226
+ raise ValueError(
227
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
228
+ "torch.Tensor, tf.Tensor or jax.ndarray."
229
+ )
230
+
231
+ if do_convert_rgb:
232
+ images = [convert_to_rgb(image) for image in images]
233
+
234
+ image_sizes = []
235
+ img_processor = torchvision.transforms.Compose([
236
+ torchvision.transforms.ToTensor(),
237
+ torchvision.transforms.Normalize(image_mean, image_std)
238
+ ])
239
+
240
+ # PIL images
241
+ # HD_transform pad images to size of multiiply of 336, 336
242
+ # convert to RGB first
243
+ images = [image.convert('RGB') for image in images]
244
+ elems = [HD_transform(im, hd_num = self.num_crops) for im in images]
245
+ # tensor transform and normalize
246
+ hd_images = [img_processor(im) for im in elems]
247
+ # create global image
248
+ global_image = [torch.nn.functional.interpolate(im.unsqueeze(0).float(), size=(336, 336), mode='bicubic',).to(im.dtype) for im in hd_images]
249
+
250
+ # [(3, h, w)], where h, w is multiple of 336
251
+ shapes = [[im.size(1), im.size(2)] for im in hd_images]
252
+ num_img_tokens = [int((h//336*w//336+1)*144 + 1 + (h//336+1)*12) for h, w in shapes]
253
+ # reshape to channel dimension -> (num_images, num_crops, 3, 336, 336)
254
+ # (1, 3, h//336, 336, w//336, 336) -> (1, h//336, w//336, 3, 336, 336) -> (h//336*w//336, 3, 336, 336)
255
+ hd_images_reshape = [im.reshape(1, 3, h//336, 336, w//336, 336).permute(0,2,4,1,3,5).reshape(-1, 3, 336, 336).contiguous() for im, (h, w) in zip(hd_images, shapes)]
256
+ # concat global image and local image
257
+ hd_images_reshape = [torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape)]
258
+
259
+ # pad to max_num_crops
260
+ image_transformed = [pad_to_max_num_crops_tensor(im, self.num_crops+1) for im in hd_images_reshape]
261
+ image_transformed = torch.stack(image_transformed, dim=0)
262
+ image_sizes = [torch.LongTensor(_shapes) for _shapes in shapes]
263
+ padded_images = image_transformed
264
+ image_sizes = shapes
265
+
266
+ data = {"pixel_values": padded_images,
267
+ "image_sizes": image_sizes,
268
+ "num_img_tokens": num_img_tokens
269
+ }
270
+
271
+ return BatchFeature(data=data, tensor_type=return_tensors)
272
+
273
+ AutoImageProcessor.register("Phi3VImageProcessor", Phi3VImageProcessor)