katuni4ka commited on
Commit
bc58373
·
verified ·
1 Parent(s): c5e5ff2

Upload 17 files

Browse files
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "minicpm_v26",
3
  "architectures": [
4
  "MiniCPMV"
5
  ],
@@ -17,7 +17,7 @@
17
  "hidden_size": 256,
18
  "image_size": 28,
19
  "initializer_range": 0.02,
20
- "intermediate_size": 37,
21
  "max_position_embeddings": 32768,
22
  "max_window_layers": 2,
23
  "model_type": "minicpmv",
@@ -37,16 +37,17 @@
37
  "sliding_window": null,
38
  "tie_word_embeddings": false,
39
  "torch_dtype": "float32",
40
- "transformers_version": "4.45.1",
41
  "use_cache": true,
42
  "use_image_id": true,
43
  "use_sliding_window": false,
44
  "version": 2.6,
45
  "vision_batch_size": 16,
46
  "vision_config": {
 
47
  "hidden_size": 64,
48
  "image_size": 28,
49
- "intermediate_size": 4304,
50
  "model_type": "siglip_vision_model",
51
  "num_attention_heads": 2,
52
  "num_hidden_layers": 4,
 
1
  {
2
+ "_name_or_path": "/home/ea/work/my_optimum_intel/optimum-intel/tiny-random-minicpmv-2_6",
3
  "architectures": [
4
  "MiniCPMV"
5
  ],
 
17
  "hidden_size": 256,
18
  "image_size": 28,
19
  "initializer_range": 0.02,
20
+ "intermediate_size": 128,
21
  "max_position_embeddings": 32768,
22
  "max_window_layers": 2,
23
  "model_type": "minicpmv",
 
37
  "sliding_window": null,
38
  "tie_word_embeddings": false,
39
  "torch_dtype": "float32",
40
+ "transformers_version": "4.46.1",
41
  "use_cache": true,
42
  "use_image_id": true,
43
  "use_sliding_window": false,
44
  "version": 2.6,
45
  "vision_batch_size": 16,
46
  "vision_config": {
47
+ "_attn_implementation_autoset": true,
48
  "hidden_size": 64,
49
  "image_size": 28,
50
+ "intermediate_size": 128,
51
  "model_type": "siglip_vision_model",
52
  "num_attention_heads": 2,
53
  "num_hidden_layers": 4,
configuration_minicpm.py CHANGED
@@ -4,10 +4,12 @@
4
  import os
5
  from typing import Union
6
 
 
7
  from transformers.utils import logging
8
- from transformers import Qwen2Config, PretrainedConfig
9
  from .modeling_navit_siglip import SiglipVisionConfig
10
 
 
11
  logger = logging.get_logger(__name__)
12
 
13
 
@@ -44,7 +46,6 @@ class MiniCPMVSliceConfig(PretrainedConfig):
44
  return cls.from_dict(config_dict, **kwargs)
45
 
46
 
47
-
48
  class MiniCPMVConfig(Qwen2Config):
49
  model_type = "minicpmv"
50
  keys_to_ignore_at_inference = ["past_key_values"]
 
4
  import os
5
  from typing import Union
6
 
7
+ from transformers import PretrainedConfig, Qwen2Config
8
  from transformers.utils import logging
9
+
10
  from .modeling_navit_siglip import SiglipVisionConfig
11
 
12
+
13
  logger = logging.get_logger(__name__)
14
 
15
 
 
46
  return cls.from_dict(config_dict, **kwargs)
47
 
48
 
 
49
  class MiniCPMVConfig(Qwen2Config):
50
  model_type = "minicpmv"
51
  keys_to_ignore_at_inference = ["past_key_values"]
generation_config.json CHANGED
@@ -2,5 +2,5 @@
2
  "_from_model_config": true,
3
  "bos_token_id": 151643,
4
  "eos_token_id": 151645,
5
- "transformers_version": "4.45.1"
6
  }
 
2
  "_from_model_config": true,
3
  "bos_token_id": 151643,
4
  "eos_token_id": 151645,
5
+ "transformers_version": "4.46.1"
6
  }
image_processing_minicpmv.py CHANGED
@@ -1,27 +1,23 @@
1
- from typing import Optional, Union, Dict, Any, List
2
-
3
- import torch
4
  import math
5
- import PIL.Image
6
- import PIL.ImageSequence
7
  import numpy as np
8
  import PIL
 
 
 
9
  from PIL import Image
10
-
11
- from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device
12
- from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
13
  from transformers import AutoImageProcessor
 
14
  from transformers.image_transforms import to_channel_dimension_format
15
  from transformers.image_utils import (
16
- ImageInput,
17
- make_list_of_images,
18
- valid_images,
19
- is_torch_tensor,
20
- is_batched,
21
- to_numpy_array,
22
  infer_channel_dimension_format,
23
- ChannelDimension
 
 
24
  )
 
25
 
26
 
27
  def recursive_converter(converter, value):
@@ -38,6 +34,7 @@ class MiniCPMVBatchFeature(BatchFeature):
38
  r"""
39
  Extend from BatchFeature for supporting various image size
40
  """
 
41
  def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
42
  super().__init__(data)
43
  self.convert_to_tensors(tensor_type=tensor_type)
@@ -45,7 +42,7 @@ class MiniCPMVBatchFeature(BatchFeature):
45
  def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
46
  if tensor_type is None:
47
  return self
48
-
49
  is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
50
 
51
  def converter(value):
@@ -61,11 +58,10 @@ class MiniCPMVBatchFeature(BatchFeature):
61
  "with 'padding=True' to have batched tensors with the same length."
62
  )
63
 
64
-
65
  for key, value in self.items():
66
  self[key] = recursive_converter(converter, value)
67
  return self
68
-
69
  def to(self, *args, **kwargs) -> "MiniCPMVBatchFeature":
70
  requires_backends(self, ["torch"])
71
  import torch
@@ -104,12 +100,7 @@ class MiniCPMVBatchFeature(BatchFeature):
104
  class MiniCPMVImageProcessor(BaseImageProcessor):
105
  model_input_names = ["pixel_values"]
106
 
107
- def __init__(
108
- self,
109
- max_slice_nums=9,
110
- scale_resolution=448,
111
- patch_size=14,
112
- **kwargs):
113
  super().__init__(**kwargs)
114
  self.max_slice_nums = max_slice_nums
115
  self.scale_resolution = scale_resolution
@@ -131,14 +122,9 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
131
  def ensure_divide(self, length, patch_size):
132
  return max(round(length / patch_size) * patch_size, patch_size)
133
 
134
- def find_best_resize(self,
135
- original_size,
136
- scale_resolution,
137
- patch_size,
138
- allow_upscale=False):
139
  width, height = original_size
140
- if (width * height >
141
- scale_resolution * scale_resolution) or allow_upscale:
142
  r = width / height
143
  height = int(scale_resolution / math.sqrt(r))
144
  width = int(height * r)
@@ -146,12 +132,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
146
  best_height = self.ensure_divide(height, patch_size)
147
  return (best_width, best_height)
148
 
149
- def get_refine_size(self,
150
- original_size,
151
- grid,
152
- scale_resolution,
153
- patch_size,
154
- allow_upscale=False):
155
  width, height = original_size
156
  grid_x, grid_y = grid
157
 
@@ -161,10 +142,9 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
161
  grid_width = refine_width / grid_x
162
  grid_height = refine_height / grid_y
163
 
164
- best_grid_size = self.find_best_resize((grid_width, grid_height),
165
- scale_resolution,
166
- patch_size,
167
- allow_upscale=allow_upscale)
168
  refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
169
  return refine_size
170
 
@@ -182,9 +162,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
182
  patches.append(images)
183
  return patches
184
 
185
- def slice_image(
186
- self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
187
- ):
188
  original_size = image.size
189
  source_image = None
190
  best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split)
@@ -192,9 +170,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
192
 
193
  if best_grid is None:
194
  # dont need to slice, upsample
195
- best_size = self.find_best_resize(
196
- original_size, scale_resolution, patch_size, allow_upscale=True
197
- )
198
  source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC)
199
  else:
200
  # source image, down-sampling and ensure divided by patch_size
@@ -212,9 +188,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
212
  if grid is None:
213
  return ""
214
  slice_image_placeholder = (
215
- self.slice_start_token
216
- + self.unk_token * self.image_feature_size
217
- + self.slice_end_token
218
  )
219
 
220
  cols = grid[0]
@@ -225,13 +199,13 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
225
  for j in range(cols):
226
  lines.append(slice_image_placeholder)
227
  slices.append("".join(lines))
228
-
229
  slice_placeholder = "\n".join(slices)
230
  return slice_placeholder
231
 
232
  def get_image_id_placeholder(self, idx=0):
233
  return f"{self.im_id_start}{idx}{self.im_id_end}"
234
-
235
  def get_sliced_images(self, image, max_slice_nums=None):
236
  slice_images = []
237
 
@@ -239,12 +213,9 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
239
  return [image]
240
 
241
  max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
242
- assert max_slice_nums > 0
243
  source_image, patches, sliced_grid = self.slice_image(
244
- image,
245
- max_slice_nums, # default: 9
246
- self.scale_resolution, # default: 448
247
- self.patch_size # default: 14
248
  )
249
 
250
  slice_images.append(source_image)
@@ -266,7 +237,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
266
  if i == 1 or i > max_slice_nums:
267
  continue
268
  candidate_split_grids_nums.append(i)
269
-
270
  candidate_grids = []
271
  for split_grids_nums in candidate_split_grids_nums:
272
  m = 1
@@ -282,19 +253,15 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
282
  if error < min_error:
283
  best_grid = grid
284
  min_error = error
285
-
286
  return best_grid
287
-
288
  def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None):
289
  max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
290
- assert max_slice_nums > 0
291
  grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums)
292
 
293
- image_placeholder = (
294
- self.im_start_token
295
- + self.unk_token * self.image_feature_size
296
- + self.im_end_token
297
- )
298
  use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id)
299
  if use_image_id:
300
  final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder
@@ -304,7 +271,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
304
  if self.slice_mode:
305
  final_placeholder = final_placeholder + self.get_grid_placeholder(grid=grid)
306
  return final_placeholder
307
-
308
  def to_pil_image(self, image, rescale=None) -> PIL.Image.Image:
309
  """
310
  Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
@@ -343,24 +310,20 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
343
  """
344
  image = torch.from_numpy(image)
345
  patch_size = self.patch_size
346
- patches = torch.nn.functional.unfold(
347
- image,
348
- (patch_size, patch_size),
349
- stride=(patch_size, patch_size)
350
- )
351
 
352
  patches = patches.reshape(image.size(0), patch_size, patch_size, -1)
353
  patches = patches.permute(0, 1, 3, 2).reshape(image.size(0), patch_size, -1)
354
  return patches.numpy()
355
 
356
  def preprocess(
357
- self,
358
- images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
359
- do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5
360
- max_slice_nums: int = None,
361
- return_tensors: Optional[Union[str, TensorType]] = None,
362
- **kwargs
363
- ) -> MiniCPMVBatchFeature:
364
  if isinstance(images, Image.Image):
365
  images_list = [[images]]
366
  elif isinstance(images[0], Image.Image):
@@ -371,19 +334,19 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
371
  new_images_list = []
372
  image_sizes_list = []
373
  tgt_sizes_list = []
374
-
375
  for _images in images_list:
376
  if _images is None or len(_images) == 0:
377
  new_images_list.append([])
378
  image_sizes_list.append([])
379
  tgt_sizes_list.append([])
380
- continue
381
  if not valid_images(_images):
382
  raise ValueError(
383
  "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
384
  "torch.Tensor, tf.Tensor or jax.ndarray."
385
  )
386
-
387
  _images = [self.to_pil_image(image).convert("RGB") for image in _images]
