| from PIL import Image | |
| from scipy import special | |
| import torch | |
| import numpy as np | |
| from math import e | |
| from param import output | |
| from transformers.feature_extraction_utils import BatchFeature | |
| from transformers.processing_utils import ProcessorMixin | |
| class UMMProcessor(ProcessorMixin): | |
| attributes = ["image_processor", "tokenizer"] | |
| image_processor_class = "AutoImageProcessor" | |
| tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") | |
| def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): | |
| self.image_token = "<image>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token | |
| if getattr(tokenizer, "image_token_id", None): | |
| self.image_token_id = tokenizer.image_token_id | |
| else: | |
| tokenizer.add_tokens(["<image>"], special_tokens=True) | |
| self.image_token_id = -200 | |
| self.image_gen_token = "<image_gen>" if not hasattr(tokenizer, "image_gen_token") else tokenizer.image_gen_token | |
| if getattr(tokenizer, "image_gen_token_id", None): | |
| self.image_gen_token_id = tokenizer.image_gen_token_id | |
| else: | |
| tokenizer.add_tokens(["<image_gen>"], special_tokens=True) | |
| self.image_gen_token_id = -300 | |
| self.image_gen_start_token = "<im_start>" if not hasattr(tokenizer, "image_gen_start") else tokenizer.image_gen_start | |
| if getattr(tokenizer, "image_gen_start_token_id", None): | |
| self.image_gen_start_token_id = tokenizer.image_gen_start_token_id | |
| else: | |
| tokenizer.add_tokens(["<im_start>"], special_tokens=True) | |
| self.image_gen_start_token_id = tokenizer.convert_tokens_to_ids(self.image_gen_start_token) | |
| self.image_gen_end_token = "<im_end>" if not hasattr(tokenizer, "image_gen_end") else tokenizer.image_gen_end | |
| if getattr(tokenizer, "image_gen_end_token_id", None): | |
| self.image_gen_end_token_id = tokenizer.image_gen_end_token_id | |
| else: | |
| tokenizer.add_tokens(["<im_end>"], special_tokens=True) | |
| self.image_gen_end_token_id = tokenizer.convert_tokens_to_ids(self.image_gen_end_token) | |
| self.no_mean_token = "<no_mean>" if not hasattr(tokenizer, "no_mean") else tokenizer.no_mean | |
| if getattr(tokenizer, "no_mean_id", None): | |
| self.no_mean_token_id = tokenizer.no_mean_id | |
| else: | |
| tokenizer.add_tokens(["<no_mean>"], special_tokens=True) | |
| self.no_mean_token_id = tokenizer.convert_tokens_to_ids(self.no_mean_token) | |
| if chat_template is None and hasattr(tokenizer, "chat_template"): | |
| chat_template = tokenizer.chat_template | |
| super().__init__(image_processor, tokenizer, chat_template=chat_template) | |
| def __call__(self, images=None, text=None, max_resolution=None, add_im_start_id=False, **kwargs): | |
| if "padding" not in kwargs: | |
| kwargs["padding"] = True | |
| if "truncation" not in kwargs: | |
| kwargs["truncation"] = True | |
| if not isinstance(text, list): | |
| text = [text] | |
| text = text.copy() | |
| return_tensors = kwargs.pop("return_tensors", None) | |
| text_inputs = self.tokenizer(text, **kwargs, return_tensors=return_tensors) | |
| img_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) | |
| img_gen_token_id = self.tokenizer.convert_tokens_to_ids(self.image_gen_token) | |
| if add_im_start_id: | |
| B, T = text_inputs["input_ids"].shape | |
| new_input_ids = torch.full((B, T+1), self.tokenizer.pad_token_id) | |
| new_input_ids[:, :T] = text_inputs["input_ids"] | |
| is_valid = (text_inputs["input_ids"] != self.tokenizer.pad_token_id) | |
| valid_len = is_valid.sum(dim=1) | |
| else: | |
| new_input_ids = text_inputs["input_ids"] | |
| t = [] | |
| und_gen_mask_list = [] | |
| for i, ids in enumerate(text_inputs["input_ids"]): | |
| for j, token_id in enumerate(ids): | |
| if token_id == img_token_id: | |
| new_input_ids[i][j] = self.image_token_id | |
| t.append(torch.tensor([1.0])) | |
| und_gen_mask_list.append(1) | |
| elif token_id == img_gen_token_id: | |
| new_input_ids[i][j] = self.image_gen_token_id | |
| t.append(torch.rand(1)) | |
| und_gen_mask_list.append(0) | |
| image_inputs = {} | |
| pixel_values, grid_hws = [], [] | |
| if images is not None: | |
| image_idx = 0 | |
| for per_images in images if isinstance(images, list) else [images]: | |
| if per_images is None: | |
| dummy_image = Image.fromarray(np.random.randint(0, 256, (256, 256, 3), dtype=np.uint8)) | |
| image_info = self.image_processor(images=dummy_image) | |
| else: | |
| for per_image in per_images if isinstance(per_images, list) else[per_images]: | |
| if und_gen_mask_list[image_idx] == 0: | |
| image_info = self.image_processor(images=per_image, max_resolution=max_resolution, und=False) | |
| else: | |
| image_info = self.image_processor(images=per_image, max_resolution=max_resolution) | |
| image_idx += 1 | |
| pixel_values.append(image_info.pixel_values) | |
| grid_hws.append(image_info.grid_hws) | |
| pixel_values = torch.concat(pixel_values, dim=0) | |
| grid_hws = torch.concat(grid_hws, dim=0) | |
| image_inputs.update({'pixel_values': pixel_values, 'grid_hws': grid_hws}) | |
| if len(t) > 0: | |
| t = torch.cat(t) | |
| image_inputs.update({"t":t}) | |
| if add_im_start_id: | |
| for b in range(B): | |
| pos = valid_len[b].item() | |
| new_input_ids[b, pos] = self.image_gen_start_token_id | |
| attention_mask = torch.cat([ | |
| text_inputs["attention_mask"], | |
| (new_input_ids[:, -1] != self.tokenizer.pad_token_id).long().unsqueeze(1) | |
| ], dim=1) | |
| text_inputs["attention_mask"] = attention_mask | |
| text_inputs["input_ids"] = new_input_ids | |
| return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) | |
| __all__ = ["UMMProcessor"] |