| |
| |
| |
|
|
| import numpy as np |
| import torch |
| from transformers.feature_extraction_utils import BatchFeature |
| from transformers.image_utils import ImageInput |
| from transformers.processing_utils import ProcessorMixin |
| from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
| from transformers.video_utils import VideoInput |
|
|
|
|
| class BrainOCRProcessor(ProcessorMixin): |
| attributes = ["image_processor", "tokenizer"] |
| valid_kwargs = ["chat_template"] |
| image_processor_class = "AutoImageProcessor" |
| tokenizer_class = "AutoTokenizer" |
|
|
| def __init__( |
| self, |
| image_processor=None, |
| tokenizer=None, |
| chat_template=None, |
| **kwargs, |
| ): |
| self.tokenizer = tokenizer |
| self.image_token_id = 120120 |
| self.image_token = self.tokenizer.convert_ids_to_tokens(self.image_token_id) |
| self.im_start_token_id = 120118 |
| self.im_start_token = self.tokenizer.convert_ids_to_tokens( |
| self.im_start_token_id |
| ) |
| self.im_end_token_id = 120119 |
| self.im_end_token = self.tokenizer.convert_ids_to_tokens(self.im_end_token_id) |
| self.placeholder_token = self.tokenizer.convert_ids_to_tokens( |
| self.tokenizer.vocab_size - 1 |
| ) |
| self.pad_id = 120002 |
|
|
| super().__init__(image_processor, tokenizer, chat_template=chat_template) |
|
|
| def __call__( |
| self, |
| images: ImageInput = None, |
| text: TextInput |
| | PreTokenizedInput |
| | list[TextInput] |
| | list[PreTokenizedInput] = None, |
| videos: VideoInput = None, |
| **kwargs, |
| ) -> BatchFeature: |
| image_inputs = {} |
| if images is not None: |
| image_inputs = self.image_processor(images=images) |
| image_grid_thw = image_inputs["image_grid_thw"] |
|
|
| if not isinstance(text, list): |
| text = [text] |
|
|
| text = text.copy() |
|
|
| image_tokens_cumsum = [0] |
| if images is not None: |
| index = 0 |
| for i in range(len(text)): |
| while self.image_token in text[i]: |
| grid_h, grid_w = image_grid_thw[index][-2:] |
| patch_h = grid_h // self.image_processor.merge_size |
| patch_w = grid_w // self.image_processor.merge_size |
| num_image_tokens = patch_h * (patch_w + 1) + 2 |
| image_tokens_cumsum.append( |
| image_tokens_cumsum[-1] + num_image_tokens |
| ) |
| text[i] = text[i].replace( |
| self.image_token, self.placeholder_token * num_image_tokens, 1 |
| ) |
| index += 1 |
| text[i] = text[i].replace(self.placeholder_token, self.image_token) |
|
|
| text_inputs = self.tokenizer(text, add_special_tokens=False, **kwargs) |
| self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) |
|
|
| input_ids = text_inputs["input_ids"] |
| position_ids = torch.arange(len(input_ids[0])) |
| position_ids_w = torch.arange(len(input_ids[0])) |
| position_ids_h = torch.arange(len(input_ids[0])) |
| position_ids_t = torch.arange(len(input_ids[0])) |
|
|
| if images is not None: |
| image_token_pos_indices = torch.where(input_ids[0] == self.image_token_id)[ |
| 0 |
| ] |
| for i in range(len(image_grid_thw)): |
| grid_h, grid_w = image_grid_thw[i][-2:] |
| patch_h = grid_h // self.image_processor.merge_size |
| patch_w = grid_w // self.image_processor.merge_size |
| start_pos = image_token_pos_indices[image_tokens_cumsum[i]].item() + 1 |
| replace_num = (patch_w + 1) * patch_h |
| position_ids_w[start_pos : start_pos + replace_num] = torch.tensor( |
| list(range(patch_w + 1)) * patch_h, dtype=torch.int64 |
| ) |
| patch_h_list = [] |
| for h in range(patch_h): |
| patch_h_list += [h] * (patch_w + 1) |
| position_ids_h[start_pos : start_pos + replace_num] = torch.tensor( |
| patch_h_list, dtype=torch.int64 |
| ) |
| position_ids_t[start_pos : start_pos + replace_num] = 0 |
|
|
| position_ids = torch.stack( |
| [position_ids, position_ids_w, position_ids_h, position_ids_t] |
| ).unsqueeze(0) |
| text_inputs["position_ids"] = position_ids |
|
|
| attention_mask = input_ids.ne(self.pad_id) |
| text_inputs["attention_mask"] = attention_mask |
| text_inputs["imgs_pos"] = [self.get_imgs_pos(e) for e in input_ids] |
|
|
| return_tensors = kwargs.pop("return_tensors", None) |
| return BatchFeature( |
| data={**text_inputs, **image_inputs}, |
| tensor_type=return_tensors, |
| ) |
|
|
| def batch_decode(self, *args, **kwargs): |
| return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
| def decode(self, *args, **kwargs): |
| return self.tokenizer.decode(*args, **kwargs) |
|
|
| def post_process_image_text_to_text( |
| self, |
| generated_outputs, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False, |
| **kwargs, |
| ): |
| assert 0 |
|
|
| def apply_chat_template(self, *args, **kwargs): |
| kwargs["return_dict"] = False |
| return self.tokenizer.apply_chat_template(*args, **kwargs) |
|
|
| def get_imgs_pos(self, doc_ids): |
| doc_ids = np.array(doc_ids, dtype=np.int64) |
| img_begin_index = np.where(doc_ids == self.im_start_token_id)[0] |
| img_end_index = np.where(doc_ids == self.im_end_token_id)[0] |
| imgs_pos = np.concatenate( |
| ( |
| np.reshape(img_begin_index + 1, (-1, 1)), |
| np.reshape(img_end_index, (-1, 1)), |
| ), |
| axis=-1, |
| ).tolist() |
| return imgs_pos |
|
|
| @property |
| def model_input_names(self): |
| tokenizer_input_names = self.tokenizer.model_input_names |
| image_processor_input_names = self.image_processor.model_input_names |
| return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
|
|
|
|
| def split_image_into_patch_blocks( |
| pixel_values: torch.Tensor, |
| patch_size: int = 16, |
| adaptor_patch_div: int = 4, |
| ) -> torch.Tensor: |
| """Split image tensor into patch blocks for the vision encoder.""" |
| batch_size, channels, height, width = pixel_values.shape |
| assert channels == 3, "Pixel values must have 3 channels in dim=1" |
| assert height % patch_size == 0 and width % patch_size == 0, ( |
| "H and W must be divisible by patch_size" |
| ) |
|
|
| patch_height_num = height // patch_size |
| patch_width_num = width // patch_size |
|
|
| img = pixel_values.reshape( |
| batch_size, 3, patch_height_num, patch_size, patch_width_num, patch_size |
| ) |
|
|
| img = img.reshape( |
| batch_size, |
| 3, |
| patch_height_num, |
| patch_size // adaptor_patch_div, |
| adaptor_patch_div, |
| patch_width_num, |
| patch_size // adaptor_patch_div, |
| adaptor_patch_div, |
| ) |
|
|
| img = img.permute(0, 2, 5, 3, 6, 1, 4, 7) |
|
|
| patches = img.reshape(-1, 3, patch_size, patch_size) |
|
|
| return patches |
|
|