388
  input_data_format = infer_channel_dimension_format(np.array(_images[0]))
389
 
@@ -395,24 +358,28 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
395
  image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches]
396
  image_patches = [
397
  self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format)
398
- for image in image_patches
399
  ]
400
  image_patches = [
401
- to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
402
- for image in image_patches
403
  ]
404
  for slice_image in image_patches:
405
  new_images.append(self.reshape_by_patch(slice_image))
406
- tgt_sizes.append(np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size)))
 
 
407
 
408
  if tgt_sizes:
409
  tgt_sizes = np.vstack(tgt_sizes)
410
-
411
  new_images_list.append(new_images)
412
  image_sizes_list.append(image_sizes)
413
  tgt_sizes_list.append(tgt_sizes)
414
  return MiniCPMVBatchFeature(
415
- data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list}, tensor_type=return_tensors
 
416
  )
417
 
 
418
  AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor)
 
 
 
 
1
  import math
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
  import numpy as np
5
  import PIL
6
+ import PIL.Image
7
+ import PIL.ImageSequence
8
+ import torch
9
  from PIL import Image
 
 
 
10
  from transformers import AutoImageProcessor
11
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
12
  from transformers.image_transforms import to_channel_dimension_format
13
  from transformers.image_utils import (
14
+ ChannelDimension,
 
 
 
 
 
15
  infer_channel_dimension_format,
16
+ is_torch_tensor,
17
+ to_numpy_array,
18
+ valid_images,
19
  )
20
+ from transformers.utils import TensorType, is_torch_device, is_torch_dtype, requires_backends
21
 
22
 
23
  def recursive_converter(converter, value):
 
34
  r"""
35
  Extend from BatchFeature for supporting various image size
36
  """
37
+
38
  def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
39
  super().__init__(data)
40
  self.convert_to_tensors(tensor_type=tensor_type)
 
42
  def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
43
  if tensor_type is None:
44
  return self
45
+
46
  is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
47
 
48
  def converter(value):
 
58
  "with 'padding=True' to have batched tensors with the same length."
59
  )
60
 
 
61
  for key, value in self.items():
62
  self[key] = recursive_converter(converter, value)
63
  return self
64
+
65
  def to(self, *args, **kwargs) -> "MiniCPMVBatchFeature":
66
  requires_backends(self, ["torch"])
67
  import torch
 
100
  class MiniCPMVImageProcessor(BaseImageProcessor):
101
  model_input_names = ["pixel_values"]
102
 
103
+ def __init__(self, max_slice_nums=9, scale_resolution=448, patch_size=14, **kwargs):
 
 
 
 
 
104
  super().__init__(**kwargs)
105
  self.max_slice_nums = max_slice_nums
106
  self.scale_resolution = scale_resolution
 
122
  def ensure_divide(self, length, patch_size):
123
  return max(round(length / patch_size) * patch_size, patch_size)
124
 
125
+ def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False):
 
 
 
 
126
  width, height = original_size
127
+ if (width * height > scale_resolution * scale_resolution) or allow_upscale:
 
128
  r = width / height
129
  height = int(scale_resolution / math.sqrt(r))
130
  width = int(height * r)
 
132
  best_height = self.ensure_divide(height, patch_size)
133
  return (best_width, best_height)
134
 
135
+ def get_refine_size(self, original_size, grid, scale_resolution, patch_size, allow_upscale=False):
 
 
 
 
 
136
  width, height = original_size
137
  grid_x, grid_y = grid
138
 
 
142
  grid_width = refine_width / grid_x
143
  grid_height = refine_height / grid_y
144
 
145
+ best_grid_size = self.find_best_resize(
146
+ (grid_width, grid_height), scale_resolution, patch_size, allow_upscale=allow_upscale
147
+ )
 
148
  refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
149
  return refine_size
150
 
 
162
  patches.append(images)
163
  return patches
164
 
165
+ def slice_image(self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):
 
 
166
  original_size = image.size
167
  source_image = None
168
  best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split)
 
170
 
171
  if best_grid is None:
172
  # dont need to slice, upsample
173
+ best_size = self.find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)
 
 
174
  source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC)
175
  else:
176
  # source image, down-sampling and ensure divided by patch_size
 
188
  if grid is None:
189
  return ""
190
  slice_image_placeholder = (
191
+ self.slice_start_token + self.unk_token * self.image_feature_size + self.slice_end_token
 
 
192
  )
193
 
194
  cols = grid[0]
 
199
  for j in range(cols):
200
  lines.append(slice_image_placeholder)
201
  slices.append("".join(lines))
202
+
203
  slice_placeholder = "\n".join(slices)
204
  return slice_placeholder
205
 
206
  def get_image_id_placeholder(self, idx=0):
207
  return f"{self.im_id_start}{idx}{self.im_id_end}"
208
+
209
  def get_sliced_images(self, image, max_slice_nums=None):
210
  slice_images = []
211
 
 
213
  return [image]
214
 
215
  max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
216
+ assert max_slice_nums > 0
217
  source_image, patches, sliced_grid = self.slice_image(
218
+ image, max_slice_nums, self.scale_resolution, self.patch_size # default: 9 # default: 448 # default: 14
 
 
 
219
  )
220
 
221
  slice_images.append(source_image)
 
237
  if i == 1 or i > max_slice_nums:
238
  continue
239
  candidate_split_grids_nums.append(i)
240
+
241
  candidate_grids = []
242
  for split_grids_nums in candidate_split_grids_nums:
243
  m = 1
 
253
  if error < min_error:
254
  best_grid = grid
255
  min_error = error
256
+
257
  return best_grid
258
+
259
  def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None):
260
  max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
261
+ assert max_slice_nums > 0
262
  grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums)
263
 
264
+ image_placeholder = self.im_start_token + self.unk_token * self.image_feature_size + self.im_end_token
 
 
 
 
265
  use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id)
266
  if use_image_id:
267
  final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder
 
271
  if self.slice_mode:
272
  final_placeholder = final_placeholder + self.get_grid_placeholder(grid=grid)
273
  return final_placeholder
274
+
275
  def to_pil_image(self, image, rescale=None) -> PIL.Image.Image:
276
  """
277
  Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
 
310
  """
311
  image = torch.from_numpy(image)
312
  patch_size = self.patch_size
313
+ patches = torch.nn.functional.unfold(image, (patch_size, patch_size), stride=(patch_size, patch_size))
 
 
 
 
314
 
315
  patches = patches.reshape(image.size(0), patch_size, patch_size, -1)
316
  patches = patches.permute(0, 1, 3, 2).reshape(image.size(0), patch_size, -1)
317
  return patches.numpy()
318
 
319
  def preprocess(
320
+ self,
321
+ images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
322
+ do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5
323
+ max_slice_nums: int = None,
324
+ return_tensors: Optional[Union[str, TensorType]] = None,
325
+ **kwargs,
326
+ ) -> MiniCPMVBatchFeature:
327
  if isinstance(images, Image.Image):
328
  images_list = [[images]]
329
  elif isinstance(images[0], Image.Image):
 
334
  new_images_list = []
335
  image_sizes_list = []
336
  tgt_sizes_list = []
337
+
338
  for _images in images_list:
339
  if _images is None or len(_images) == 0:
340
  new_images_list.append([])
341
  image_sizes_list.append([])
342
  tgt_sizes_list.append([])
343
+ continue
344
  if not valid_images(_images):
345
  raise ValueError(
346
  "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
347
  "torch.Tensor, tf.Tensor or jax.ndarray."
348
  )
349
+
350
  _images = [self.to_pil_image(image).convert("RGB") for image in _images]
351
  input_data_format = infer_channel_dimension_format(np.array(_images[0]))
352
 
 
358
  image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches]
359
  image_patches = [
360
  self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format)
361
+ for image in image_patches
362
  ]
363
  image_patches = [
364
+ to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
365
+ for image in image_patches
366
  ]
367
  for slice_image in image_patches:
368
  new_images.append(self.reshape_by_patch(slice_image))
369
+ tgt_sizes.append(
370
+ np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))
371
+ )
372
 
373
  if tgt_sizes:
374
  tgt_sizes = np.vstack(tgt_sizes)
375
+
376
  new_images_list.append(new_images)
377
  image_sizes_list.append(image_sizes)
378
  tgt_sizes_list.append(tgt_sizes)
379
  return MiniCPMVBatchFeature(
380
+ data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list},
381
+ tensor_type=return_tensors,
382
  )
383
 
384
+
385
  AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:244f72a0389de521d87c3411aaf425ebb85e19144f557f6ed0363ce84eb385f5
3
- size 323558976
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a13c2a624f4445809755648b73465369b127bf9f4c7a6a87ccf0c7498039149
3
+ size 315498808
modeling_minicpmv.py CHANGED
@@ -1,20 +1,17 @@
1
- import math
2
- from typing import List, Optional
3
  import json
4
- import torch
5
- import torchvision
6
-
7
- from threading import Thread
8
  from copy import deepcopy
 
 
 
9
  from PIL import Image
10
- from transformers import AutoProcessor, Qwen2PreTrainedModel, Qwen2ForCausalLM, TextIteratorStreamer
11
 
12
  from .configuration_minicpm import MiniCPMVConfig
13
  from .modeling_navit_siglip import SiglipVisionTransformer
14
  from .resampler import Resampler
15
 
16
 
17
-
18
  class MiniCPMVPreTrainedModel(Qwen2PreTrainedModel):
19
  config_class = MiniCPMVConfig
20
 
@@ -29,21 +26,21 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
29
  self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
30
  self.processor = None
31
 
32
- self.terminators = ['<|im_end|>', '<|endoftext|>']
33
 
34
  def init_vision_module(self):
35
  # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
36
- if self.config._attn_implementation == 'flash_attention_2':
37
- self.config.vision_config._attn_implementation = 'flash_attention_2'
38
  else:
39
  # not suport sdpa
40
- self.config.vision_config._attn_implementation = 'eager'
41
  model = SiglipVisionTransformer(self.config.vision_config)
42
  if self.config.drop_vision_last_layer:
43
  model.encoder.layers = model.encoder.layers[:-1]
44
 
45
- setattr(model, 'embed_dim', model.embeddings.embed_dim)
46
- setattr(model, 'patch_size', model.embeddings.patch_size)
47
 
48
  return model
49
 
@@ -53,7 +50,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
53
  embed_dim=embed_dim,
54
  num_heads=embed_dim // 128,
55
  kv_dim=vision_dim,
56
- adaptive=True
57
  )
58
 
59
  def get_input_embeddings(self):
@@ -75,11 +72,11 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
75
  return self.llm
76
 
77
  def get_vllm_embedding(self, data):
78
- if 'vision_hidden_states' not in data:
79
  dtype = self.llm.model.embed_tokens.weight.dtype
80
  device = self.llm.model.embed_tokens.weight.device
81
- tgt_sizes = data['tgt_sizes']
82
- pixel_values_list = data['pixel_values']
83
  vision_hidden_states = []
84
  all_pixel_values = []
85
  img_cnt = []
@@ -94,14 +91,15 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
94
 
95
  max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
96
 
97
- all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True,
98
- padding_value=0.0)
 
99
  B, L, _ = all_pixel_values.shape
100
  all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
101
 
102
  patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
103
  for i in range(B):
104
- patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
105
 
106
  vision_batch_size = self.config.vision_batch_size
107
  all_pixel_values = all_pixel_values.type(dtype)
@@ -110,28 +108,33 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
110
  for i in range(0, B, vision_batch_size):
111
  start_idx = i
112
  end_idx = i + vision_batch_size
113
- tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx]).last_hidden_state
 
 
 
 
114
  hs.append(tmp_hs)
115
  vision_embedding = torch.cat(hs, dim=0)
116
  else:
