| | import torch |
| |
|
| | from collections import UserDict, OrderedDict |
| | from typing import Union, List, Dict, Any |
| |
|
| | from transformers.processing_utils import ProcessorMixin |
| | from transformers.feature_extraction_utils import BatchFeature |
| | from transformers.utils.chat_template_utils import render_jinja_template |
| |
|
| |
|
| | class SmallVLMProcessor(ProcessorMixin): |
| | attributes = ["tokenizer", "image_processor"] |
| | optional_attributes = ['chat_template'] |
| | model_input_names = ['input_ids', 'attention_mask', 'pixel_values'] |
| | image_processor_class = "AutoImageProcessor" |
| | tokenizer_class = "AutoTokenizer" |
| |
|
| | image_token = '<|image_pad|>' |
| |
|
| | def __init__(self, tokenizer, image_processor, chat_template, **kwargs): |
| | super().__init__(tokenizer=tokenizer, image_processor=image_processor, chat_template=chat_template) |
| | self.tokenizer.add_special_tokens({'additional_special_tokens': [self.image_token]}, replace_additional_special_tokens=False) |
| | self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) |
| |
|
| | def __call__(self, inputs=None, images=[], text=None, **kwargs) -> BatchFeature: |
| |
|
| | truncation = kwargs.pop('truncation', False) |
| | max_length = kwargs.pop('max_length', 1024) |
| | padding = kwargs.pop('padding', False) |
| |
|
| | if inputs is None: |
| | inputs = {} |
| | if isinstance(inputs, UserDict): |
| | inputs = inputs.data |
| |
|
| | if 'input_ids' not in inputs: |
| | input_ids = self.tokenizer(text, padding=False, truncation=False, return_attention_mask=False, **kwargs)['input_ids'][0] |
| | inputs['input_ids'] = input_ids.tolist() |
| |
|
| | inputs = self.process_images(images, inputs=inputs) |
| |
|
| | if 'attention_mask' not in inputs: |
| | inputs['attention_mask'] = [1] * len(inputs['input_ids']) |
| |
|
| | if 'assistant_masks' in inputs: |
| | inputs['prompt_mask'] = [1-x for x in inputs.pop('assistant_masks')] |
| |
|
| | inputs = self.process_inputs(inputs) |
| |
|
| |
|
| | if truncation and len(inputs['input_ids']) > max_length: |
| | inputs = self.truncate(inputs, max_length) |
| |
|
| | if padding and len(inputs['input_ids']) < max_length: |
| | inputs = self.padding(inputs, max_length) |
| |
|
| | inputs = self.to_tensor(inputs) |
| |
|
| | self.check(inputs) |
| |
|
| | new_inputs = { |
| | "input_ids": inputs["input_ids"], |
| | "attention_mask": inputs["attention_mask"], |
| | } |
| | if "pixel_values" in inputs: |
| | new_inputs['pixel_values'] = inputs['pixel_values'] |
| | new_inputs['pixel_attention_mask'] = inputs['pixel_attention_mask'] |
| | new_inputs['spatial_shapes'] = inputs['spatial_shapes'] |
| | if 'prompt_mask' in inputs: |
| | new_inputs['prompt_mask'] = inputs['prompt_mask'] |
| | |
| | return BatchFeature(new_inputs) |
| |
|
| | def process_images(self, images, inputs): |
| | if len(images) > 0: |
| | pixel_values, spatial_shapes, pixel_attention_mask = self.image_transform(images) |
| | else: |
| | pixel_values = torch.zeros((0, self.image_processor.max_num_patches, 3*self.image_processor.patch_size**2), dtype=torch.float32) |
| | spatial_shapes = torch.zeros((0, 2), dtype=torch.int64) |
| | pixel_attention_mask = torch.ones((0, self.image_processor.max_num_patches), dtype=torch.int32) |
| | |
| | inputs['pixel_values'] = pixel_values |
| | inputs['spatial_shapes'] = spatial_shapes |
| | inputs['pixel_attention_mask'] = pixel_attention_mask |
| | return inputs |
| |
|
| | def image_transform(self, images): |
| | image_inputs = self.image_processor(images, return_tensors='pt') |
| | return image_inputs['pixel_values'], image_inputs['spatial_shapes'], image_inputs['pixel_attention_mask'] |
| |
|
| | def truncate(self, inputs: Dict[str, Any], max_length: int): |
| | assert self.image_token_id not in inputs['input_ids'][max_length:], f"Truncate image token is not allowed." |
| |
|
| | inputs['input_ids'] = inputs['input_ids'][:max_length] |
| | inputs['attention_mask'] = inputs['attention_mask'][:max_length] |
| | if 'prompt_mask' in inputs: |
| | inputs['prompt_mask'] = inputs['prompt_mask'][:max_length] |
| |
|
| | return inputs |
| |
|
| | def get_image_token_length(self, inputs: Dict[str, Any]) -> List[int]: |
| | spatial_shapes = inputs.get('spatial_shapes', None) |
| | if spatial_shapes is None: |
| | return [] |
| | image_token_lens = spatial_shapes.prod(dim=1).tolist() |
| | return image_token_lens |
| |
|
| | def process_inputs(self, inputs: Dict[str, Any]): |
| | graft_token_lens = self._get_graft_token_length(inputs) |
| |
|
| | inputs['input_ids'] = self._graft_token(inputs['input_ids'], graft_token_lens, self.image_token_id) |
| | inputs['attention_mask'] = self._graft_token(inputs['attention_mask'], graft_token_lens, 'replicate') |
| | if 'prompt_mask' in inputs: |
| | inputs['prompt_mask'] = self._graft_token(inputs['prompt_mask'], graft_token_lens, 'replicate') |
| |
|
| | return inputs |
| |
|
| | def _graft_token(self, seq, graft_token_lens, value): |
| | if value == 'replicate': |
| | for i in reversed(graft_token_lens.keys()): |
| | seq[i:] = [seq[i]] * graft_token_lens[i] + seq[i+1:] |
| | else: |
| | for i in reversed(graft_token_lens.keys()): |
| | assert value == seq[i] |
| | seq[i:] = [value] * graft_token_lens[i] + seq[i+1:] |
| | return seq |
| |
|
| | def _get_graft_token_length(self, inputs: Dict[str, Any]) -> Dict[int, int]: |
| | image_token_pos = [i for i, x in enumerate(inputs['input_ids']) if x == self.image_token_id] |
| | image_token_lens = self.get_image_token_length(inputs) |
| |
|
| | assert len(image_token_pos) == len(image_token_lens), \ |
| | "Wrong image token count, " \ |
| | f"image_token_count({len(image_token_pos)}) != image_count({len(image_token_lens)})" |
| |
|
| | graft_token_lens = OrderedDict(item for item in zip(image_token_pos, image_token_lens)) |
| |
|
| | return graft_token_lens |
| |
|
| | def check(self, inputs: Dict[str, Any]): |
| | image_embed_token_count = torch.count_nonzero(inputs['input_ids'] == self.image_token_id).item() |
| | image_embed_count = sum(self.get_image_token_length(inputs)) |
| | assert image_embed_token_count == image_embed_count, "Wrong image embed token count" |
| |
|
| | def padding(self, inputs: Dict[str, Any], max_length: int): |
| | padding_len = max_length - len(inputs['input_ids']) |
| | inputs['input_ids'] += [self.pad_token_id] * padding_len |
| | inputs['attention_mask'] += [0] * padding_len |
| | if 'prompt_mask' in inputs: |
| | inputs['prompt_mask'] += [0] * padding_len |
| | return inputs |
| |
|
| | def decode(self, token_ids: Union[List[int], torch.Tensor], **kwargs): |
| | if isinstance(token_ids, torch.Tensor): |
| | token_ids = token_ids.tolist() |
| | text = self.tokenizer.decode(token_ids, **kwargs) |
| | return text |
| |
|
| | def batch_decode(self, sequences: Union[List[List[int]], torch.Tensor], **kwargs): |
| | if isinstance(sequences, torch.Tensor): |
| | sequences = sequences.tolist() |
| | texts = self.tokenizer.batch_decode(sequences, **kwargs) |
| | return texts |
| |
|
| | def to_tensor(self, inputs): |
| | inputs['input_ids'] = torch.tensor([inputs['input_ids']], dtype=torch.long) |
| | inputs['attention_mask'] = torch.tensor([inputs['attention_mask']], dtype=torch.bool) |
| | if 'prompt_mask' in inputs: |
| | inputs['prompt_mask'] = torch.tensor([inputs['prompt_mask']], dtype=torch.bool) |
| | return inputs |
| |
|
| | @property |
| | def pad_token_id(self): |
| | return self.tokenizer.pad_token_id |
| |
|
| | @property |
| | def special_tokens(self): |
| | return [token.content for token in self.tokenizer.added_tokens_decoder.values()] |
| |
|
| | def __repr__(self): |
| | pass |
| |
|
| | def __str__(self): |
| | return '' |