File size: 4,744 Bytes
bcc6605 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | import numpy as np
from transformers import AutoImageProcessor, AutoProcessor
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from .image_processing_vectorllm import VectorLLMImageProcessor
class VectorLLMImagesKwargs(ImagesKwargs):
resized_size: int
patch_size: int
class VectorLLMProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: VectorLLMImagesKwargs
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
}
}
class VectorLLMProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "VectorLLMImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
self.image_token = "<pixel>"
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
def __call__(
self,
images: ImageInput = None,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
**kwargs: Unpack[VectorLLMProcessorKwargs],
) -> BatchFeature:
output_kwargs = self._merge_kwargs(
VectorLLMProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
image_inputs = {}
if images is not None:
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
if not isinstance(text, list):
text = [text]
text = text.copy()
if images is not None:
num_image_tokens = (
self.image_processor.resized_size // self.image_processor.patch_size
) ** 2
for index in range(len(text)):
while self.image_token in text[index]:
text[index] = text[index].replace(
self.image_token,
"<|placeholder|>" * num_image_tokens,
1,
)
text[index] = text[index].replace("<|placeholder|>", self.image_token)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(array_ids)
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
vision_data = {}
if image_sizes is not None:
images_kwargs = VectorLLMProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
resized_size = images_kwargs.get("resized_size", None) or self.image_processor.resized_size
patch_size = images_kwargs.get("patch_size", None) or self.image_processor.patch_size
num_image_patches = [(resized_size // patch_size) ** 2 for _ in image_sizes]
vision_data.update(
{"num_image_tokens": num_image_patches, "num_image_patches": num_image_patches}
)
return MultiModalData(**vision_data)
def post_process_image_text_to_text(
self,
generated_outputs,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
**kwargs,
):
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
AutoProcessor.register("VectorLLMProcessor", VectorLLMProcessor)
AutoImageProcessor.register("VectorLLMImageProcessor", VectorLLMImageProcessor)
|