117
- vision_embedding = self.vpm(all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state
 
 
118
  vision_embedding = self.resampler(vision_embedding, tgt_sizes)
119
 
120
  start = 0
121
  for pixel_values in pixel_values_list:
122
  img_cnt = len(pixel_values)
123
  if img_cnt > 0:
124
- vision_hidden_states.append(vision_embedding[start: start + img_cnt])
125
  start += img_cnt
126
  else:
127
  vision_hidden_states.append([])
128
- else: # no image
129
  if self.training:
130
- dummy_image = torch.zeros(
131
- (1, 3, 224, 224),
132
- device=device, dtype=dtype
133
- )
134
- tgt_sizes = torch.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).type(torch.int32)
135
  dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
136
  else:
137
  dummy_feature = []
@@ -139,29 +142,33 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
139
  vision_hidden_states.append(dummy_feature)
140
 
141
  else:
142
- vision_hidden_states = data['vision_hidden_states']
143
 
144
- if hasattr(self.llm.config, 'scale_emb'):
145
- vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
146
  else:
147
- vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])
148
 
149
- vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(
150
- i, torch.Tensor) else i for i in vision_hidden_states]
 
151
 
152
- bs = len(data['input_ids'])
153
  for i in range(bs):
154
  cur_vs_hs = vision_hidden_states[i]
155
  if len(cur_vs_hs) > 0:
156
  cur_vllm_emb = vllm_embedding[i]
157
- cur_image_bound = data['image_bound'][i]
158
  if len(cur_image_bound) > 0:
159
  image_indices = torch.stack(
160
  [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
161
  ).to(vllm_embedding.device)
162
 
163
- cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
164
- cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
 
 
 
165
  elif self.training:
166
  cur_vllm_emb += cur_vs_hs[0].mean() * 0
167
 
@@ -173,13 +180,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
173
  if position_ids.dtype != torch.int64:
174
  position_ids = position_ids.long()
175
 
176
- return self.llm(
177
- input_ids=None,
178
- position_ids=position_ids,
179
- inputs_embeds=vllm_embedding,
180
- **kwargs
181
- )
182
-
183
  def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, **kwargs):
184
  terminators = None
185
  if tokenizer is not None:
@@ -187,10 +189,10 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
187
  kwargs.pop("image_sizes")
188
  output = self.llm.generate(
189
  inputs_embeds=inputs_embeds,
190
- #pad_token_id=0,
191
  eos_token_id=terminators,
192
  attention_mask=attention_mask,
193
- **kwargs
194
  )
195
  if decode_text:
196
  return self._decode_text(output, tokenizer)
@@ -200,16 +202,16 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
200
  terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
201
  streamer = TextIteratorStreamer(tokenizer=tokenizer)
202
  generation_kwargs = {
203
- 'inputs_embeds': inputs_embeds,
204
- 'pad_token_id': 0,
205
- 'eos_token_id': terminators,
206
- 'streamer': streamer
207
  }
208
  generation_kwargs.update(kwargs)
209
 
210
  thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
211
  thread.start()
212
-
213
  return streamer
214
 
215
  def _decode_text(self, result_ids, tokenizer):
@@ -236,7 +238,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
236
  return_vision_hidden_states=False,
237
  stream=False,
238
  decode_text=False,
239
- **kwargs
240
  ):
241
  assert input_ids is not None
242
  assert len(input_ids) == len(pixel_values)
@@ -248,7 +250,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
248
 
249
  if vision_hidden_states is None:
250
  model_inputs["pixel_values"] = pixel_values
251
- model_inputs['tgt_sizes'] = tgt_sizes
252
  else:
253
  model_inputs["vision_hidden_states"] = vision_hidden_states
254
 
@@ -261,11 +263,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
261
  if stream:
262
  result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
263
  else:
264
- result = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, decode_text=decode_text, **kwargs)
 
 
265
 
266
  if return_vision_hidden_states:
267
  return result, vision_hidden_states
268
-
269
  return result
270
 
271
  def chat(
@@ -279,11 +283,11 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
279
  min_new_tokens=0,
280
  sampling=True,
281
  max_inp_length=8192,
282
- system_prompt='',
283
  stream=False,
284
  max_slice_nums=None,
285
  use_image_id=None,
286
- **kwargs
287
  ):
288
  if isinstance(msgs[0], list):
289
  batched = True
@@ -291,7 +295,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
291
  batched = False
292
  msgs_list = msgs
293
  images_list = image
294
-
295
  if batched is False:
296
  images_list, msgs_list = [images_list], [msgs_list]
297
  else:
@@ -303,12 +307,22 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
303
  if self.processor is None:
304
  self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
305
  processor = self.processor
306
-
307
- assert self.config.query_num == processor.image_processor.image_feature_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
308
- assert self.config.patch_size == processor.image_processor.patch_size, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
309
- assert self.config.use_image_id == processor.image_processor.use_image_id, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
310
- assert self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
311
- assert self.config.slice_mode == processor.image_processor.slice_mode, "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
 
 
 
 
 
 
 
 
 
 
312
 
313
  prompts_lists = []
314
  input_images_lists = []
@@ -342,19 +356,21 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
342
  msg["content"] = "\n".join(cur_msgs)
343
 
344
  if system_prompt:
345
- sys_msg = {'role': 'system', 'content': system_prompt}
346
- copy_msgs = [sys_msg] + copy_msgs
347
 
348
- prompts_lists.append(processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True))
 
 
349
  input_images_lists.append(images)
350
 
351
  inputs = processor(
352
- prompts_lists,
353
- input_images_lists,
354
  max_slice_nums=max_slice_nums,
355
  use_image_id=use_image_id,
356
- return_tensors="pt",
357
- max_length=max_inp_length
358
  ).to(self.device)
359
 
360
  if sampling:
@@ -363,20 +379,18 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
363
  "top_k": 100,
364
  "temperature": 0.7,
365
  "do_sample": True,
366
- "repetition_penalty": 1.05
367
  }
368
  else:
369
  generation_config = {
370
  "num_beams": 3,
371
  "repetition_penalty": 1.2,
372
  }
373
-
374
  if min_new_tokens > 0:
375
- generation_config['min_new_tokens'] = min_new_tokens
376
 
377
- generation_config.update(
378
- (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
379
- )
380
 
381
  inputs.pop("image_sizes")
382
  with torch.inference_mode():
@@ -387,15 +401,17 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
387
  vision_hidden_states=vision_hidden_states,
388
  stream=stream,
389
  decode_text=True,
390
- **generation_config
391
  )
392
-
393
  if stream:
 
394
  def stream_gen():
395
  for text in res:
396
  for term in self.terminators:
397
- text = text.replace(term, '')
398
  yield text
 
399
  return stream_gen()
400
 
401
  else:
 
 
 
1
  import json
2
+ import math
 
 
 
3
  from copy import deepcopy
4
+ from threading import Thread
5
+
6
+ import torch
7
  from PIL import Image
8
+ from transformers import AutoProcessor, Qwen2ForCausalLM, Qwen2PreTrainedModel, TextIteratorStreamer
9
 
10
  from .configuration_minicpm import MiniCPMVConfig
11
  from .modeling_navit_siglip import SiglipVisionTransformer
12
  from .resampler import Resampler
13
 
14
 
 
15
  class MiniCPMVPreTrainedModel(Qwen2PreTrainedModel):
16
  config_class = MiniCPMVConfig
17
 
 
26
  self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
27
  self.processor = None
28
 
29
+ self.terminators = ["<|im_end|>", "<|endoftext|>"]
30
 
31
  def init_vision_module(self):
32
  # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
33
+ if self.config._attn_implementation == "flash_attention_2":
34
+ self.config.vision_config._attn_implementation = "flash_attention_2"
35
  else:
36
  # not suport sdpa
37
+ self.config.vision_config._attn_implementation = "eager"
38
  model = SiglipVisionTransformer(self.config.vision_config)
39
  if self.config.drop_vision_last_layer:
40
  model.encoder.layers = model.encoder.layers[:-1]
41
 
42
+ setattr(model, "embed_dim", model.embeddings.embed_dim)
43
+ setattr(model, "patch_size", model.embeddings.patch_size)
44
 
45
  return model
46
 
 
50
  embed_dim=embed_dim,
51
  num_heads=embed_dim // 128,
52
  kv_dim=vision_dim,
53
+ adaptive=True,
54
  )
55
 
56
  def get_input_embeddings(self):
 
72
  return self.llm
73
 
74
  def get_vllm_embedding(self, data):
75
+ if "vision_hidden_states" not in data:
76
  dtype = self.llm.model.embed_tokens.weight.dtype
77
  device = self.llm.model.embed_tokens.weight.device
78
+ tgt_sizes = data["tgt_sizes"]
79
+ pixel_values_list = data["pixel_values"]
80
  vision_hidden_states = []
81
  all_pixel_values = []
82
  img_cnt = []
 
91
 
92
  max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
93
 
94
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(
95
+ all_pixel_values, batch_first=True, padding_value=0.0
96
+ )
97
  B, L, _ = all_pixel_values.shape
98
  all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
99
 
100
  patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
101
  for i in range(B):
102
+ patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
103
 
104
  vision_batch_size = self.config.vision_batch_size
105
  all_pixel_values = all_pixel_values.type(dtype)
 
108
  for i in range(0, B, vision_batch_size):
109
  start_idx = i
110
  end_idx = i + vision_batch_size
111
+ tmp_hs = self.vpm(
112
+ all_pixel_values[start_idx:end_idx],
113
+ patch_attention_mask=patch_attn_mask[start_idx:end_idx],
114
+ tgt_sizes=tgt_sizes[start_idx:end_idx],
115
+ ).last_hidden_state
116
  hs.append(tmp_hs)
117
  vision_embedding = torch.cat(hs, dim=0)
118
  else:
119
+ vision_embedding = self.vpm(
120
+ all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes
121
+ ).last_hidden_state
122
  vision_embedding = self.resampler(vision_embedding, tgt_sizes)
123
 
124
  start = 0
125
  for pixel_values in pixel_values_list:
126
  img_cnt = len(pixel_values)
127
  if img_cnt > 0:
128
+ vision_hidden_states.append(vision_embedding[start : start + img_cnt])
129
  start += img_cnt
130
  else:
131
  vision_hidden_states.append([])
132
+ else: # no image
133
  if self.training:
134
+ dummy_image = torch.zeros((1, 3, 224, 224), device=device, dtype=dtype)
135
+ tgt_sizes = torch.Tensor(
136
+ [[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]
137
+ ).type(torch.int32)
 
138
  dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
139
  else:
140
  dummy_feature = []
 
142
  vision_hidden_states.append(dummy_feature)
143
 
144
  else:
145
+ vision_hidden_states = data["vision_hidden_states"]
146
 
147
+ if hasattr(self.llm.config, "scale_emb"):
148
+ vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb
149
  else:
150
+ vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
151
 
152
+ vision_hidden_states = [
153
+ i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
154
+ ]
155
 
156
+ bs = len(data["input_ids"])
157
  for i in range(bs):
158
  cur_vs_hs = vision_hidden_states[i]
159
  if len(cur_vs_hs) > 0:
160
  cur_vllm_emb = vllm_embedding[i]
161
+ cur_image_bound = data["image_bound"][i]
162
  if len(cur_image_bound) > 0:
163
  image_indices = torch.stack(
164
  [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
165
  ).to(vllm_embedding.device)
166
 
167
+ cur_vllm_emb.scatter_(
168
+ 0,
169
+ image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
170
+ cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
171
+ )
172
  elif self.training:
173
  cur_vllm_emb += cur_vs_hs[0].mean() * 0
174
 
 
180
  if position_ids.dtype != torch.int64:
181
  position_ids = position_ids.long()
182
 
183
+ return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)
184
+
 
 
 
 
 
185
  def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, **kwargs):
186
  terminators = None
