magicgh commited on
Commit
b55a547
1 Parent(s): fac2a26

Upload processing_phi3_v.py

Browse files

Update the phi3-v processor to support multi-frame images as input.

Files changed (1) hide show
  1. processing_phi3_v.py +34 -11
processing_phi3_v.py CHANGED
@@ -41,7 +41,7 @@ from transformers.image_transforms import (
41
  from transformers.image_utils import (
42
  OPENAI_CLIP_MEAN,
43
  OPENAI_CLIP_STD,
44
- ImageInput,
45
  make_list_of_images,
46
  valid_images,
47
  )
@@ -57,6 +57,7 @@ if is_vision_available():
57
  import torch
58
  import torchvision
59
 
 
60
 
61
  def padding_336(b):
62
  width, height = b.size
@@ -139,6 +140,11 @@ def pad_to_max_num_crops_tensor(images, max_crops=5):
139
  images = torch.cat([images, pad], dim=0)
140
  return images
141
 
 
 
 
 
 
142
 
143
  class Phi3VImageProcessor(BaseImageProcessor):
144
  r"""
@@ -330,7 +336,7 @@ class Phi3VProcessor(ProcessorMixin):
330
  def __call__(
331
  self,
332
  text: Union[TextInput, List[TextInput]],
333
- images: ImageInput = None,
334
  padding: Union[bool, str, PaddingStrategy] = False,
335
  truncation: Union[bool, str, TruncationStrategy] = None,
336
  max_length=None,
@@ -382,6 +388,8 @@ class Phi3VProcessor(ProcessorMixin):
382
  - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
383
  """
384
  if images is not None:
 
 
385
  image_inputs = self.image_processor(images, return_tensors=return_tensors)
386
  else:
387
  image_inputs = {}
@@ -421,7 +429,14 @@ class Phi3VProcessor(ProcessorMixin):
421
  return BatchFeature(data={**model_inputs})
422
 
423
  pattern = r"<\|image_\d+\|>"
424
- prompt_chunks = [self.tokenizer(chunk, truncation=truncation, max_length=max_length).input_ids for chunk in re.split(pattern, texts)]
 
 
 
 
 
 
 
425
 
426
  if 'num_img_tokens' in images:
427
  num_img_tokens = images['num_img_tokens']
@@ -433,18 +448,23 @@ class Phi3VProcessor(ProcessorMixin):
433
  images, image_sizes = images['pixel_values'], images['image_sizes']
434
 
435
  # image_tags needs to start from 1 to n
436
- image_tags = re.findall(pattern, texts)
437
  # image_ids = [int(s.split("|")[1].split("_")[-1]) * -1 for s in image_tags]
438
  # image_ids_pad = [[iid]*num_img_tokens[i] for i, iid in enumerate(image_ids)]
439
- image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
440
- unique_image_ids = sorted(list(set(image_ids)))
 
 
 
 
 
441
  # image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
442
  # check the condition
443
  assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
444
  # total images must be the same as the number of image tags
445
  assert len(unique_image_ids) == len(images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(images)} images"
446
 
447
- image_ids_pad = [[-iid] * num_img_tokens[iid - 1] for iid in image_ids]
448
 
449
  def insert_separator(X, sep_list):
450
  if len(X) > len(sep_list):
@@ -452,12 +472,15 @@ class Phi3VProcessor(ProcessorMixin):
452
  return [ele for sublist in zip(X, sep_list) for ele in sublist]
453
 
454
  input_ids = []
455
- offset = 0
456
- for x in insert_separator(prompt_chunks, image_ids_pad):
457
- input_ids.extend(x[offset:])
 
 
458
 
459
- input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
460
  attention_mask = (input_ids > -1000000).to(torch.long)
 
461
 
462
  return BatchFeature(data={"input_ids": input_ids,
463
  "attention_mask": attention_mask,
 
41
  from transformers.image_utils import (
42
  OPENAI_CLIP_MEAN,
43
  OPENAI_CLIP_STD,
44
+ is_valid_image,
45
  make_list_of_images,
46
  valid_images,
47
  )
 
57
  import torch
58
  import torchvision
59
 
60
+ MultiFrameImageInput = Union[List[List["Image.Image"]], List[List[np.ndarray]], List[List["torch.Tensor"]]]
61
 
62
  def padding_336(b):
63
  width, height = b.size
 
140
  images = torch.cat([images, pad], dim=0)
141
  return images
142
 
143
+ def is_multi_frames(images):
144
+ if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)):
145
+ return is_valid_image(images[0][0])
146
+ else:
147
+ return False
148
 
149
  class Phi3VImageProcessor(BaseImageProcessor):
150
  r"""
 
336
  def __call__(
337
  self,
338
  text: Union[TextInput, List[TextInput]],
339
+ images: Union[ImageInput, MultiFrameImageInput] = None,
340
  padding: Union[bool, str, PaddingStrategy] = False,
341
  truncation: Union[bool, str, TruncationStrategy] = None,
342
  max_length=None,
 
388
  - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
389
  """
390
  if images is not None:
391
+ if is_multi_frames(images):
392
+ images = [image for sample_images in images for image in sample_images]
393
  image_inputs = self.image_processor(images, return_tensors=return_tensors)
394
  else:
395
  image_inputs = {}
 
429
  return BatchFeature(data={**model_inputs})
430
 
431
  pattern = r"<\|image_\d+\|>"
432
+ if isinstance(texts, str):
433
+ texts = [texts]
434
+
435
+ prompt_chunks = []
436
+ image_tags = []
437
+ for text in texts:
438
+ prompt_chunks.append([self.tokenizer(chunk, truncation=truncation, max_length=max_length).input_ids for chunk in re.split(pattern, text)])
439
+ image_tags.append(re.findall(pattern, text))
440
 
441
  if 'num_img_tokens' in images:
442
  num_img_tokens = images['num_img_tokens']
 
448
  images, image_sizes = images['pixel_values'], images['image_sizes']
449
 
450
  # image_tags needs to start from 1 to n
451
+ # image_tags = re.findall(pattern, texts)
452
  # image_ids = [int(s.split("|")[1].split("_")[-1]) * -1 for s in image_tags]
453
  # image_ids_pad = [[iid]*num_img_tokens[i] for i, iid in enumerate(image_ids)]
454
+
455
+ image_ids_counter = 0
456
+ image_ids = []
457
+ for tags in image_tags:
458
+ image_ids.append([int(s.split("|")[1].split("_")[-1]) + image_ids_counter for s in tags])
459
+ image_ids_counter += len(tags)
460
+ unique_image_ids = sorted(list(set([iid for ids in image_ids for iid in ids])))
461
  # image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
462
  # check the condition
463
  assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
464
  # total images must be the same as the number of image tags
465
  assert len(unique_image_ids) == len(images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(images)} images"
466
 
467
+ image_ids_pad = [[[-iid]*num_img_tokens[iid-1] for iid in ids] for ids in image_ids]
468
 
469
  def insert_separator(X, sep_list):
470
  if len(X) > len(sep_list):
 
472
  return [ele for sublist in zip(X, sep_list) for ele in sublist]
473
 
474
  input_ids = []
475
+ for sub_prompt_chunks, sub_image_ids_pad in zip(prompt_chunks, image_ids_pad):
476
+ input_ids.append([])
477
+ offset = 0
478
+ for x in insert_separator(sub_prompt_chunks, sub_image_ids_pad):
479
+ input_ids[-1].extend(x[offset:])
480
 
481
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
482
  attention_mask = (input_ids > -1000000).to(torch.long)
483
+ attention_mask[input_ids == self.tokenizer.pad_token_id] = 0
484
 
485
  return BatchFeature(data={"input_ids": input_ids,
486
  "attention_mask": attention_mask,