from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from transformers import PreTrainedModel from transformers.activations import ACT2FN from transformers.cache_utils import Cache from transformers.modeling_outputs import ModelOutput from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, AutoConfig from .configuration_wemm import WeMMConfig from .vision_model import Idefics2VisionTransformer from .connector import Idefics2Connector from .image_processor import Idefics2ImageProcessor from .modeling_downsampler import DownsamplerModel from .modeling_projector import ProjectorModel from .modeling_internlm2 import InternLM2ForCausalLM from .tokenization_internlm2 import InternLM2Tokenizer from peft import PeftModel from peft import PeftConfig import os from PIL import Image import numpy as np IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" IGNORE_INDEX = -100 from transformers import StoppingCriteria from transformers import PreTrainedTokenizerFast, StoppingCriteriaList import torch.nn.functional as F class StopWordStoppingCriteria(StoppingCriteria): """StopWord stopping criteria.""" def __init__(self, tokenizer, stop_word): self.tokenizer = tokenizer self.stop_word = stop_word self.length = len(self.stop_word) def __call__(self, input_ids, *args, **kwargs) -> bool: cur_text = self.tokenizer.decode(input_ids[0]) cur_text = cur_text.replace('\r', '').replace('\n', '') return cur_text[-self.length:] == self.stop_word def get_stop_criteria( tokenizer, stop_words=[], ): stop_criteria = StoppingCriteriaList() for word in stop_words: stop_criteria.append(StopWordStoppingCriteria(tokenizer, word)) return stop_criteria def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H, W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H, W, D/2) emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = np.squeeze(pos) # (1, H, W) -> (H, W) out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product emb_sin = np.sin(out) # (H, W, D/2) emb_cos = np.cos(out) # (H, W, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) return emb # 2D sine-cosine position embedding # References: # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py # MoCo v3: https://github.com/facebookresearch/moco-v3 # -------------------------------------------------------- def get_2d_sincos_pos_embed(embed_dim, grid_size_h, grid_size_w, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size_h, dtype=np.float32) grid_w = np.arange(grid_size_w, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def recover_navit_subimages_with_pos_emb( sub_image_hidden_states, attention_mask, num_sub_images, visual_embedding_group, pos_hidden_size, thumbnail_only=False): _slice = int(np.sqrt(num_sub_images)) N, L, D = sub_image_hidden_states.shape _, H, W = attention_mask.shape if thumbnail_only is True: num_sub_images += 1 sub_image_hidden_states = sub_image_hidden_states.reshape(-1, num_sub_images, H, W, D) attention_mask = attention_mask.reshape(-1, num_sub_images, H, W) if thumbnail_only is True: sub_image_hidden_states = sub_image_hidden_states[:, -1:, :, :, :] attention_mask = attention_mask[:, -1:, :, :] _slice = 1 def _infer_ori_image_patch_shape(sub_image_attention_mask): ind_h, ind_w = torch.where(sub_image_attention_mask > 0) return torch.max(ind_h) + 1, torch.max(ind_w) + 1 def _pad_to_same(image_hidden): _dtype = image_hidden.dtype visual_downsample_stride = int(np.sqrt(visual_embedding_group)) full_h, full_w, _ = image_hidden.shape target_h, target_w = H * _slice, W * _slice # ensure all contents are included during downsampling to_pad_h = (target_h - full_h) + ( visual_downsample_stride - target_h % visual_downsample_stride) % visual_downsample_stride to_pad_w = (target_w - full_w) + ( visual_downsample_stride - target_w % visual_downsample_stride) % visual_downsample_stride # (H,W,D) -> (1,D,H,W) to support replicate padding image_hidden = image_hidden.permute(2, 0, 1).unsqueeze(0) pad_size = (0, to_pad_w, 0, to_pad_h) # (1,D,H,W) -> (H,W,D) image_hidden = F.pad(image_hidden.to(torch.float32), pad_size, mode='replicate').squeeze(0).permute(1, 2, 0) return image_hidden.to(_dtype) image_hidden_states = list() valid_image_token = list() image_2d_pos = list() for batch_id in range(len(sub_image_hidden_states)): ori_h, ori_w = _infer_ori_image_patch_shape(attention_mask[batch_id][0]) full_h, full_w = ori_h * _slice, ori_w * _slice # (S,H,W,D) -> (S_h,S_w,H,W,D) -> (S_h,H,S_w,W,D) -> (S_h*H,S_w*W,D) this_image_hidden = sub_image_hidden_states[batch_id][:, 0:ori_h, 0:ori_w, :] \ .view(_slice, _slice, ori_h, ori_w, D).permute(0, 2, 1, 3, 4).contiguous().view(full_h, full_w, D) pos_emb = get_2d_sincos_pos_embed(pos_hidden_size, grid_size_h=full_h, grid_size_w=full_w) # (H, W, D) pos_emb = torch.tensor(pos_emb, dtype=this_image_hidden.dtype, device=this_image_hidden.device) image_hidden_states.append(_pad_to_same(this_image_hidden)) image_2d_pos.append(_pad_to_same(pos_emb)) valid_image_token.append([full_h, full_w]) image_hidden_states = torch.stack(image_hidden_states) image_2d_pos = torch.stack(image_2d_pos) valid_image_token = torch.tensor(valid_image_token, dtype=torch.int64) return image_hidden_states, image_2d_pos, valid_image_token def visiual_token_downsample( visual_downsampler, image_hidden_states, valid_image_token, visual_embedding_group, image_2d_pos): if image_2d_pos is not None: image_hidden_states = image_hidden_states + image_2d_pos image_hidden_states = visual_downsampler(image_hidden_states) valid_image_token = torch.ceil(valid_image_token / np.sqrt(visual_embedding_group)).to(torch.int64) return image_hidden_states, valid_image_token def merge_native_qformer( clip_embeddings_native_patch, valid_image_token_shape, clip_embeddings_qformer, visual_source_spliter, num_sub_images): assert clip_embeddings_native_patch.size(0) == valid_image_token_shape.size(0) == clip_embeddings_qformer.size(0) def add_split_token_for_qformer_token(qformer_emb): # + 1 for thumbnail len_per_token = int(qformer_emb.size(0) // (num_sub_images + 1)) qformer_emb_with_spliter = list() for i in range(num_sub_images + 1): qformer_emb_with_spliter.append( visual_source_spliter(torch.tensor([2 * i]).to(visual_source_spliter.weight.device)) ) qformer_emb_with_spliter.append(qformer_emb[i * len_per_token:(i + 1) * len_per_token]) qformer_emb_with_spliter.append( visual_source_spliter(torch.tensor([2 * i + 1]).to(visual_source_spliter.weight.device)) ) return torch.cat(qformer_emb_with_spliter, dim=0) merged_visual_embeddings = list() for batch_id in range(clip_embeddings_native_patch.size(0)): h, w = valid_image_token_shape[batch_id] native_patch_emb = clip_embeddings_native_patch[batch_id][:h, :w, :].reshape(h*w, -1) qformer_emb = clip_embeddings_qformer[batch_id] qformer_emb = add_split_token_for_qformer_token(qformer_emb) merged_visual_embeddings.append( torch.cat( [visual_source_spliter(torch.tensor([10]).to(visual_source_spliter.weight.device)), native_patch_emb, visual_source_spliter(torch.tensor([11]).to(visual_source_spliter.weight.device)), qformer_emb], dim=0)) return merged_visual_embeddings class WemmForConditionalGeneration(PreTrainedModel): config_class = WeMMConfig def __init__(self, config: WeMMConfig): super().__init__(config) self.vision_tower = Idefics2VisionTransformer(config.vision_config) self.image_processor = Idefics2ImageProcessor(config.image_processor) self.connector = Idefics2Connector(config.connector_config) self.projector = ProjectorModel(config.projector_config) self.language_model = InternLM2ForCausalLM(config.text_config) self.tokenizer = AutoTokenizer.from_pretrained("internlm/internlm2-chat-7b", trust_remote_code=True, encode_special_tokens=True) self.downsampler = DownsamplerModel(config.downsampler_config) self.visual_source_spliter_emb = torch.nn.Embedding(**config.spliter_emb_config) self.gen_config = GenerationConfig( max_new_tokens=512, do_sample=False, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id, ) self.do_image_splitting = config.do_image_splitting self.stop_criteria = get_stop_criteria( tokenizer=self.tokenizer, stop_words=['<|im_end|>']) self.config = config def mm_generate(self, image_path, prompt, gen_config=None): prompt = "" + '\n' + prompt prompt = f"<|im_start|>user\n{prompt}<|im_end|><|im_start|>assistant\n" image = Image.open(image_path).convert('RGB') navit980_images = self.image_processor([[image]], return_tensors="pt", do_image_splitting=self.do_image_splitting) batch_size_navit = navit980_images['pixel_values'].shape[0] navit_pixel_values = navit980_images['navit_pixel_values'].cuda() navit_patch_attention_mask = navit980_images["pixel_attention_mask"].cuda() clip_visual_outputs = self.vision_tower(pixel_values=navit_pixel_values,patch_attention_mask=navit_patch_attention_mask,).last_hidden_state super_image_hidden_states, image_2d_pos, valid_image_token_shape = \ recover_navit_subimages_with_pos_emb( clip_visual_outputs, navit_patch_attention_mask, num_sub_images=4, visual_embedding_group=1, pos_hidden_size=4096, thumbnail_only=True ) clip_embeddings_native_patch, valid_image_token_shape = visiual_token_downsample( self.downsampler, super_image_hidden_states, valid_image_token_shape, visual_embedding_group=1, image_2d_pos=None ) clip_embeddings_qformer = self.connector(clip_visual_outputs, attention_mask=navit_patch_attention_mask.view(navit_pixel_values.size(0), -1)) hidden_size = clip_embeddings_qformer.shape[-1] clip_embeddings_qformer = clip_embeddings_qformer.view(batch_size_navit, -1, hidden_size) clip_embeddings_qformer = self.projector(clip_embeddings_qformer) merged_visual_embeddings = \ merge_native_qformer( clip_embeddings_native_patch, valid_image_token_shape, clip_embeddings_qformer, visual_source_spliter=self.visual_source_spliter_emb, num_sub_images=4 ) chunk_encode = [] for idx, chunk in enumerate(prompt.split(DEFAULT_IMAGE_TOKEN)): if idx == 0: cur_encode = self.tokenizer.encode(chunk) else: cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False) chunk_encode.append(cur_encode) assert len(chunk_encode) == 2 ids = [] for idx, cur_chunk_encode in enumerate(chunk_encode): ids.extend(cur_chunk_encode) if idx != len(chunk_encode) - 1: ids.append(IMAGE_TOKEN_INDEX) ids = torch.tensor(ids).cuda().unsqueeze(0) pixel_values = None mm_inputs = self.prepare_inputs_labels_for_multimodal( llm=self.language_model, input_ids=ids, pixel_values=pixel_values, clip_embeddings=merged_visual_embeddings) generate_output = self.language_model.generate( **mm_inputs, generation_config=gen_config if gen_config is not None else self.gen_config, streamer=None, bos_token_id=self.tokenizer.bos_token_id, stopping_criteria=self.stop_criteria ) predict = self.tokenizer.decode( generate_output[0], skip_special_tokens=True).strip() return predict def get_valid_visual_embedding(self, embedding, valid_token_shape): if valid_token_shape is None: return embedding h, w = valid_token_shape return embedding[:h, :w, :].reshape(h*w, -1) # Modified from https://github.com/haotian-liu/LLaVA/blob/82fc5e0e5f4393a4c26851fa32c69ab37ea3b146/llava/model/llava_arch.py#L99 # noqa: E501 def prepare_inputs_labels_for_multimodal( self, llm: PreTrainedModel, input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, clip_embeddings: Optional[torch.FloatTensor] = None, hard_coded_max_len: Optional[int] = None, **kwargs): if pixel_values is None and clip_embeddings is None: return { 'input_ids': input_ids, 'position_ids': position_ids, 'attention_mask': attention_mask, 'past_key_values': past_key_values, 'inputs_embeds': None, 'labels': labels } valid_image_token_shape = kwargs.get('valid_image_token_shape', None) _labels = labels _position_ids = position_ids _attention_mask = attention_mask if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) else: attention_mask = attention_mask.bool() if position_ids is None: position_ids = torch.arange( 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) if labels is None: labels = torch.full_like(input_ids, IGNORE_INDEX) # remove the padding using attention_mask -- TODO: double check input_ids = [ cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) ] labels = [ cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask) ] new_inputs_embeds = [] new_labels = [] new_img_masks = [] cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() if num_images == 0: cur_pixel_values = pixel_values[cur_image_idx] if pixel_values is not None else None cur_clip_emb = self.get_valid_visual_embedding(clip_embeddings[cur_image_idx], valid_image_token_shape[cur_image_idx]) if clip_embeddings is not None else None cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids) if cur_clip_emb is not None and cur_pixel_values is not None: cur_inputs_embeds = torch.cat( [cur_inputs_embeds_1, cur_pixel_values[0:0], cur_clip_emb[0:0]], dim=0) elif cur_pixel_values is not None: cur_inputs_embeds = torch.cat( [cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0) elif cur_clip_emb is not None: cur_inputs_embeds = torch.cat( [cur_inputs_embeds_1, cur_clip_emb[0:0]], dim=0) else: raise ValueError new_inputs_embeds.append(cur_inputs_embeds) new_labels.append(labels[batch_idx]) new_img_masks.append(torch.zeros( cur_inputs_embeds.shape[0], device=cur_inputs_embeds.device).bool()) cur_image_idx += 1 continue image_token_indices = [-1] + torch.where( cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [ cur_input_ids.shape[0] ] cur_input_ids_noim = [] cur_labels = labels[batch_idx] cur_labels_noim = [] for i in range(len(image_token_indices) - 1): cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]]) cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]]) split_sizes = [x.shape[0] for x in cur_labels_noim] cur_inputs_embeds = llm.get_input_embeddings()( torch.cat(cur_input_ids_noim)) cur_inputs_embeds_no_im = torch.split( cur_inputs_embeds, split_sizes, dim=0) cur_new_inputs_embeds = [] cur_new_labels = [] cur_img_masks = [] for i in range(num_images + 1): cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i]) cur_new_labels.append(cur_labels_noim[i]) cur_img_masks.append(torch.zeros( cur_inputs_embeds_no_im[i].shape[0], device=cur_inputs_embeds_no_im[i].device).bool()) if i < num_images: cur_pixel_values = pixel_values[cur_image_idx] if pixel_values is not None else None if(valid_image_token_shape is not None): cur_clip_emb = \ self.get_valid_visual_embedding(clip_embeddings[cur_image_idx], valid_image_token_shape[cur_image_idx]) \ if clip_embeddings is not None else None else: cur_clip_emb = clip_embeddings[cur_image_idx] if clip_embeddings is not None else None cur_image_idx += 1 # discrete token embeddings if cur_pixel_values is not None: cur_new_inputs_embeds.append(cur_pixel_values) cur_img_masks.append(torch.ones( cur_pixel_values.shape[0], device=cur_pixel_values.device).bool()) cur_new_labels.append( torch.full((cur_pixel_values.shape[0], ), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) # clip embeddings if cur_clip_emb is not None: cur_new_inputs_embeds.append(cur_clip_emb) cur_img_masks.append(torch.zeros( cur_clip_emb.shape[0], device=cur_clip_emb.device).bool()) cur_new_labels.append( torch.full((cur_clip_emb.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds) cur_new_labels = torch.cat(cur_new_labels) cur_img_masks = torch.cat(cur_img_masks) new_inputs_embeds.append(cur_new_inputs_embeds) new_labels.append(cur_new_labels) new_img_masks.append(cur_img_masks) # Combine them max_len = max(x.shape[0] for x in new_inputs_embeds) if hard_coded_max_len is not None: max_len = min(max_len, hard_coded_max_len) batch_size = len(new_inputs_embeds) new_inputs_embeds_padded = [] new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) new_img_masks_padded = torch.zeros((batch_size, max_len), device=new_img_masks[0].device).bool() for i, (cur_new_embed, cur_new_labels, cur_new_img_masks) in enumerate(zip(new_inputs_embeds, new_labels, new_img_masks)): cur_new_embed = cur_new_embed[:max_len] cur_new_labels = cur_new_labels[:max_len] cur_new_img_masks = cur_new_img_masks[:max_len] cur_len = cur_new_embed.shape[0] new_inputs_embeds_padded.append( torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) if cur_len > 0: new_labels_padded[i, :cur_len] = cur_new_labels attention_mask[i, :cur_len] = True position_ids[i, :cur_len] = torch.arange( 0, cur_len, dtype=position_ids.dtype, device=position_ids.device) new_img_masks_padded[i, :cur_len] = cur_new_img_masks new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0) if _labels is None: new_labels = None else: new_labels = new_labels_padded if _attention_mask is None: attention_mask = None else: attention_mask = attention_mask.to(dtype=_attention_mask.dtype) if _position_ids is None: position_ids = None prepared_data = { 'input_ids': None, 'position_ids': position_ids, 'attention_mask': attention_mask, 'past_key_values': past_key_values, 'inputs_embeds': new_inputs_embeds, 'labels': new_labels, } if pixel_values is not None: prepared_data.update({'im_mask': new_img_masks_padded}) return prepared_data AutoConfig.register("wemm_hf", WeMMConfig) AutoModel.register(WeMMConfig, WemmForConditionalGeneration)