187
  if tokenizer is not None:
 
189
  kwargs.pop("image_sizes")
190
  output = self.llm.generate(
191
  inputs_embeds=inputs_embeds,
192
+ # pad_token_id=0,
193
  eos_token_id=terminators,
194
  attention_mask=attention_mask,
195
+ **kwargs,
196
  )
197
  if decode_text:
198
  return self._decode_text(output, tokenizer)
 
202
  terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
203
  streamer = TextIteratorStreamer(tokenizer=tokenizer)
204
  generation_kwargs = {
205
+ "inputs_embeds": inputs_embeds,
206
+ "pad_token_id": 0,
207
+ "eos_token_id": terminators,
208
+ "streamer": streamer,
209
  }
210
  generation_kwargs.update(kwargs)
211
 
212
  thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
213
  thread.start()
214
+
215
  return streamer
216
 
217
  def _decode_text(self, result_ids, tokenizer):
 
238
  return_vision_hidden_states=False,
239
  stream=False,
240
  decode_text=False,
241
+ **kwargs,
242
  ):
243
  assert input_ids is not None
244
  assert len(input_ids) == len(pixel_values)
 
250
 
251
  if vision_hidden_states is None:
252
  model_inputs["pixel_values"] = pixel_values
253
+ model_inputs["tgt_sizes"] = tgt_sizes
254
  else:
255
  model_inputs["vision_hidden_states"] = vision_hidden_states
256
 
 
263
  if stream:
264
  result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
265
  else:
266
+ result = self._decode(
267
+ model_inputs["inputs_embeds"], tokenizer, attention_mask, decode_text=decode_text, **kwargs
268
+ )
269
 
270
  if return_vision_hidden_states:
271
  return result, vision_hidden_states
272
+
273
  return result
274
 
275
  def chat(
 
283
  min_new_tokens=0,
284
  sampling=True,
285
  max_inp_length=8192,
286
+ system_prompt="",
287
  stream=False,
288
  max_slice_nums=None,
289
  use_image_id=None,
290
+ **kwargs,
291
  ):
292
  if isinstance(msgs[0], list):
293
  batched = True
 
295
  batched = False
296
  msgs_list = msgs
297
  images_list = image
298
+
299
  if batched is False:
300
  images_list, msgs_list = [images_list], [msgs_list]
301
  else:
 
307
  if self.processor is None:
308
  self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
309
  processor = self.processor
310
+
311
+ assert (
312
+ self.config.query_num == processor.image_processor.image_feature_size
313
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
314
+ assert (
315
+ self.config.patch_size == processor.image_processor.patch_size
316
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
317
+ assert (
318
+ self.config.use_image_id == processor.image_processor.use_image_id
319
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
320
+ assert (
321
+ self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums
322
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
323
+ assert (
324
+ self.config.slice_mode == processor.image_processor.slice_mode
325
+ ), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
326
 
327
  prompts_lists = []
328
  input_images_lists = []
 
356
  msg["content"] = "\n".join(cur_msgs)
357
 
358
  if system_prompt:
359
+ sys_msg = {"role": "system", "content": system_prompt}
360
+ copy_msgs = [sys_msg] + copy_msgs
361
 
362
+ prompts_lists.append(
363
+ processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)
364
+ )
365
  input_images_lists.append(images)
366
 
367
  inputs = processor(
368
+ prompts_lists,
369
+ input_images_lists,
370
  max_slice_nums=max_slice_nums,
371
  use_image_id=use_image_id,
372
+ return_tensors="pt",
373
+ max_length=max_inp_length,
374
  ).to(self.device)
375
 
376
  if sampling:
 
379
  "top_k": 100,
380
  "temperature": 0.7,
381
  "do_sample": True,
382
+ "repetition_penalty": 1.05,
383
  }
384
  else:
385
  generation_config = {
386
  "num_beams": 3,
387
  "repetition_penalty": 1.2,
388
  }
389
+
390
  if min_new_tokens > 0:
391
+ generation_config["min_new_tokens"] = min_new_tokens
392
 
393
+ generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
 
 
394
 
395
  inputs.pop("image_sizes")
396
  with torch.inference_mode():
 
401
  vision_hidden_states=vision_hidden_states,
402
  stream=stream,
403
  decode_text=True,
404
+ **generation_config,
405
  )
406
+
407
  if stream:
408
+
409
  def stream_gen():
410
  for text in res:
411
  for term in self.terminators:
412
+ text = text.replace(term, "")
413
  yield text
414
+
415
  return stream_gen()
416
 
417
  else:
modeling_navit_siglip.py CHANGED
@@ -16,11 +16,11 @@
16
  # Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
17
 
18
 
19
- import os
20
  import math
 
21
  import warnings
22
  from dataclasses import dataclass
23
- from typing import Any, Optional, Tuple, Union
24
 
25
  import numpy as np
26
  import torch
@@ -28,12 +28,11 @@ import torch.nn.functional as F
28
  import torch.utils.checkpoint
29
  from torch import nn
30
  from torch.nn.init import _calculate_fan_in_and_fan_out
31
-
32
  from transformers.activations import ACT2FN
 
33
  from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
34
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
35
  from transformers.modeling_utils import PreTrainedModel
36
- from transformers.configuration_utils import PretrainedConfig
37
  from transformers.utils import (
38
  ModelOutput,
39
  add_start_docstrings,
@@ -42,10 +41,11 @@ from transformers.utils import (
42
  logging,
43
  replace_return_docstrings,
44
  )
45
- from transformers.utils import logging
46
 
47
  logger = logging.get_logger(__name__)
48
 
 
49
  class SiglipVisionConfig(PretrainedConfig):
50
  r"""
51
  This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
@@ -133,7 +133,7 @@ class SiglipVisionConfig(PretrainedConfig):
133
  )
134
 
135
  return cls.from_dict(config_dict, **kwargs)
136
-
137
 
138
  _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
139
 
@@ -148,7 +148,6 @@ try:
148
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
149
  except:
150
  pass
151
-
152
 
153
 
154
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
@@ -318,7 +317,12 @@ class SiglipVisionEmbeddings(nn.Module):
318
  self.num_positions = self.num_patches
319
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
320
 
321
- def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor, tgt_sizes: Optional[torch.IntTensor]=None) -> torch.Tensor:
 
 
 
 
 
322
  batch_size = pixel_values.size(0)
323
 
324
  patch_embeds = self.patch_embedding(pixel_values)
@@ -643,11 +647,7 @@ class SiglipEncoderLayer(nn.Module):
643
  super().__init__()
644
  self.embed_dim = config.hidden_size
645
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
646
- self.self_attn = (
647
- SiglipAttention(config)
648
- if not self._use_flash_attention_2
649
- else SiglipFlashAttention2(config)
650
- )
651
  self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
652
  self.mlp = SiglipMLP(config)
653
  self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -847,9 +847,9 @@ class SiglipEncoder(nn.Module):
847
  last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
848
  )
849
 
 
850
  @add_start_docstrings(
851
- """The vision model from SigLIP without any head or projection on top.""",
852
- SIGLIP_START_DOCSTRING
853
  )
854
  class SiglipVisionTransformer(SiglipPreTrainedModel):
855
  config_class = SiglipVisionConfig
@@ -904,14 +904,16 @@ class SiglipVisionTransformer(SiglipPreTrainedModel):
904
  device=pixel_values.device,
905
  )
906
 
907
- hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes)
 
 
908
 
909
  patch_attention_mask = patch_attention_mask.view(batch_size, -1)
910
  # The call to `_upad_input` in `_flash_attention_forward` is expensive
911
  # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
912
  # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
913
  if not torch.any(~patch_attention_mask):
914
- attention_mask=None
915
  else:
916
  attention_mask = (
917
  _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
 
16
  # Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
17
 
18
 
 
19
  import math
20
+ import os
21
  import warnings
22
  from dataclasses import dataclass
23
+ from typing import Optional, Tuple, Union
24
 
25
  import numpy as np
26
  import torch
 
28
  import torch.utils.checkpoint
29
  from torch import nn
30
  from torch.nn.init import _calculate_fan_in_and_fan_out
 
31
  from transformers.activations import ACT2FN
32
+ from transformers.configuration_utils import PretrainedConfig
33
  from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
34
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
35
  from transformers.modeling_utils import PreTrainedModel
 
36
  from transformers.utils import (
37
  ModelOutput,
38
  add_start_docstrings,
 
41
  logging,
42
  replace_return_docstrings,
43
  )
44
+
45
 
46
  logger = logging.get_logger(__name__)
47
 
48
+
49
  class SiglipVisionConfig(PretrainedConfig):
50
  r"""
51
  This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
 
133
  )
134
 
135
  return cls.from_dict(config_dict, **kwargs)
136
+
137
 
138
  _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
139
 
 
148
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
149
  except:
150
  pass
 
151
 
152
 
153
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
317
  self.num_positions = self.num_patches
318
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
319
 
320
+ def forward(
321
+ self,
322
+ pixel_values: torch.FloatTensor,
323
+ patch_attention_mask: torch.BoolTensor,
324
+ tgt_sizes: Optional[torch.IntTensor] = None,
325
+ ) -> torch.Tensor:
326
  batch_size = pixel_values.size(0)
327
 
328
  patch_embeds = self.patch_embedding(pixel_values)
 
647
  super().__init__()
648
  self.embed_dim = config.hidden_size
649
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
650
+ self.self_attn = SiglipAttention(config) if not self._use_flash_attention_2 else SiglipFlashAttention2(config)
 
 
 
 
651
  self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
652
  self.mlp = SiglipMLP(config)
653
  self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
 
847
  last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
848
  )
849
 
850
+
851
  @add_start_docstrings(
852
+ """The vision model from SigLIP without any head or projection on top.""", SIGLIP_START_DOCSTRING
 
853
  )
854
  class SiglipVisionTransformer(SiglipPreTrainedModel):
855
  config_class = SiglipVisionConfig
 
904
  device=pixel_values.device,
905
  )
906
 
907
+ hidden_states = self.embeddings(
908
+ pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes
909
+ )
910
 
911
  patch_attention_mask = patch_attention_mask.view(batch_size, -1)
912
  # The call to `_upad_input` in `_flash_attention_forward` is expensive
913
  # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
914
  # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
915
  if not torch.any(~patch_attention_mask):
916
+ attention_mask = None
917
  else:
