import os import torch import torch.utils.checkpoint from torch import nn from torch.cuda.amp import autocast as autocast from typing import Optional from modeling_internvideo2_vit import pretrain_internvideo2_giant_patch14_224_clean from modeling_qformer import build_qformer # from .flash_attention_class import FlashAttention from model_config import VideoChat2Config from transformers import AutoTokenizer,AutoModel, AutoConfig, PreTrainedModel, PretrainedConfig import logging logger = logging.getLogger(__name__) token = os.environ['HF_TOKEN'] IMG_TOKEN = "[]" VID_TOKEN = "[]" DEFAULT_PAD_TOKEN = "[PAD]" DEFAULT_BOS_TOKEN = '' DEFAULT_EOS_TOKEN = '' DEFAULT_UNK_TOKEN = "" DEFAULT_IMAGE_TOKEN = "[IMAGETOKEN]" DEFAULT_VIDEO_TOKEN = "[VIDEOTOKEN]" DEFAULT_IMG_PLACEHOLDER = "[]" DEFAULT_VID_PLACEHOLDER = "[]" def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def freeze_module(module): for _, param in module.named_parameters(): param.requires_grad = False module = module.eval() module.train = disabled_train return module class InternVideo2_Classification(PreTrainedModel): config_class = VideoChat2Config def __init__(self, config): self.model_config = config.model_config # config.model_config = None super().__init__(config) self.build_vision_encoder() self.build_llm() self.build_bridge() # NOTE place it after freeze llm for n, p in self.named_parameters(): if p.requires_grad: logger.info(f'{n} requires_grad') def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, image: Optional[torch.Tensor] = None, video: Optional[torch.Tensor] = None, instruction = None, video_idx = None, image_idx = None, ): if self.use_vision_regression_loss: text_embeds, visual, visual_idx = self.pad_text_embeds(input_ids=input_ids, image=image,video=video, return_visual=True, video_idx=video_idx, image_idx=image_idx, instruction = instruction) else: text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, return_visual=False, video_idx=video_idx, image_idx=image_idx, instruction = instruction) outputs = self.lm( inputs_embeds=text_embeds, attention_mask=attention_mask, labels=labels, output_hidden_states=True, return_dict=True, ) return outputs def build_vision_encoder(self): # load pretrained internvideo2-1b here, simplified as it receives no args # note that we haven't load the internvideo pretrained version if 'internvideo2' in self.model_config.vision_encoder.name.lower(): encoder_name = self.model_config.vision_encoder.name logger.info(f"Build vision_encoder: {encoder_name}") if encoder_name == 'internvideo2-1B': self.vision_encoder = pretrain_internvideo2_giant_patch14_224_clean(self.model_config) else: raise ValueError(f"Not implemented: {encoder_name}") else: raise NotImplementedError(self.model_config.vision_encoder.name) if self.model_config.vision_encoder.vit_add_ln: self.vision_layernorm = nn.LayerNorm(self.model_config.vision_encoder.encoder_embed_dim, eps=1e-12) else: self.vision_layernorm = nn.Identity() self.freeze_vision_encoder = self.model_config.get("freeze_vision_encoder", False) if self.freeze_vision_encoder: logger.info("freeze vision encoder") freeze_module(self.vision_encoder) freeze_module(self.vision_layernorm) def build_bridge(self): # ViT to LM: 1792 -> 6656 NOTE 768 is qformer dim self.project_up = nn.Linear(768, self.lm.config.hidden_size) # whether bias is needed? # LM to ViT: 6656 -> 1792 self.project_down = nn.Linear(self.lm.config.hidden_size, 768) if 'qformer' in self.model_config.bridge.name.lower(): from transformers import BertTokenizer self.qformer_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="left") self.qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"}) self.qformer_tokenizer.padding_side = "left" if self.model_config.bridge.name == 'qformer': self.qformer, self.query_tokens = build_qformer( self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim, qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob, qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob, qformer_drop_path_rate=self.model_config.bridge.qformer_drop_path_rate, ) self.qformer.resize_token_embeddings(len(self.qformer_tokenizer)) self.qformer.cls = None self.extra_num_query_token = self.model_config.bridge.extra_num_query_token if self.model_config.bridge.extra_num_query_token > 0: logger.info(f"Add extra {self.model_config.bridge.extra_num_query_token} tokens in QFormer") self.extra_query_tokens = nn.Parameter( torch.zeros(1, self.model_config.bridge.extra_num_query_token, self.query_tokens.shape[-1]) ) self.freeze_bridge = self.model_config.get("freeze_bridge", False) if self.freeze_bridge: logger.info("freeze bridge") freeze_module(self.qformer) self.query_tokens.requires_grad = False def build_llm(self): self.lm_name = self.model_config.llm.name if self.model_config.llm.name == 'mistral_7b': from transformers import AutoModelForSequenceClassification config = AutoConfig.from_pretrained( self.model_config.llm.pretrained_llm_path, torch_dtype=torch.bfloat16, token=token, # attn_implementation="flash_attention_2", ) self.lm = AutoModelForSequenceClassification.from_config(config) elif self.model_config.llm.name == 'internlm_20b': from transformers import AutoModelForSequenceClassification self.lm = AutoModelForSequenceClassification.from_pretrained( self.model_config.llm.pretrained_llm_path, torch_dtype=torch.bfloat16, trust_remote_code=True, ) self.lm.gradient_checkpointing = True self.lm._set_gradient_checkpointing() elif self.model_config.llm.name == 'internlm2_5_7b': from transformers import AutoModelForSequenceClassification self.lm = AutoModelForSequenceClassification.from_pretrained( self.model_config.llm.pretrained_llm_path, torch_dtype=torch.bfloat16, trust_remote_code=True, local_files_only=True, ) else: raise NotImplementedError(self.model_config.llm.name) self.freeze_llm = self.model_config.get("freeze_llm", True) logger.info(f'freeze_llm: {self.freeze_llm}') if self.freeze_llm: logger.info("freeze llm") freeze_module(self.lm) if self.model_config.llm.use_lora: self.use_lora = True from peft import get_peft_model, LoraConfig, TaskType logger.info("Use lora") if self.model_config.llm.name == 'internlm_20b': peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout, target_modules=['wqkv', 'wo', 'w1', 'w2', 'w3', 'output'] ) else: peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"] ) self.lm = get_peft_model(self.lm, peft_config) self.lm.enable_input_require_grads() self.lm.print_trainable_parameters() else: self.use_lora = False def build_conversation(self,instruction, user_prompt,media_type='video',msg=''): conversation = "" if instruction: conversation += instruction conversation += ("[INST]" + " ") if media_type == 'image': conversation +=( "" + IMG_TOKEN + "")#*ilen else: conversation += ("")#*ilen conversation += (msg.rstrip() + "[/INST]") conversation += (" [INST] " + user_prompt + " [/INST]") conversation += ("") return conversation def pad_text_embeds( self, input_ids: torch.LongTensor = None, image: Optional[torch.Tensor] = None, video: Optional[torch.Tensor] = None, image_idx = None, video_idx = None, return_visual: bool = False, instruction = None, ): # text_embeds text_embeds = self.lm.get_input_embeddings()(input_ids.long()).detach() visual = None visual_idx = None if image is not None: B, T, C, H, W = image.shape image = image.permute(0, 2, 1, 3, 4) prompt_image_embeds = self.encode_vision(image, instruction=instruction) visual = prompt_image_embeds prompt_image_embeds = self.project_up(prompt_image_embeds) prompt_image_embeds = prompt_image_embeds.view(-1, prompt_image_embeds.shape[-1]) visual_idx = image_idx text_embeds[image_idx == 1] = text_embeds[image_idx == 1] * 0 + prompt_image_embeds.to(text_embeds.device) elif video is not None: if len(video.shape) == 5: B, T, C, H, W = video.shape N = 1 else: B, N, T, C, H, W = video.shape video = video.reshape(B*N, T, C, H, W).permute(0, 2, 1, 3, 4) prompt_video_embeds = self.encode_vision(video, instruction=instruction) visual = prompt_video_embeds prompt_video_embeds = self.project_up(prompt_video_embeds) prompt_video_embeds = prompt_video_embeds.view(-1, prompt_video_embeds.shape[-1]) visual_idx = video_idx text_embeds[video_idx == 1] = text_embeds[video_idx == 1] * 0 + prompt_video_embeds.to(text_embeds.device).to(text_embeds.dtype) else: logger.warn(f"don't get visual input, input_ids: {input_ids}") if return_visual: return text_embeds, visual, visual_idx return text_embeds def encode_vision( self, image, instruction ): device = image.device B = image.shape[0] T = image.shape[2] use_image = True if T == 1 else False image_embeds = self.vision_encoder(image, use_image=use_image) C = image_embeds.shape[-1] image_embeds = image_embeds.reshape(B, -1, C) image_embeds = self.vision_layernorm(image_embeds).to(device) # [B, T*L, C] image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) if self.extra_num_query_token > 0: query_tokens = torch.cat([self.query_tokens, self.extra_query_tokens], dim=1) query_tokens = query_tokens.expand(image_embeds.shape[0], -1, -1) if instruction is not None: text_Qformer = self.qformer_tokenizer( instruction, padding='longest', truncation=True, max_length=512, return_tensors="pt", ).to(image_embeds.device) query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device) Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1) query_output = self.qformer.bert( text_Qformer.input_ids, attention_mask=Qformer_atts, query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) else: query_output = self.qformer.bert( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) return query_output.last_hidden_state[:, :query_tokens.size(1), :] def build_input_ids( self, tokenizer, conversation, max_length, add_special_tokens, truncation, image = None, video = None, padding = "longest", return_tensors = "pt", image_placeholder: str = DEFAULT_IMG_PLACEHOLDER, video_placeholder: str = DEFAULT_VID_PLACEHOLDER, ): input_ids = [] indexs = [] attention_mask = [] start, total_len = 0, 0 while True: index1 = conversation.find(image_placeholder, start) index2 = conversation.find(video_placeholder, start) if index1 == -1 and index2 == -1: index = -1 elif index1 == -1: index = index2 elif index2 == -1: index = index1 else: index = min(index1, index2) assert index != -1 if index == -1: inputs = tokenizer(conversation[start:], max_length=max_length-total_len, truncation=truncation, padding=padding, return_tensors=return_tensors) else: inputs = tokenizer(conversation[start:index], max_length=max_length, truncation=truncation, padding='longest', return_tensors=return_tensors) input_ids += inputs.input_ids attention_mask += inputs.attention_mask total_len += inputs.input_ids[0].shape[0] indexs += torch.zeros_like(inputs.input_ids) if index != -1: input_ids += [torch.zeros(96).long()] attention_mask += [torch.ones(96).long()] indexs += [torch.ones(96)] if index == -1: return { 'input_ids': torch.cat(input_ids), 'attention_mask': torch.cat(attention_mask), 'index': torch.cat(indexs).to(torch.bool), } start = index + len(DEFAULT_IMG_PLACEHOLDER) @property def dtype(self): return self.lm.dtype @property def device(self): return self.lm.device class InternVideo2_Classification_test(PreTrainedModel): config_class = VideoChat2Config def __init__(self, config): super().__init__(config) self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) self.model_config = config.model_config self.build_bridge() def forward(self, x): x = self.conv1(x) return self.conv2(x) def test_lol(self, x): return x def build_bridge(self): if 'qformer' in self.model_config.bridge.name.lower(): from transformers import BertTokenizer self.qformer_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="left") self.qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"}) self.qformer_tokenizer.padding_side = "left" if self.model_config.bridge.name == 'qformer': self.qformer, self.query_tokens = build_qformer( self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim, qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob, qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob, qformer_drop_path_rate=self.model_config.bridge.qformer_drop_path_rate, ) self.qformer.resize_token_embeddings(len(self.qformer_tokenizer)) self.qformer.cls = None self.extra_num_query_token = self.model_config.bridge.extra_num_query_token if self.model_config.bridge.extra_num_query_token > 0: logger.info(f"Add extra {self.model_config.bridge.extra_num_query_token} tokens in QFormer") self.extra_query_tokens = nn.Parameter( torch.zeros(1, self.model_config.bridge.extra_num_query_token, self.query_tokens.shape[-1]) ) self.freeze_bridge = self.model_config.get("freeze_bridge", False) if self.freeze_bridge: logger.info("freeze bridge") freeze_module(self.qformer) self.query_tokens.requires_grad = False if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained('OpenGVLab/InternVideo2-Chat-8B',trust_remote_code=True,use_fast=False) config = AutoConfig.from_pretrained('OpenGVLab/InternVideo2-Chat-8B', torch_dtype=torch.bfloat16,trust_remote_code=True) model = InternVideo2_Classification(config).cuda() B, T, C, H, W = 1, 8, 3, 224, 224 video_tensor = torch.randn(B,T,C,H,W).cuda() user_prompt = "this is a user prompt" instruction = "this is an instruction" conversation = model.build_conversation(instruction=instruction, user_prompt=user_prompt, media_type='video') tokenized = model.build_input_ids(tokenizer,conversation,max_length=248,add_special_tokens=True,truncation=False,padding=False,return_tensors='pt') input_ids = tokenized['input_ids'].unsqueeze(0).to(model.device) attn_mask = tokenized['attention_mask'].unsqueeze(0).to(model.device) indexes = tokenized['index'].unsqueeze(0) text_embeds = model.pad_text_embeds(input_ids = input_ids,video = video_tensor,video_idx = indexes) outputs = model.lm(inputs_embeds=text_embeds, attention_mask=attn_mask,output_hidden_states=True,return_dict=True)