import math from typing import List, Optional import json import torch import torchvision from copy import deepcopy from PIL import Image from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from torchvision import transforms from transformers import LlamaTokenizer, LlamaPreTrainedModel, LlamaForCausalLM, AutoModel, PreTrainedTokenizerFast from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer from .configuration_minicpm import MiniCPMVConfig from .resampler import Resampler class MiniCPMVPreTrainedModel(LlamaPreTrainedModel): config_class = MiniCPMVConfig class MiniCPMV(MiniCPMVPreTrainedModel): def __init__(self, config): super().__init__(config) self.llm = LlamaForCausalLM(config) self.vpm = self.init_vision_module() self.vision_dim = self.vpm.embed_dim self.embed_dim = self.llm.config.hidden_size self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) self.transform = self.init_transform() def init_vision_module(self): # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit model = Idefics2VisionTransformer(self.config.vision_config) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] setattr(model, 'embed_dim', model.embeddings.embed_dim) setattr(model, 'patch_size', model.embeddings.patch_size) return model def init_resampler(self, embed_dim, vision_dim): return Resampler( num_queries=self.config.query_num, embed_dim=embed_dim, num_heads=embed_dim // 128, kv_dim=vision_dim, adaptive=True ) def init_transform(self): return transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD ), ] ) def get_vllm_embedding(self, data): if 'vision_hidden_states' not in data: dtype = self.vpm.embeddings.position_embedding.weight.dtype device = self.vpm.embeddings.position_embedding.weight.device tgt_sizes = data['tgt_sizes'] pixel_values_list = data['pixel_values'] vision_hidden_states = [] all_pixel_values = [] img_cnt = [] for pixel_values in pixel_values_list: img_cnt.append(len(pixel_values)) all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values]) # exist image if all_pixel_values: tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) if self.config.batch_vision_input: max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True, padding_value=0.0) B, L, _ = all_pixel_values.shape all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) for i in range(B): patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True vision_embedding = self.vpm(all_pixel_values.type(dtype), patch_attention_mask=patch_attn_mask).last_hidden_state vision_embedding = self.resampler(vision_embedding, tgt_sizes) else: # get vision_embedding foreach vision_embedding = [] for single_tgt_size, single_pixel_values in zip(tgt_sizes, all_pixel_values): single_pixel_values = single_pixel_values.unsqueeze(0) B, L, _ = single_pixel_values.shape single_pixel_values = single_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) single_vision_embedding = self.vpm(single_pixel_values.type(dtype)).last_hidden_state single_vision_embedding = self.resampler(single_vision_embedding, single_tgt_size.unsqueeze(0)) vision_embedding.append(single_vision_embedding) vision_embedding = torch.vstack(vision_embedding) start = 0 for pixel_values in pixel_values_list: img_cnt = len(pixel_values) if img_cnt > 0: vision_hidden_states.append(vision_embedding[start: start + img_cnt]) start += img_cnt else: vision_hidden_states.append([]) else: # no image if self.training: dummy_image = torch.zeros( (1, 3, 224, 224), device=device, dtype=dtype ) tgt_sizes = torch.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).type(torch.int32) dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes) else: dummy_feature = [] for _ in range(len(pixel_values_list)): vision_hidden_states.append(dummy_feature) else: vision_hidden_states = data['vision_hidden_states'] if hasattr(self.llm.config, 'scale_emb'): vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb else: vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance( i, torch.Tensor) else i for i in vision_hidden_states] bs = len(data['input_ids']) for i in range(bs): cur_vs_hs = vision_hidden_states[i] if len(cur_vs_hs) > 0: cur_vllm_emb = vllm_embedding[i] cur_image_bound = data['image_bound'][i] if len(cur_image_bound) > 0: image_indices = torch.stack( [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound] ).to(vllm_embedding.device) cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), cur_vs_hs.view(-1, cur_vs_hs.shape[-1])) elif self.training: cur_vllm_emb += cur_vs_hs[0].mean() * 0 return vllm_embedding, vision_hidden_states def forward(self, data, **kwargs): vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data) position_ids = data["position_ids"] if position_ids.dtype != torch.int64: position_ids = position_ids.long() return self.llm( input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs ) def _convert_to_tensors( self, tokenizer, input_ids, max_inp_length: Optional[int] = None ): if max_inp_length is not None: input_ids = input_ids[:max_inp_length] input_ids = torch.tensor(input_ids, dtype=torch.int32) image_start_tokens = torch.where(input_ids == tokenizer.im_start_id)[0] # 跳过 im_start image_start_tokens += 1 image_end_tokens = torch.where(input_ids == tokenizer.im_end_id)[0] valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) image_bound = torch.hstack( [ image_start_tokens[:valid_image_nums].unsqueeze(-1), image_end_tokens[:valid_image_nums].unsqueeze(-1), ] ) model_input = {} model_input["input_ids"] = input_ids.unsqueeze(0).to(self.device) model_input["image_bound"] = image_bound return model_input def _process_list( self, tokenizer, input_id_list, max_inp_length: Optional[int] = None ): pad_keys = ["input_ids"] input_tensors = [] for input_ids in input_id_list: input_tensors.append( self._convert_to_tensors(tokenizer, input_ids, max_inp_length) ) padded = {} for key in pad_keys: padded[key] = pad(input_tensors, key, padding_side="left").to(self.device) padded["image_bound"] = [i["image_bound"] for i in input_tensors] return padded def _decode(self, inputs_embeds, tokenizer, **kwargs): terminators = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>") ] output = self.llm.generate( inputs_embeds=inputs_embeds, pad_token_id=0, eos_token_id=terminators, **kwargs ) return self._decode_text(output, tokenizer) def _decode_text(self, result_ids, tokenizer): result_text = [] for result in result_ids: result = result[result != 0] if result[0] == tokenizer.bos_id: result = result[1:] if result[-1] == tokenizer.eos_id or result[-1] == tokenizer.eot_id: result = result[:-1] result_text.append(tokenizer.decode(result).strip()) return result_text def slice_image(self, image): return slice_image( image, self.config.slice_config.max_slice_nums, self.config.slice_config.scale_resolution, self.config.slice_config.patch_size, ) def get_slice_image_placeholder(self, image, tokenizer): image_placeholder = ( tokenizer.im_start + tokenizer.unk_token * self.config.query_num + tokenizer.im_end ) slice_images = [] source_image, patches, best_grid = slice_image( image, self.config.slice_config.max_slice_nums, self.config.slice_config.scale_resolution, self.config.slice_config.patch_size, ) slice_images.append(source_image) final_placeholder = image_placeholder if len(patches) > 0: for i in range(len(patches)): for j in range(len(patches[0])): slice_images.append(patches[i][j]) final_placeholder += get_grid_placeholder( tokenizer, best_grid, self.config.query_num ) return slice_images, final_placeholder def reshape_by_patch(self, image_tensor): """ :param image_tensor: shape [3, H, W] :param patch_size: :return: [3, patch_size, HW/patch_size] """ patch_size = self.config.patch_size patches = torch.nn.functional.unfold( image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size) ) patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1) patches = patches.permute(0, 1, 3, 2).reshape(image_tensor.size(0), patch_size, -1) return patches def generate( self, input_id_list=None, img_list=None, tgt_sizes=None, tokenizer=None, max_inp_length: Optional[int] = None, vision_hidden_states=None, return_vision_hidden_states=False, **kwargs ): assert input_id_list is not None bs = len(input_id_list) if img_list == None: img_list = [[] for i in range(bs)] assert bs == len(img_list) model_inputs = self._process_list(tokenizer, input_id_list, max_inp_length) if vision_hidden_states is None: pixel_values = [] for i in range(bs): img_inps = [] for img in img_list[i]: img_inps.append(img.to(self.device)) if img_inps: pixel_values.append(img_inps) else: pixel_values.append([]) model_inputs["pixel_values"] = pixel_values model_inputs['tgt_sizes'] = tgt_sizes else: model_inputs["vision_hidden_states"] = vision_hidden_states with torch.inference_mode(): ( model_inputs["inputs_embeds"], vision_hidden_states, ) = self.get_vllm_embedding(model_inputs) result = self._decode(model_inputs["inputs_embeds"], tokenizer, **kwargs) if return_vision_hidden_states: return result, vision_hidden_states return result def chat( self, image, msgs, tokenizer, vision_hidden_states=None, max_new_tokens=1024, sampling=True, max_inp_length=2048, **kwargs ): if isinstance(msgs, str): msgs = json.loads(msgs) copy_msgs = deepcopy(msgs) assert len(copy_msgs) > 0, 'msgs is empty' if image is not None and isinstance(copy_msgs[0]['content'], str): copy_msgs[0]['content'] = [image, copy_msgs[0]['content']] images = [] tgt_sizes = [] for i, msg in enumerate(copy_msgs): role = msg["role"] content = msg["content"] assert role in ["user", "assistant"] if i == 0: assert role == "user", "The role of first msg should be user" if isinstance(content, str): content = [content] cur_msgs = [] for c in content: if isinstance(c, Image.Image): image = c if self.config.slice_mode: slice_images, image_placeholder = self.get_slice_image_placeholder( image, tokenizer ) cur_msgs.append(image_placeholder) for slice_image in slice_images: slice_image = self.transform(slice_image) H, W = slice_image.shape[1:] images.append(self.reshape_by_patch(slice_image)) tgt_sizes.append(torch.Tensor([H // self.config.patch_size, W // self.config.patch_size]).type(torch.int32)) else: images.append(self.transform(image)) cur_msgs.append( tokenizer.im_start + tokenizer.unk_token * self.config.query_num + tokenizer.im_end ) elif isinstance(c, str): cur_msgs.append(c) msg['content'] = '\n'.join(cur_msgs) if tgt_sizes: tgt_sizes = torch.vstack(tgt_sizes) input_ids = tokenizer.apply_chat_template(copy_msgs, tokenize=True, add_generation_prompt=False) if sampling: generation_config = { "top_p": 0.8, "top_k": 100, "temperature": 0.7, "do_sample": True, "repetition_penalty": 1.05 } else: generation_config = { "num_beams": 3, "repetition_penalty": 1.2, } generation_config.update( (k, kwargs[k]) for k in generation_config.keys() & kwargs.keys() ) with torch.inference_mode(): res, vision_hidden_states = self.generate( input_id_list=[input_ids], max_inp_length=max_inp_length, img_list=[images], tgt_sizes=[tgt_sizes], tokenizer=tokenizer, max_new_tokens=max_new_tokens, vision_hidden_states=vision_hidden_states, return_vision_hidden_states=True, **generation_config ) answer = res[0] return answer class PreTrainedTokenizerFastWrapper(PreTrainedTokenizerFast): def __init__(self, **kwargs): super().__init__(**kwargs) self.eot_token = "<|eot_id|>" self.im_start = "" self.im_end = "" self.ref_start = "" self.ref_end = "" self.box_start = "" self.box_end = "" self.quad_start = "" self.quad_end = "" self.slice_start = "" self.slice_end = "" @property def eos_id(self): return self.eos_token_id @property def bos_id(self): return self.bos_token_id @property def unk_id(self): return self.unk_token_id @property def eot_id(self): return self.convert_tokens_to_ids(self.eot_token) @property def im_start_id(self): return self.convert_tokens_to_ids(self.im_start) @property def im_end_id(self): return self.convert_tokens_to_ids(self.im_end) @staticmethod def escape(text: str) -> str: return text @staticmethod def unescape(text: str) -> str: return text def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"): items = [] if isinstance(orig_items[0][key], list): assert isinstance(orig_items[0][key][0], torch.Tensor) for it in orig_items: for tr in it[key]: items.append({key: tr}) else: assert isinstance(orig_items[0][key], torch.Tensor) items = orig_items batch_size = len(items) shape = items[0][key].shape dim = len(shape) assert dim <= 3 if max_length is None: max_length = 0 max_length = max(max_length, max(item[key].shape[-1] for item in items)) min_length = min(item[key].shape[-1] for item in items) dtype = items[0][key].dtype if dim == 1: return torch.cat([item[key] for item in items], dim=0) elif dim == 2: if max_length == min_length: return torch.cat([item[key] for item in items], dim=0) tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value else: tensor = ( torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value ) for i, item in enumerate(items): if dim == 2: if padding_side == "left": tensor[i, -len(item[key][0]) :] = item[key][0].clone() else: tensor[i, : len(item[key][0])] = item[key][0].clone() elif dim == 3: if padding_side == "left": tensor[i, -len(item[key][0]) :, :] = item[key][0].clone() else: tensor[i, : len(item[key][0]), :] = item[key][0].clone() return tensor def slice_image( image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False ): original_size = image.size original_width, original_height = original_size log_ratio = math.log(original_width / original_height) ratio = original_width * original_height / (scale_resolution * scale_resolution) multiple = min(math.ceil(ratio), max_slice_nums) source_image = None best_grid = None patches = [] if multiple <= 1 or never_split: # dont need to slice, upsample best_size = find_best_resize( original_size, scale_resolution, patch_size, allow_upscale=True ) source_image = image.resize(best_size, Image.Resampling.BICUBIC) else: candidate_split_grids_nums = [] for i in [multiple - 1, multiple, multiple + 1]: if i == 1 or i > max_slice_nums: continue candidate_split_grids_nums.append(i) # source image, down-sampling and ensure divided by patch_size best_resize = find_best_resize(original_size, scale_resolution, patch_size) source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) candidate_grids = [] # find best grid for split_grids_nums in candidate_split_grids_nums: m = 1 while m <= split_grids_nums: if split_grids_nums % m == 0: candidate_grids.append([m, split_grids_nums // m]) m += 1 best_grid = [1, 1] min_error = float("inf") for grid in candidate_grids: error = abs(log_ratio - math.log(grid[0] / grid[1])) if error < min_error: best_grid = grid min_error = error refine_size = get_refine_size( original_size, best_grid, scale_resolution, patch_size, allow_upscale=True ) refine_image = image.resize(refine_size, Image.Resampling.BICUBIC) patches = split_to_patches(refine_image, best_grid) return source_image, patches, best_grid def ensure_divide(length, patch_size): return max(round(length / patch_size) * patch_size, patch_size) def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False): width, height = original_size if (width * height > scale_resolution * scale_resolution) or allow_upscale: r = width / height height = int(scale_resolution / math.sqrt(r)) width = int(height * r) best_width = ensure_divide(width, patch_size) best_height = ensure_divide(height, patch_size) return (best_width, best_height) def get_refine_size( original_size, grid, scale_resolution, patch_size, allow_upscale=False ): width, height = original_size grid_x, grid_y = grid refine_width = ensure_divide(width, grid_x) refine_height = ensure_divide(height, grid_y) grid_width = refine_width / grid_x grid_height = refine_height / grid_y best_grid_size = find_best_resize( (grid_width, grid_height), scale_resolution, patch_size, allow_upscale=allow_upscale, ) refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) return refine_size def split_to_patches(image, grid): patches = [] width, height = image.size grid_x = int(width / grid[0]) grid_y = int(height / grid[1]) for i in range(0, height, grid_y): images = [] for j in range(0, width, grid_x): box = (j, i, j + grid_x, i + grid_y) patch = image.crop(box) images.append(patch) patches.append(images) return patches def get_grid_placeholder(tokenizer, grid, query_num): image_placeholder = ( tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end ) cols = grid[0] rows = grid[1] slices = [] for i in range(rows): lines = [] for j in range(cols): lines.append(image_placeholder) slices.append("".join(lines)) slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end return slice_placeholder