918
  attention_mask = (
919
  _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
processing_minicpmv.py CHANGED
@@ -16,15 +16,14 @@
16
  Processor class for MiniCPMV.
17
  """
18
 
19
- from typing import List, Optional, Union, Dict, Any
20
- import torch
21
  import re
 
22
 
23
- from transformers.image_processing_utils import BatchFeature
24
  from transformers.image_utils import ImageInput
25
  from transformers.processing_utils import ProcessorMixin
26
- from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
27
- from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device
28
 
29
  from .image_processing_minicpmv import MiniCPMVBatchFeature
30
 
@@ -49,7 +48,7 @@ class MiniCPMVProcessor(ProcessorMixin):
49
  def __init__(self, image_processor=None, tokenizer=None):
50
  super().__init__(image_processor, tokenizer)
51
  self.version = image_processor.version
52
-
53
  def __call__(
54
  self,
55
  text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
@@ -59,14 +58,23 @@ class MiniCPMVProcessor(ProcessorMixin):
59
  max_slice_nums: int = None,
60
  use_image_id: bool = None,
61
  return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
62
- **kwargs
63
  ) -> MiniCPMVBatchFeature:
64
-
65
  image_inputs = None
66
  if images is not None:
67
- image_inputs = self.image_processor(images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors)
68
- return self._convert_images_texts_to_inputs(image_inputs, text, max_slice_nums=max_slice_nums, use_image_id=use_image_id, max_length=max_length, **kwargs, return_tensors=return_tensors)
69
-
 
 
 
 
 
 
 
 
 
 
70
  # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
71
  def batch_decode(self, *args, **kwargs):
72
  """
@@ -84,7 +92,7 @@ class MiniCPMVProcessor(ProcessorMixin):
84
  result_text.append(self.tokenizer.decode(result, *args[1:], **kwargs).strip())
85
  return result_text
86
  # return self.tokenizer.batch_decode(*args, **kwargs)
87
-
88
  # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
89
  def decode(self, *args, **kwargs):
90
  """
@@ -95,13 +103,13 @@ class MiniCPMVProcessor(ProcessorMixin):
95
  result = result[result != 0]
96
  if result[0] == self.tokenizer.bos_id:
97
  result = result[1:]
98
- if result[-1] == self.tokenizer.eos_id or (hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id):
 
 
99
  result = result[:-1]
100
  return self.tokenizer.decode(result, *args[1:], **kwargs).strip()
101
 
102
- def _convert(
103
- self, input_str, max_inp_length: Optional[int] = None
104
- ):
105
  if self.version > 2.5 or not getattr(self.tokenizer, "add_bos_token", False):
106
  input_ids = self.tokenizer.encode(input_str)
107
  else:
@@ -128,23 +136,25 @@ class MiniCPMVProcessor(ProcessorMixin):
128
  return input_ids, image_bounds
129
 
130
  def _convert_images_texts_to_inputs(
131
- self,
132
- images,
133
- texts: Union[str, List[str]],
134
- truncation=None,
135
- max_length=None,
136
- max_slice_nums=None,
137
- use_image_id=None,
138
- return_tensors=None,
139
- **kwargs
140
- ):
141
  if images is None or not len(images):
142
- model_inputs = self.tokenizer(texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs)
 
 
143
  return MiniCPMVBatchFeature(data={**model_inputs})
144
-
145
  pattern = "(<image>./</image>)"
146
  images, image_sizes, tgt_sizes = images["pixel_values"], images["image_sizes"], images["tgt_sizes"]
147
-
148
  if isinstance(texts, str):
149
  texts = [texts]
150
  input_ids_list = []
@@ -155,33 +165,32 @@ class MiniCPMVProcessor(ProcessorMixin):
155
  text_chunks = text.split(pattern)
156
  final_text = ""
157
  for i in range(len(image_tags)):
158
- final_text = final_text + text_chunks[i] + \
159
- self.image_processor.get_slice_image_placeholder(
160
- image_sizes[index][i],
161
- i,
162
- max_slice_nums,
163
- use_image_id
164
  )
 
165
  final_text += text_chunks[-1]
166
  input_ids, image_bounds = self._convert(final_text, max_length)
167
  input_ids_list.append(input_ids)
168
  image_bounds_list.append(image_bounds)
169
- padded_input_ids, padding_lengths = self.pad(
170
- input_ids_list,
171
- padding_side="left"
172
- )
173
  for i, length in enumerate(padding_lengths):
174
  image_bounds_list[i] = image_bounds_list[i] + length
175
  attention_mask = padded_input_ids.ne(0)
176
 
177
- return MiniCPMVBatchFeature(data={
178
- "input_ids": padded_input_ids,
179
- "attention_mask": attention_mask,
180
- "pixel_values": images,
181
- "image_sizes": image_sizes,
182
- "image_bound": image_bounds_list,
183
- "tgt_sizes": tgt_sizes
184
- })
 
 
185
 
186
  @property
187
  # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
@@ -190,7 +199,6 @@ class MiniCPMVProcessor(ProcessorMixin):
190
  image_processor_input_names = self.image_processor.model_input_names
191
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
192
 
193
-
194
  def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
195
  items = []
196
  if isinstance(inputs[0], list):
@@ -219,10 +227,7 @@ class MiniCPMVProcessor(ProcessorMixin):
219
  return torch.stack([item for item in items], dim=0), [0] * batch_size
220
  tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
221
  else:
222
- tensor = (
223
- torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype)
224
- + padding_value
225
- )
226
 
227
  padding_length = []
228
  for i, item in enumerate(items):
 
16
  Processor class for MiniCPMV.
17
  """
18
 
 
 
19
  import re
20
+ from typing import List, Optional, Union
21
 
22
+ import torch
23
  from transformers.image_utils import ImageInput
24
  from transformers.processing_utils import ProcessorMixin
25
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
26
+ from transformers.utils import TensorType
27
 
28
  from .image_processing_minicpmv import MiniCPMVBatchFeature
29
 
 
48
  def __init__(self, image_processor=None, tokenizer=None):
49
  super().__init__(image_processor, tokenizer)
50
  self.version = image_processor.version
51
+
52
  def __call__(
53
  self,
54
  text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
 
58
  max_slice_nums: int = None,
59
  use_image_id: bool = None,
60
  return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
61
+ **kwargs,
62
  ) -> MiniCPMVBatchFeature:
 
63
  image_inputs = None
64
  if images is not None:
65
+ image_inputs = self.image_processor(
66
+ images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors
67
+ )
68
+ return self._convert_images_texts_to_inputs(
69
+ image_inputs,
70
+ text,
71
+ max_slice_nums=max_slice_nums,
72
+ use_image_id=use_image_id,
73
+ max_length=max_length,
74
+ **kwargs,
75
+ return_tensors=return_tensors,
76
+ )
77
+
78
  # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
79
  def batch_decode(self, *args, **kwargs):
80
  """
 
92
  result_text.append(self.tokenizer.decode(result, *args[1:], **kwargs).strip())
93
  return result_text
94
  # return self.tokenizer.batch_decode(*args, **kwargs)
95
+
96
  # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
97
  def decode(self, *args, **kwargs):
98
  """
 
103
  result = result[result != 0]
104
  if result[0] == self.tokenizer.bos_id:
105
  result = result[1:]
106
+ if result[-1] == self.tokenizer.eos_id or (
107
+ hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id
108
+ ):
109
  result = result[:-1]
110
  return self.tokenizer.decode(result, *args[1:], **kwargs).strip()
111
 
112
+ def _convert(self, input_str, max_inp_length: Optional[int] = None):
 
 
113
  if self.version > 2.5 or not getattr(self.tokenizer, "add_bos_token", False):
114
  input_ids = self.tokenizer.encode(input_str)
115
  else:
 
136
  return input_ids, image_bounds
137
 
138
  def _convert_images_texts_to_inputs(
139
+ self,
140
+ images,
141
+ texts: Union[str, List[str]],
142
+ truncation=None,
143
+ max_length=None,
144
+ max_slice_nums=None,
145
+ use_image_id=None,
146
+ return_tensors=None,
147
+ **kwargs,
148
+ ):
149
  if images is None or not len(images):
150
+ model_inputs = self.tokenizer(
151
+ texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs
152
+ )
153
  return MiniCPMVBatchFeature(data={**model_inputs})
154
+
155
  pattern = "(<image>./</image>)"
156
  images, image_sizes, tgt_sizes = images["pixel_values"], images["image_sizes"], images["tgt_sizes"]
157
+
158
  if isinstance(texts, str):
159
  texts = [texts]
160
  input_ids_list = []
 
165
  text_chunks = text.split(pattern)
166
  final_text = ""
167
  for i in range(len(image_tags)):
168
+ final_text = (
169
+ final_text
170
+ + text_chunks[i]
171
+ + self.image_processor.get_slice_image_placeholder(
172
+ image_sizes[index][i], i, max_slice_nums, use_image_id
 
173
  )
174
+ )
175
  final_text += text_chunks[-1]
176
  input_ids, image_bounds = self._convert(final_text, max_length)
177
  input_ids_list.append(input_ids)
178
  image_bounds_list.append(image_bounds)
179
+ padded_input_ids, padding_lengths = self.pad(input_ids_list, padding_side="left")
 
 
 
180
  for i, length in enumerate(padding_lengths):
181
  image_bounds_list[i] = image_bounds_list[i] + length
182
  attention_mask = padded_input_ids.ne(0)
183
 
184
+ return MiniCPMVBatchFeature(
185
+ data={
186
+ "input_ids": padded_input_ids,
187
+ "attention_mask": attention_mask,
188
+ "pixel_values": images,
189
+ "image_sizes": image_sizes,
190
+ "image_bound": image_bounds_list,
191
+ "tgt_sizes": tgt_sizes,
192
+ }
193
+ )
194
 
195
  @property
196
  # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
 
199
  image_processor_input_names = self.image_processor.model_input_names
200
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
201
 
 
202
  def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
203
  items = []
204
  if isinstance(inputs[0], list):
 
227
  return torch.stack([item for item in items], dim=0), [0] * batch_size
228
  tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
229
  else:
230
+ tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
 
 
 
231
 
232
  padding_length = []
233
  for i, item in enumerate(items):
resampler.py CHANGED
@@ -1,18 +1,17 @@
 
1
  from functools import partial
2
  from typing import Optional, Tuple
3
- import numpy as np
4
- import warnings
5
 
 
6
  import torch
7
- from torch import nn
8
- from torch import Tensor
9
  import torch.nn.functional as F
 
10
  from torch.nn.functional import *
 
11
  from torch.nn.modules.activation import *
12
- from torch.nn.init import trunc_normal_, constant_, xavier_normal_, xavier_uniform_
13
-
14
  from transformers.integrations import is_deepspeed_zero3_enabled
15
 
 
16
  def get_2d_sincos_pos_embed(embed_dim, image_size):
17
  """
18
  image_size: image_size or (image_height, image_width)
@@ -52,10 +51,10 @@ def get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos):
52
  """
53
  assert embed_dim % 2 == 0
54
  omega = np.arange(embed_dim // 2, dtype=np.float32)
55
- omega /= embed_dim / 2.
56
- omega = 1. / 10000 ** omega # (D/2,)
57
 
58
- out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product
59
 
60
  emb_sin = np.sin(out) # (H, W, D/2)
61
  emb_cos = np.cos(out) # (H, W, D/2)
@@ -73,14 +72,14 @@ class Resampler(nn.Module):
73
  """
74
 
75
  def __init__(
76
- self,
77
- num_queries,
78
- embed_dim,
79
- num_heads,
80
- kv_dim=None,
81
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
82
- adaptive=False,
83
- max_size=(70, 70),
84
  ):
85
  super().__init__()
86
  self.num_queries = num_queries
@@ -101,13 +100,13 @@ class Resampler(nn.Module):
101
  self.ln_kv = norm_layer(embed_dim)
102
 
103
  self.ln_post = norm_layer(embed_dim)
104
- self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
105
 
106
  self._set_2d_pos_cache(self.max_size)
107
 
108
- def _set_2d_pos_cache(self, max_size, device='cpu'):
109
  if is_deepspeed_zero3_enabled():
110
- device='cuda'
111
  pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.embed_dim, max_size)).float().to(device)
112
  self.register_buffer("pos_embed", pos_embed, persistent=False)
113
 
@@ -120,7 +119,7 @@ class Resampler(nn.Module):
120
 
121
  def _init_weights(self, m):
122
  if isinstance(m, nn.Linear):
123
- trunc_normal_(m.weight, std=.02)
124
  if isinstance(m, nn.Linear) and m.bias is not None:
125
  nn.init.constant_(m.bias, 0)
126
  elif isinstance(m, nn.LayerNorm):
@@ -145,10 +144,11 @@ class Resampler(nn.Module):
145
  for i in range(bs):
146
  tgt_h, tgt_w = tgt_sizes[i]
147
  pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D
148
- key_padding_mask[i, patch_len[i]:] = True
149
 
150
- pos_embed = torch.nn.utils.rnn.pad_sequence(
151
- pos_embed, batch_first=True, padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D
 
152
 
153
  x = self.kv_proj(x) # B * L * D
154
  x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
@@ -159,7 +159,8 @@ class Resampler(nn.Module):
159
  self._repeat(q, bs), # Q * B * D
160
  x + pos_embed, # L * B * D + L * B * D
161
  x,
162
- key_padding_mask=key_padding_mask)[0]
 
163
  # out: Q * B * D
164
  x = out.permute(1, 0, 2) # B * Q * D
165
 
@@ -172,26 +173,44 @@ class Resampler(nn.Module):
172
 
173
 
174
  class MultiheadAttention(nn.MultiheadAttention):
175
- def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False,
176
- add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None):
177
- super().__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  # rewrite out_proj layer,with nn.Linear
180
  self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
181
 
182
  def forward(
183
- self,
184
- query: Tensor,
185
- key: Tensor,
186
- value: Tensor,
187
- key_padding_mask: Optional[Tensor] = None,
188
- need_weights: bool = True,
189
- attn_mask: Optional[Tensor] = None,
190
- average_attn_weights: bool = True,
191
- is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
192
- why_not_fast_path = ''
193
- if ((attn_mask is not None and torch.is_floating_point(attn_mask))
194
- or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)):
 
 
 
 
195
  why_not_fast_path = "floating-point masks are not supported for fast path."
196
 
197
  is_batched = query.dim() == 3
@@ -201,7 +220,7 @@ class MultiheadAttention(nn.MultiheadAttention):
201
  mask_name="key_padding_mask",
202
  other_type=F._none_or_dtype(attn_mask),
203
  other_name="attn_mask",
204
- target_type=query.dtype
205
  )
206
 
207
  attn_mask = _canonical_mask(
@@ -213,7 +232,6 @@ class MultiheadAttention(nn.MultiheadAttention):
213
  check_other=False,
214
  )
215
 
216
-
217
  if not is_batched:
218
  why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
219
  elif query is not key or key is not value:
@@ -222,12 +240,16 @@ class MultiheadAttention(nn.MultiheadAttention):
222
  # they don't!
223
  why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
224
  elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
225
- why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
 
 
226
  elif self.in_proj_weight is None:
227
  why_not_fast_path = "in_proj_weight was None"
228
  elif query.dtype != self.in_proj_weight.dtype:
229
  # this case will fail anyway, but at least they'll get a useful error message.
230
- why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
 
 
231
  elif self.training:
232
  why_not_fast_path = "training is enabled"
233
  elif (self.num_heads % 2) != 0:
@@ -265,11 +287,15 @@ class MultiheadAttention(nn.MultiheadAttention):
265
  elif _is_make_fx_tracing():
266
  why_not_fast_path = "we are running make_fx tracing"
267
  elif not all(_check_arg_device(x) for x in tensor_args):
268
- why_not_fast_path = ("some Tensor argument's device is neither one of "
269
- f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}")
 
 
270
  elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
271
- why_not_fast_path = ("grad is enabled and at least one of query or the "
272
- "input/output projection weights or biases requires_grad")
 
 
273
  if not why_not_fast_path:
274
  merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
275
 
@@ -287,11 +313,14 @@ class MultiheadAttention(nn.MultiheadAttention):
287
  merged_mask,
288
  need_weights,
289
  average_attn_weights,
290
- mask_type)
 
291
 
292
  any_nested = query.is_nested or key.is_nested or value.is_nested
293
- assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
294
- f"The fast path was not hit because {why_not_fast_path}")
 
 
295
 
296
  if self.batch_first and is_batched:
297
  # make sure that the transpose op does not affect the "is" property
@@ -303,38 +332,60 @@ class MultiheadAttention(nn.MultiheadAttention):
303
  value = key
304
  else:
305
  query, key, value = (x.transpose(1, 0) for x in (query, key, value))
306
-
307
  if not self._qkv_same_embed_dim:
308
  attn_output, attn_output_weights = self.multi_head_attention_forward(
309
- query, key, value, self.embed_dim, self.num_heads,
310
- self.in_proj_weight, self.in_proj_bias,
311
- self.bias_k, self.bias_v, self.add_zero_attn,
312
- self.dropout, self.out_proj.weight, self.out_proj.bias,
 
 
 
 
 
 
 
 
 
313
  training=self.training,
314
- key_padding_mask=key_padding_mask, need_weights=need_weights,
 
315
  attn_mask=attn_mask,
316
  use_separate_proj_weight=True,
317
- q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
 
318
  v_proj_weight=self.v_proj_weight,
319
  average_attn_weights=average_attn_weights,
320
- is_causal=is_causal)
 
321
  else:
322
  attn_output, attn_output_weights = self.multi_head_attention_forward(
323
- query, key, value, self.embed_dim, self.num_heads,
324
- self.in_proj_weight, self.in_proj_bias,
325
- self.bias_k, self.bias_v, self.add_zero_attn,
326
- self.dropout, self.out_proj.weight, self.out_proj.bias,
 
 
 
 
 
 
 
 
 
327
  training=self.training,
328
  key_padding_mask=key_padding_mask,
329
  need_weights=need_weights,
330
  attn_mask=attn_mask,
331
  average_attn_weights=average_attn_weights,
332
- is_causal=is_causal)
 
333
  if self.batch_first and is_batched:
334
  return attn_output.transpose(1, 0), attn_output_weights
335
  else:
336
  return attn_output, attn_output_weights
337
-
338
  def multi_head_attention_forward(
339
  self,
340
  query: Tensor,
@@ -364,9 +415,9 @@ class MultiheadAttention(nn.MultiheadAttention):
364
  is_causal: bool = False,
365
  ) -> Tuple[Tensor, Optional[Tensor]]:
366
  tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
367
-
368
  is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
369
-
370
  # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
371
  # is batched, run the computation and before returning squeeze the
372
  # batch dimension so that the output doesn't carry this temporary batch dimension.
@@ -377,26 +428,26 @@ class MultiheadAttention(nn.MultiheadAttention):
377
  value = value.unsqueeze(1)
378
  if key_padding_mask is not None:
379
  key_padding_mask = key_padding_mask.unsqueeze(0)
380
-
381
  # set up shape vars
382
  tgt_len, bsz, embed_dim = query.shape
383
  src_len, _, _ = key.shape
384
-
385
  key_padding_mask = _canonical_mask(
386
  mask=key_padding_mask,
387
  mask_name="key_padding_mask",
388
  other_type=_none_or_dtype(attn_mask),
389
  other_name="attn_mask",
390
- target_type=query.dtype
391
  )
392
-
393
  if is_causal and attn_mask is None:
394
  raise RuntimeError(
395
  "Need attn_mask if specifying the is_causal hint. "
396
  "You may use the Transformer module method "
397
  "`generate_square_subsequent_mask` to create this mask."
398
  )
399
-
400
  if is_causal and key_padding_mask is None and not need_weights:
401
  # when we have a kpm or need weights, we need attn_mask
402
  # Otherwise, we use the is_causal hint go as is_causal
@@ -411,28 +462,30 @@ class MultiheadAttention(nn.MultiheadAttention):
411
  target_type=query.dtype,
412
  check_other=False,
413
  )
414
-
415
  if key_padding_mask is not None:
416
  # We have the attn_mask, and use that to merge kpm into it.
417
  # Turn off use of is_causal hint, as the merged mask is no
418
  # longer causal.
419
  is_causal = False
420
-
421
- assert embed_dim == embed_dim_to_check, \
422
- f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
 
423
  if isinstance(embed_dim, torch.Tensor):
424
  # embed_dim can be a tensor when JIT tracing
425
- head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
426
  else:
427
  head_dim = embed_dim // num_heads
428
  assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
429
  if use_separate_proj_weight:
430
  # allow MHA to have different embedding dimensions when separate projection weights are used
431
- assert key.shape[:2] == value.shape[:2], \
432
- f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
 
433
  else:
434
  assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
435
-
436
  #
437
  # compute in-projection
438
  #
@@ -448,23 +501,27 @@ class MultiheadAttention(nn.MultiheadAttention):
448
  else:
449
  b_q, b_k, b_v = in_proj_bias.chunk(3)
450
  q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
451
-
452
  # prep attention mask
453
-
454
  if attn_mask is not None:
455
  # ensure attn_mask's dim is 3
456
  if attn_mask.dim() == 2:
457
  correct_2d_size = (tgt_len, src_len)
458
  if attn_mask.shape != correct_2d_size:
459
- raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
 
 
460
  attn_mask = attn_mask.unsqueeze(0)
461
  elif attn_mask.dim() == 3:
462
  correct_3d_size = (bsz * num_heads, tgt_len, src_len)
463
  if attn_mask.shape != correct_3d_size:
464
- raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
 
 
465
  else:
466
  raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
467
-
468
  # add bias along batch dimension (currently second)
469
  if bias_k is not None and bias_v is not None:
470
  assert static_k is None, "bias cannot be added to static key."
@@ -478,7 +535,7 @@ class MultiheadAttention(nn.MultiheadAttention):
478
  else:
479
  assert bias_k is None
480
  assert bias_v is None
481
-
482
  #
483
  # reshape q, k, v for multihead attention and make em batch first
484
  #
@@ -487,21 +544,25 @@ class MultiheadAttention(nn.MultiheadAttention):
487
  k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
488
  else:
489
  # TODO finish disentangling control flow so we don't do in-projections when statics are passed
490
- assert static_k.size(0) == bsz * num_heads, \
491
- f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
492
- assert static_k.size(2) == head_dim, \
493
- f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
 
 
494
  k = static_k
495
  if static_v is None:
496
  v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
497
  else:
498
  # TODO finish disentangling control flow so we don't do in-projections when statics are passed
499
- assert static_v.size(0) == bsz * num_heads, \
500
- f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
501
- assert static_v.size(2) == head_dim, \
502
- f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
 
 
503
  v = static_v
504
-
505
  # add zero attention along batch dimension (now first)
506
  if add_zero_attn:
507
  zero_attn_shape = (bsz * num_heads, 1, head_dim)
@@ -511,35 +572,40 @@ class MultiheadAttention(nn.MultiheadAttention):
511
  attn_mask = pad(attn_mask, (0, 1))
512
  if key_padding_mask is not None:
513
  key_padding_mask = pad(key_padding_mask, (0, 1))
514
-
515
  # update source sequence length after adjustments
516
  src_len = k.size(1)
517
-
518
  # merge key padding and attention masks
519
  if key_padding_mask is not None:
520
- assert key_padding_mask.shape == (bsz, src_len), \
521
- f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
522
- key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
523
- expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
 
 
 
 
 
524
  if attn_mask is None:
525
  attn_mask = key_padding_mask
526
  else:
527
  attn_mask = attn_mask + key_padding_mask
528
-
529
  # adjust dropout probability
530
  if not training:
531
  dropout_p = 0.0
532
-
533
  #
534
  # (deep breath) calculate attention and out projection
535
  #
536
-
537
  if need_weights:
538
  B, Nt, E = q.shape
539
  q_scaled = q / math.sqrt(E)
540
-
541
  assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
542
-
543
  if attn_mask is not None:
544
  attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
545
  else:
@@ -547,18 +613,18 @@ class MultiheadAttention(nn.MultiheadAttention):
547
  attn_output_weights = softmax(attn_output_weights, dim=-1)
548
  if dropout_p > 0.0:
549
  attn_output_weights = dropout(attn_output_weights, p=dropout_p)
550
-
551
  attn_output = torch.bmm(attn_output_weights, v)
552
-
553
  attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
554
  attn_output = self.out_proj(attn_output)
555
  attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
556
-
557
  # optionally average attention weights over heads
558
  attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
559
  if average_attn_weights:
560
  attn_output_weights = attn_output_weights.mean(dim=1)
561
-
562
  if not is_batched:
563
  # squeeze the output if input was unbatched
564
  attn_output = attn_output.squeeze(1)
@@ -573,14 +639,14 @@ class MultiheadAttention(nn.MultiheadAttention):
573
  attn_mask = attn_mask.unsqueeze(0)
574
  else:
575
  attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
576
-
577
  q = q.view(bsz, num_heads, tgt_len, head_dim)
578
  k = k.view(bsz, num_heads, src_len, head_dim)
579
  v = v.view(bsz, num_heads, src_len, head_dim)
580
-
581
  attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
582
  attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
583
-
584
  attn_output = self.out_proj(attn_output)
585
  attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
586
  if not is_batched:
@@ -589,8 +655,14 @@ class MultiheadAttention(nn.MultiheadAttention):
589
  return attn_output, None
590
 
591
 
592
- def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
593
- key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], num_heads: int):
 
 
 
 
 
 
594
  # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
595
  # and returns if the input is batched or not.
596
  # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
@@ -599,59 +671,65 @@ def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
599
  if query.dim() == 3:
600
  # Batched Inputs
601
  is_batched = True
602
- assert key.dim() == 3 and value.dim() == 3, \
603
- ("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
604
- f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
 
605
  if key_padding_mask is not None:
606
- assert key_padding_mask.dim() == 2, \
607
- ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
608
- f" but found {key_padding_mask.dim()}-D tensor instead")
 
609
  if attn_mask is not None:
610
- assert attn_mask.dim() in (2, 3), \
611
- ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
612
- f" but found {attn_mask.dim()}-D tensor instead")
 
613
  elif query.dim() == 2:
614
  # Unbatched Inputs
615
  is_batched = False
616
- assert key.dim() == 2 and value.dim() == 2, \
617
- ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
618
- f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
 
619
 
620
  if key_padding_mask is not None:
621
- assert key_padding_mask.dim() == 1, \
622
- ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
623
- f" but found {key_padding_mask.dim()}-D tensor instead")
 
624
 
625
  if attn_mask is not None:
626
- assert attn_mask.dim() in (2, 3), \
627
- ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
628
- f" but found {attn_mask.dim()}-D tensor instead")
 
629
  if attn_mask.dim() == 3:
630
  expected_shape = (num_heads, query.shape[0], key.shape[0])
631
- assert attn_mask.shape == expected_shape, \
632
- (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
 
633
  else:
634
  raise AssertionError(
635
- f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
 
636
 
637
  return is_batched
638
 
639
 
640
  def _canonical_mask(
641
- mask: Optional[Tensor],
642
- mask_name: str,
643
- other_type: Optional[DType],
644
- other_name: str,
645
- target_type: DType,
646
- check_other: bool = True,
647
  ) -> Optional[Tensor]:
648
-
649
  if mask is not None:
650
  _mask_dtype = mask.dtype
651
  _mask_is_float = torch.is_floating_point(mask)
652
  if _mask_dtype != torch.bool and not _mask_is_float:
653
- raise AssertionError(
654
- f"only bool and floating types of {mask_name} are supported")
655
  if check_other and other_type is not None:
656
  if _mask_dtype != other_type:
657
  warnings.warn(
@@ -659,10 +737,7 @@ def _canonical_mask(
659
  "is deprecated. Use same type for both instead."
660
  )
661
  if not _mask_is_float:
662
- mask = (
663
- torch.zeros_like(mask, dtype=target_type)
664
- .masked_fill_(mask, float("-inf"))
665
- )
666
  return mask
667
 
668
 
@@ -673,6 +748,7 @@ def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
673
  return input.dtype
674
  raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
675
 
 
676
  def _in_projection_packed(
677
  q: Tensor,
678
  k: Tensor,
@@ -779,4 +855,4 @@ def _in_projection(
779
  assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
780
  assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
781
  assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
782
- return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
 
1
+ import warnings
2
  from functools import partial
3
  from typing import Optional, Tuple
 
 
4
 
5
+ import numpy as np
6
  import torch
 
 
7
  import torch.nn.functional as F
8
+ from torch import Tensor, nn
9
  from torch.nn.functional import *
10
+ from torch.nn.init import trunc_normal_
11
  from torch.nn.modules.activation import *
 
 
12
  from transformers.integrations import is_deepspeed_zero3_enabled
13
 
14
+
15
  def get_2d_sincos_pos_embed(embed_dim, image_size):
16
  """
17
  image_size: image_size or (image_height, image_width)
 
51
  """
52
  assert embed_dim % 2 == 0
53
  omega = np.arange(embed_dim // 2, dtype=np.float32)
54
+ omega /= embed_dim / 2.0
55
+ omega = 1.0 / 10000**omega # (D/2,)
56
 
57
+ out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
58
 
59
  emb_sin = np.sin(out) # (H, W, D/2)
60
  emb_cos = np.cos(out) # (H, W, D/2)
 
72
  """
73
 
74
  def __init__(
75
+ self,
76
+ num_queries,
77
+ embed_dim,
78
+ num_heads,
79
+ kv_dim=None,
80
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
81
+ adaptive=False,
82
+ max_size=(70, 70),
83
  ):
84
  super().__init__()
85
  self.num_queries = num_queries
 
100
  self.ln_kv = norm_layer(embed_dim)
101
 
102
  self.ln_post = norm_layer(embed_dim)
103
+ self.proj = nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
104
 
105
  self._set_2d_pos_cache(self.max_size)
106
 
107
+ def _set_2d_pos_cache(self, max_size, device="cpu"):
108
  if is_deepspeed_zero3_enabled():
109
+ device = "cuda"
110
  pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.embed_dim, max_size)).float().to(device)
111
  self.register_buffer("pos_embed", pos_embed, persistent=False)
112
 
 
119
 
120
  def _init_weights(self, m):
121
  if isinstance(m, nn.Linear):
122
+ trunc_normal_(m.weight, std=0.02)
123
  if isinstance(m, nn.Linear) and m.bias is not None:
124
  nn.init.constant_(m.bias, 0)
125
  elif isinstance(m, nn.LayerNorm):
 
144
  for i in range(bs):
145
  tgt_h, tgt_w = tgt_sizes[i]
146
  pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D
147
+ key_padding_mask[i, patch_len[i] :] = True
148
 
149
+ pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute(
150
+ 1, 0, 2
151
+ ) # BLD => L * B * D
152
 
153
  x = self.kv_proj(x) # B * L * D
154
  x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
 
159
  self._repeat(q, bs), # Q * B * D
160
  x + pos_embed, # L * B * D + L * B * D
161
  x,
162
+ key_padding_mask=key_padding_mask,
163
+ )[0]
164
  # out: Q * B * D
165
  x = out.permute(1, 0, 2) # B * Q * D
166
 
 
173
 
174
 
175
  class MultiheadAttention(nn.MultiheadAttention):
176
+ def __init__(
177
+ self,
178
+ embed_dim,
179
+ num_heads,
180
+ dropout=0.0,
181
+ bias=True,
182
+ add_bias_kv=False,
183
+ add_zero_attn=False,
184
+ kdim=None,
185
+ vdim=None,
186
+ batch_first=False,
187
+ device=None,
188
+ dtype=None,
189
+ ):
190
+ super().__init__(
191
+ embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype
192
+ )
193
 
194
  # rewrite out_proj layer,with nn.Linear
195
  self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
196
 
197
  def forward(
198
+ self,
199
+ query: Tensor,
200
+ key: Tensor,
201
+ value: Tensor,
202
+ key_padding_mask: Optional[Tensor] = None,
203
+ need_weights: bool = True,
204
+ attn_mask: Optional[Tensor] = None,
205
+ average_attn_weights: bool = True,
206
+ is_causal: bool = False,
207
+ ) -> Tuple[Tensor, Optional[Tensor]]:
208
+ why_not_fast_path = ""
209
+ if (
210
+ (attn_mask is not None and torch.is_floating_point(attn_mask))
211
+ or (key_padding_mask is not None)
212
+ and torch.is_floating_point(key_padding_mask)
213
+ ):
214
  why_not_fast_path = "floating-point masks are not supported for fast path."
215
 
216
  is_batched = query.dim() == 3
 
220
  mask_name="key_padding_mask",
221
  other_type=F._none_or_dtype(attn_mask),
222
  other_name="attn_mask",
223
+ target_type=query.dtype,
224
  )
225
 
226
  attn_mask = _canonical_mask(
 
232
  check_other=False,
233
  )
234
 
 
235
  if not is_batched:
236
  why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
237
  elif query is not key or key is not value:
 
240
  # they don't!
241
  why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
242
  elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
243
+ why_not_fast_path = (
244
+ f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
245
+ )
246
  elif self.in_proj_weight is None:
247
  why_not_fast_path = "in_proj_weight was None"
248
  elif query.dtype != self.in_proj_weight.dtype:
249
  # this case will fail anyway, but at least they'll get a useful error message.
250
+ why_not_fast_path = (
251
+ f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
252
+ )
253
  elif self.training:
254
  why_not_fast_path = "training is enabled"
255
  elif (self.num_heads % 2) != 0:
 
287
  elif _is_make_fx_tracing():
288
  why_not_fast_path = "we are running make_fx tracing"
289
  elif not all(_check_arg_device(x) for x in tensor_args):
290
+ why_not_fast_path = (
291
+ "some Tensor argument's device is neither one of "
292
+ f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}"
293
+ )
294
  elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
295
+ why_not_fast_path = (
296
+ "grad is enabled and at least one of query or the "
297
+ "input/output projection weights or biases requires_grad"
298
+ )
299
  if not why_not_fast_path:
300
  merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
301
 
 
313
  merged_mask,
314
  need_weights,
315
  average_attn_weights,
316
+ mask_type,
317
+ )
318
 
319
  any_nested = query.is_nested or key.is_nested or value.is_nested
320
+ assert not any_nested, (
321
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
322
+ + f"The fast path was not hit because {why_not_fast_path}"
323
+ )
324
 
325
  if self.batch_first and is_batched:
326
  # make sure that the transpose op does not affect the "is" property
 
332
  value = key
333
  else:
334
  query, key, value = (x.transpose(1, 0) for x in (query, key, value))
335
+
336
  if not self._qkv_same_embed_dim:
337
  attn_output, attn_output_weights = self.multi_head_attention_forward(
338
+ query,
339
+ key,
340
+ value,
341
+ self.embed_dim,
342
+ self.num_heads,
343
+ self.in_proj_weight,
344
+ self.in_proj_bias,
345
+ self.bias_k,
346
+ self.bias_v,
347
+ self.add_zero_attn,
348
+ self.dropout,
349
+ self.out_proj.weight,
350
+ self.out_proj.bias,
351
  training=self.training,
352
+ key_padding_mask=key_padding_mask,
353
+ need_weights=need_weights,
354
  attn_mask=attn_mask,
355
  use_separate_proj_weight=True,
356
+ q_proj_weight=self.q_proj_weight,
357
+ k_proj_weight=self.k_proj_weight,
358
  v_proj_weight=self.v_proj_weight,
359
  average_attn_weights=average_attn_weights,
360
+ is_causal=is_causal,
361
+ )
362
  else:
363
  attn_output, attn_output_weights = self.multi_head_attention_forward(
364
+ query,
365
+ key,
366
+ value,
367
+ self.embed_dim,
368
+ self.num_heads,
369
+ self.in_proj_weight,
370
+ self.in_proj_bias,
371
+ self.bias_k,
372
+ self.bias_v,
373
+ self.add_zero_attn,
374
+ self.dropout,
375
+ self.out_proj.weight,
376
+ self.out_proj.bias,
377
  training=self.training,
378
  key_padding_mask=key_padding_mask,
379
  need_weights=need_weights,
380
  attn_mask=attn_mask,
381
  average_attn_weights=average_attn_weights,
382
+ is_causal=is_causal,
383
+ )
384
  if self.batch_first and is_batched:
385
  return attn_output.transpose(1, 0), attn_output_weights
386
  else:
387
  return attn_output, attn_output_weights
388
+
389
  def multi_head_attention_forward(
390
  self,
391
  query: Tensor,
 
415
  is_causal: bool = False,
416
  ) -> Tuple[Tensor, Optional[Tensor]]:
417
  tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
418
+
419
  is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
420
+
421
  # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
422
  # is batched, run the computation and before returning squeeze the
423
  # batch dimension so that the output doesn't carry this temporary batch dimension.
 
428
  value = value.unsqueeze(1)
429
  if key_padding_mask is not None:
430
  key_padding_mask = key_padding_mask.unsqueeze(0)
431
+
432
  # set up shape vars
433
  tgt_len, bsz, embed_dim = query.shape
434
  src_len, _, _ = key.shape
435
+
436
  key_padding_mask = _canonical_mask(
437
  mask=key_padding_mask,
438
  mask_name="key_padding_mask",
439
  other_type=_none_or_dtype(attn_mask),
440
  other_name="attn_mask",
441
+ target_type=query.dtype,
442
  )
443
+
444
  if is_causal and attn_mask is None:
445
  raise RuntimeError(
446
  "Need attn_mask if specifying the is_causal hint. "
447
  "You may use the Transformer module method "
448
  "`generate_square_subsequent_mask` to create this mask."
449
  )
450
+
451
  if is_causal and key_padding_mask is None and not need_weights:
452
  # when we have a kpm or need weights, we need attn_mask
453
  # Otherwise, we use the is_causal hint go as is_causal
 
462
  target_type=query.dtype,
463
  check_other=False,
464
  )
465
+
466
  if key_padding_mask is not None:
467
  # We have the attn_mask, and use that to merge kpm into it.
468
  # Turn off use of is_causal hint, as the merged mask is no
469
  # longer causal.
470
  is_causal = False
471
+
472
+ assert (
473
+ embed_dim == embed_dim_to_check
474
+ ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
475
  if isinstance(embed_dim, torch.Tensor):
476
  # embed_dim can be a tensor when JIT tracing
477
+ head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
478
  else:
479
  head_dim = embed_dim // num_heads
480
  assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
481
  if use_separate_proj_weight:
482
  # allow MHA to have different embedding dimensions when separate projection weights are used
483
+ assert (
484
+ key.shape[:2] == value.shape[:2]
485
+ ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
486
  else:
487
  assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
488
+
489
  #
490
  # compute in-projection
491
  #
 
501
  else:
502
  b_q, b_k, b_v = in_proj_bias.chunk(3)
503
  q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
504
+
505
  # prep attention mask
506
+
507
  if attn_mask is not None:
508
  # ensure attn_mask's dim is 3
509
  if attn_mask.dim() == 2:
510
  correct_2d_size = (tgt_len, src_len)
511
  if attn_mask.shape != correct_2d_size:
512
+ raise RuntimeError(
513
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
514
+ )
515
  attn_mask = attn_mask.unsqueeze(0)
516
  elif attn_mask.dim() == 3:
517
  correct_3d_size = (bsz * num_heads, tgt_len, src_len)
518
  if attn_mask.shape != correct_3d_size:
519
+ raise RuntimeError(
520
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
521
+ )
522
  else:
523
  raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
524
+
525
  # add bias along batch dimension (currently second)
526
  if bias_k is not None and bias_v is not None:
527
  assert static_k is None, "bias cannot be added to static key."
 
535
  else:
536
  assert bias_k is None
537
  assert bias_v is None
538
+
539
  #
540
  # reshape q, k, v for multihead attention and make em batch first
541
  #
 
544
  k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
545
  else:
546
  # TODO finish disentangling control flow so we don't do in-projections when statics are passed
547
+ assert (
548
+ static_k.size(0) == bsz * num_heads
549
+ ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
550
+ assert (
551
+ static_k.size(2) == head_dim
552
+ ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
553
  k = static_k
554
  if static_v is None:
555
  v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
556
  else:
557
  # TODO finish disentangling control flow so we don't do in-projections when statics are passed
558
+ assert (
559
+ static_v.size(0) == bsz * num_heads
560
+ ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
561
+ assert (
562
+ static_v.size(2) == head_dim
563
+ ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
564
  v = static_v
565
+
566
  # add zero attention along batch dimension (now first)
567
  if add_zero_attn:
568
  zero_attn_shape = (bsz * num_heads, 1, head_dim)
 
572
  attn_mask = pad(attn_mask, (0, 1))
573
  if key_padding_mask is not None:
574
  key_padding_mask = pad(key_padding_mask, (0, 1))
575
+
576
  # update source sequence length after adjustments
577
  src_len = k.size(1)
578
+
579
  # merge key padding and attention masks
580
  if key_padding_mask is not None:
581
+ assert key_padding_mask.shape == (
582
+ bsz,
583
+ src_len,
584
+ ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
585
+ key_padding_mask = (
586
+ key_padding_mask.view(bsz, 1, 1, src_len)
587
+ .expand(-1, num_heads, -1, -1)
588
+ .reshape(bsz * num_heads, 1, src_len)
589
+ )
590
  if attn_mask is None:
591
  attn_mask = key_padding_mask
592
  else:
593
  attn_mask = attn_mask + key_padding_mask
594
+
595
  # adjust dropout probability
596
  if not training:
597
  dropout_p = 0.0
598
+
599
  #
600
  # (deep breath) calculate attention and out projection
601
  #
602
+
603
  if need_weights:
604
  B, Nt, E = q.shape
605
  q_scaled = q / math.sqrt(E)
606
+
607
  assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
608
+
609
  if attn_mask is not None:
610
  attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
611
  else:
 
613
  attn_output_weights = softmax(attn_output_weights, dim=-1)
614
  if dropout_p > 0.0:
615
  attn_output_weights = dropout(attn_output_weights, p=dropout_p)
616
+
617
  attn_output = torch.bmm(attn_output_weights, v)
618
+
619
  attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
620
  attn_output = self.out_proj(attn_output)
621
  attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
622
+
623
  # optionally average attention weights over heads
624
  attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
625
  if average_attn_weights:
626
  attn_output_weights = attn_output_weights.mean(dim=1)
627
+
628
  if not is_batched:
629
  # squeeze the output if input was unbatched
630
  attn_output = attn_output.squeeze(1)
 
639
  attn_mask = attn_mask.unsqueeze(0)
640
  else:
641
  attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
642
+
643
  q = q.view(bsz, num_heads, tgt_len, head_dim)
644
  k = k.view(bsz, num_heads, src_len, head_dim)
645
  v = v.view(bsz, num_heads, src_len, head_dim)
646
+
647
  attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
648
  attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
649
+
650
  attn_output = self.out_proj(attn_output)
651
  attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
652
  if not is_batched:
 
655
  return attn_output, None
656
 
657
 
658
+ def _mha_shape_check(
659
+ query: Tensor,
660
+ key: Tensor,
661
+ value: Tensor,
662
+ key_padding_mask: Optional[Tensor],
663
+ attn_mask: Optional[Tensor],
664
+ num_heads: int,
665
+ ):
666
  # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
667
  # and returns if the input is batched or not.
668
  # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
 
671
  if query.dim() == 3:
672
  # Batched Inputs
673
  is_batched = True
674
+ assert key.dim() == 3 and value.dim() == 3, (
675
+ "For batched (3-D) `query`, expected `key` and `value` to be 3-D"
676
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
677
+ )
678
  if key_padding_mask is not None:
679
+ assert key_padding_mask.dim() == 2, (
680
+ "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
681
+ f" but found {key_padding_mask.dim()}-D tensor instead"
682
+ )
683
  if attn_mask is not None:
684
+ assert attn_mask.dim() in (2, 3), (
685
+ "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
686
+ f" but found {attn_mask.dim()}-D tensor instead"
687
+ )
688
  elif query.dim() == 2:
689
  # Unbatched Inputs
690
  is_batched = False
691
+ assert key.dim() == 2 and value.dim() == 2, (
692
+ "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
693
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
694
+ )
695
 
696
  if key_padding_mask is not None:
697
+ assert key_padding_mask.dim() == 1, (
698
+ "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
699
+ f" but found {key_padding_mask.dim()}-D tensor instead"
700
+ )
701
 
702
  if attn_mask is not None:
703
+ assert attn_mask.dim() in (2, 3), (
704
+ "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
705
+ f" but found {attn_mask.dim()}-D tensor instead"
706
+ )
707
  if attn_mask.dim() == 3:
708
  expected_shape = (num_heads, query.shape[0], key.shape[0])
709
+ assert (
710
+ attn_mask.shape == expected_shape
711
+ ), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}"
712
  else:
713
  raise AssertionError(
714
+ f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor"
715
+ )
716
 
717
  return is_batched
718
 
719
 
720
  def _canonical_mask(
721
+ mask: Optional[Tensor],
722
+ mask_name: str,
723
+ other_type: Optional[DType],
724
+ other_name: str,
725
+ target_type: DType,
726
+ check_other: bool = True,
727
  ) -> Optional[Tensor]:
 
728
  if mask is not None:
729
  _mask_dtype = mask.dtype
730
  _mask_is_float = torch.is_floating_point(mask)
731
  if _mask_dtype != torch.bool and not _mask_is_float:
732
+ raise AssertionError(f"only bool and floating types of {mask_name} are supported")
 
733
  if check_other and other_type is not None:
734
  if _mask_dtype != other_type:
735
  warnings.warn(
 
737
  "is deprecated. Use same type for both instead."
738
  )
739
  if not _mask_is_float:
740
+ mask = torch.zeros_like(mask, dtype=target_type).masked_fill_(mask, float("-inf"))
 
 
 
741
  return mask
742
 
743
 
 
748
  return input.dtype
749
  raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
750
 
751
+
752
  def _in_projection_packed(
753
  q: Tensor,
754
  k: Tensor,
 
855
  assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
856
  assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
857
  assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
858
+ return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
tokenization_minicpmv_fast.py CHANGED
@@ -40,7 +40,7 @@ class MiniCPMVTokenizerFast(Qwen2TokenizerFast):
40
  @property
41
  def slice_start_id(self):
42
  return self.convert_tokens_to_ids(self.slice_start)
43
-
44
  @property
45
  def slice_end_id(self):
46
  return self.convert_tokens_to_ids(self.slice_end)
@@ -48,14 +48,14 @@ class MiniCPMVTokenizerFast(Qwen2TokenizerFast):
48
  @property
49
  def im_id_start_id(self):
50
  return self.convert_tokens_to_ids(self.im_id_start)
51
-
52
  @property
53
  def im_id_end_id(self):
54
  return self.convert_tokens_to_ids(self.im_id_end)
55
-
56
  @property
57
  def newline_id(self):
58
- return self.convert_tokens_to_ids('\n')
59
 
60
  @staticmethod
61
  def escape(text: str) -> str:
@@ -63,4 +63,4 @@ class MiniCPMVTokenizerFast(Qwen2TokenizerFast):
63
 
64
  @staticmethod
65
  def unescape(text: str) -> str:
66
- return text
 
40
  @property
41
  def slice_start_id(self):
42
  return self.convert_tokens_to_ids(self.slice_start)
43
+
44
  @property
45
  def slice_end_id(self):
46
  return self.convert_tokens_to_ids(self.slice_end)
 
48
  @property
49
  def im_id_start_id(self):
50
  return self.convert_tokens_to_ids(self.im_id_start)
51
+
52
  @property
53
  def im_id_end_id(self):
54
  return self.convert_tokens_to_ids(self.im_id_end)
55
+
56
  @property
57
  def newline_id(self):
58
+ return self.convert_tokens_to_ids("\n")
59
 
60
  @staticmethod
61
  def escape(text: str) -> str:
 
63
 
64
  @staticmethod
65
  def unescape(text: str) -> str:
66
+ return text