InternVideo2-Cls-8B-1 / modeling_videochat2_classification.py
morpheushoc's picture
Upload InternVideo2_Classification_test
a9e49d0 verified
import os
import torch
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast as autocast
from typing import Optional
from modeling_internvideo2_vit import pretrain_internvideo2_giant_patch14_224_clean
from modeling_qformer import build_qformer
# from .flash_attention_class import FlashAttention
from model_config import VideoChat2Config
from transformers import AutoTokenizer,AutoModel, AutoConfig, PreTrainedModel, PretrainedConfig
import logging
logger = logging.getLogger(__name__)
token = os.environ['HF_TOKEN']
IMG_TOKEN = "[<IMG_PLH>]"
VID_TOKEN = "[<VID_PLH>]"
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_BOS_TOKEN = '<s>'
DEFAULT_EOS_TOKEN = '</s>'
DEFAULT_UNK_TOKEN = "<unk>"
DEFAULT_IMAGE_TOKEN = "[IMAGETOKEN]"
DEFAULT_VIDEO_TOKEN = "[VIDEOTOKEN]"
DEFAULT_IMG_PLACEHOLDER = "[<IMG_PLH>]"
DEFAULT_VID_PLACEHOLDER = "[<VID_PLH>]"
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def freeze_module(module):
for _, param in module.named_parameters():
param.requires_grad = False
module = module.eval()
module.train = disabled_train
return module
class InternVideo2_Classification(PreTrainedModel):
config_class = VideoChat2Config
def __init__(self, config):
self.model_config = config.model_config
# config.model_config = None
super().__init__(config)
self.build_vision_encoder()
self.build_llm()
self.build_bridge()
# NOTE place it after freeze llm
for n, p in self.named_parameters():
if p.requires_grad:
logger.info(f'{n} requires_grad')
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
image: Optional[torch.Tensor] = None,
video: Optional[torch.Tensor] = None,
instruction = None,
video_idx = None,
image_idx = None,
):
if self.use_vision_regression_loss:
text_embeds, visual, visual_idx = self.pad_text_embeds(input_ids=input_ids, image=image,video=video, return_visual=True, video_idx=video_idx, image_idx=image_idx, instruction = instruction)
else:
text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, return_visual=False, video_idx=video_idx, image_idx=image_idx, instruction = instruction)
outputs = self.lm(
inputs_embeds=text_embeds,
attention_mask=attention_mask,
labels=labels,
output_hidden_states=True,
return_dict=True,
)
return outputs
def build_vision_encoder(self):
# load pretrained internvideo2-1b here, simplified as it receives no args
# note that we haven't load the internvideo pretrained version
if 'internvideo2' in self.model_config.vision_encoder.name.lower():
encoder_name = self.model_config.vision_encoder.name
logger.info(f"Build vision_encoder: {encoder_name}")
if encoder_name == 'internvideo2-1B':
self.vision_encoder = pretrain_internvideo2_giant_patch14_224_clean(self.model_config)
else:
raise ValueError(f"Not implemented: {encoder_name}")
else:
raise NotImplementedError(self.model_config.vision_encoder.name)
if self.model_config.vision_encoder.vit_add_ln:
self.vision_layernorm = nn.LayerNorm(self.model_config.vision_encoder.encoder_embed_dim, eps=1e-12)
else:
self.vision_layernorm = nn.Identity()
self.freeze_vision_encoder = self.model_config.get("freeze_vision_encoder", False)
if self.freeze_vision_encoder:
logger.info("freeze vision encoder")
freeze_module(self.vision_encoder)
freeze_module(self.vision_layernorm)
def build_bridge(self):
# ViT to LM: 1792 -> 6656 NOTE 768 is qformer dim
self.project_up = nn.Linear(768, self.lm.config.hidden_size) # whether bias is needed?
# LM to ViT: 6656 -> 1792
self.project_down = nn.Linear(self.lm.config.hidden_size, 768)
if 'qformer' in self.model_config.bridge.name.lower():
from transformers import BertTokenizer
self.qformer_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="left")
self.qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"})
self.qformer_tokenizer.padding_side = "left"
if self.model_config.bridge.name == 'qformer':
self.qformer, self.query_tokens = build_qformer(
self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim,
qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob,
qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob,
qformer_drop_path_rate=self.model_config.bridge.qformer_drop_path_rate,
)
self.qformer.resize_token_embeddings(len(self.qformer_tokenizer))
self.qformer.cls = None
self.extra_num_query_token = self.model_config.bridge.extra_num_query_token
if self.model_config.bridge.extra_num_query_token > 0:
logger.info(f"Add extra {self.model_config.bridge.extra_num_query_token} tokens in QFormer")
self.extra_query_tokens = nn.Parameter(
torch.zeros(1, self.model_config.bridge.extra_num_query_token, self.query_tokens.shape[-1])
)
self.freeze_bridge = self.model_config.get("freeze_bridge", False)
if self.freeze_bridge:
logger.info("freeze bridge")
freeze_module(self.qformer)
self.query_tokens.requires_grad = False
def build_llm(self):
self.lm_name = self.model_config.llm.name
if self.model_config.llm.name == 'mistral_7b':
from transformers import AutoModelForSequenceClassification
config = AutoConfig.from_pretrained(
self.model_config.llm.pretrained_llm_path,
torch_dtype=torch.bfloat16,
token=token,
# attn_implementation="flash_attention_2",
)
self.lm = AutoModelForSequenceClassification.from_config(config)
elif self.model_config.llm.name == 'internlm_20b':
from transformers import AutoModelForSequenceClassification
self.lm = AutoModelForSequenceClassification.from_pretrained(
self.model_config.llm.pretrained_llm_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
self.lm.gradient_checkpointing = True
self.lm._set_gradient_checkpointing()
elif self.model_config.llm.name == 'internlm2_5_7b':
from transformers import AutoModelForSequenceClassification
self.lm = AutoModelForSequenceClassification.from_pretrained(
self.model_config.llm.pretrained_llm_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
local_files_only=True,
)
else:
raise NotImplementedError(self.model_config.llm.name)
self.freeze_llm = self.model_config.get("freeze_llm", True)
logger.info(f'freeze_llm: {self.freeze_llm}')
if self.freeze_llm:
logger.info("freeze llm")
freeze_module(self.lm)
if self.model_config.llm.use_lora:
self.use_lora = True
from peft import get_peft_model, LoraConfig, TaskType
logger.info("Use lora")
if self.model_config.llm.name == 'internlm_20b':
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=False,
r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
target_modules=['wqkv', 'wo', 'w1', 'w2', 'w3', 'output']
)
else:
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=False,
r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj", "lm_head"]
)
self.lm = get_peft_model(self.lm, peft_config)
self.lm.enable_input_require_grads()
self.lm.print_trainable_parameters()
else:
self.use_lora = False
def build_conversation(self,instruction, user_prompt,media_type='video',msg=''):
conversation = ""
if instruction:
conversation += instruction
conversation += ("[INST]" + " ")
if media_type == 'image':
conversation +=( "<Image>" + IMG_TOKEN + "</Image>")#*ilen
else:
conversation += ("<Video>" + VID_TOKEN + "</Video>")#*ilen
conversation += (msg.rstrip() + "[/INST]")
conversation += (" [INST] " + user_prompt + " [/INST]")
conversation += ("")
return conversation
def pad_text_embeds(
self,
input_ids: torch.LongTensor = None,
image: Optional[torch.Tensor] = None,
video: Optional[torch.Tensor] = None,
image_idx = None,
video_idx = None,
return_visual: bool = False,
instruction = None,
):
# text_embeds
text_embeds = self.lm.get_input_embeddings()(input_ids.long()).detach()
visual = None
visual_idx = None
if image is not None:
B, T, C, H, W = image.shape
image = image.permute(0, 2, 1, 3, 4)
prompt_image_embeds = self.encode_vision(image, instruction=instruction)
visual = prompt_image_embeds
prompt_image_embeds = self.project_up(prompt_image_embeds)
prompt_image_embeds = prompt_image_embeds.view(-1, prompt_image_embeds.shape[-1])
visual_idx = image_idx
text_embeds[image_idx == 1] = text_embeds[image_idx == 1] * 0 + prompt_image_embeds.to(text_embeds.device)
elif video is not None:
if len(video.shape) == 5:
B, T, C, H, W = video.shape
N = 1
else:
B, N, T, C, H, W = video.shape
video = video.reshape(B*N, T, C, H, W).permute(0, 2, 1, 3, 4)
prompt_video_embeds = self.encode_vision(video, instruction=instruction)
visual = prompt_video_embeds
prompt_video_embeds = self.project_up(prompt_video_embeds)
prompt_video_embeds = prompt_video_embeds.view(-1, prompt_video_embeds.shape[-1])
visual_idx = video_idx
text_embeds[video_idx == 1] = text_embeds[video_idx == 1] * 0 + prompt_video_embeds.to(text_embeds.device).to(text_embeds.dtype)
else:
logger.warn(f"don't get visual input, input_ids: {input_ids}")
if return_visual:
return text_embeds, visual, visual_idx
return text_embeds
def encode_vision(
self,
image,
instruction
):
device = image.device
B = image.shape[0]
T = image.shape[2]
use_image = True if T == 1 else False
image_embeds = self.vision_encoder(image, use_image=use_image)
C = image_embeds.shape[-1]
image_embeds = image_embeds.reshape(B, -1, C)
image_embeds = self.vision_layernorm(image_embeds).to(device) # [B, T*L, C]
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
if self.extra_num_query_token > 0:
query_tokens = torch.cat([self.query_tokens, self.extra_query_tokens], dim=1)
query_tokens = query_tokens.expand(image_embeds.shape[0], -1, -1)
if instruction is not None:
text_Qformer = self.qformer_tokenizer(
instruction,
padding='longest',
truncation=True,
max_length=512,
return_tensors="pt",
).to(image_embeds.device)
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device)
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
query_output = self.qformer.bert(
text_Qformer.input_ids,
attention_mask=Qformer_atts,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
else:
query_output = self.qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
return query_output.last_hidden_state[:, :query_tokens.size(1), :]
def build_input_ids(
self,
tokenizer,
conversation,
max_length,
add_special_tokens,
truncation,
image = None,
video = None,
padding = "longest",
return_tensors = "pt",
image_placeholder: str = DEFAULT_IMG_PLACEHOLDER,
video_placeholder: str = DEFAULT_VID_PLACEHOLDER,
):
input_ids = []
indexs = []
attention_mask = []
start, total_len = 0, 0
while True:
index1 = conversation.find(image_placeholder, start)
index2 = conversation.find(video_placeholder, start)
if index1 == -1 and index2 == -1:
index = -1
elif index1 == -1:
index = index2
elif index2 == -1:
index = index1
else:
index = min(index1, index2)
assert index != -1
if index == -1:
inputs = tokenizer(conversation[start:], max_length=max_length-total_len, truncation=truncation, padding=padding, return_tensors=return_tensors)
else:
inputs = tokenizer(conversation[start:index], max_length=max_length, truncation=truncation, padding='longest', return_tensors=return_tensors)
input_ids += inputs.input_ids
attention_mask += inputs.attention_mask
total_len += inputs.input_ids[0].shape[0]
indexs += torch.zeros_like(inputs.input_ids)
if index != -1:
input_ids += [torch.zeros(96).long()]
attention_mask += [torch.ones(96).long()]
indexs += [torch.ones(96)]
if index == -1:
return {
'input_ids': torch.cat(input_ids),
'attention_mask': torch.cat(attention_mask),
'index': torch.cat(indexs).to(torch.bool),
}
start = index + len(DEFAULT_IMG_PLACEHOLDER)
@property
def dtype(self):
return self.lm.dtype
@property
def device(self):
return self.lm.device
class InternVideo2_Classification_test(PreTrainedModel):
config_class = VideoChat2Config
def __init__(self, config):
super().__init__(config)
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
self.model_config = config.model_config
self.build_bridge()
def forward(self, x):
x = self.conv1(x)
return self.conv2(x)
def test_lol(self, x):
return x
def build_bridge(self):
if 'qformer' in self.model_config.bridge.name.lower():
from transformers import BertTokenizer
self.qformer_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="left")
self.qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"})
self.qformer_tokenizer.padding_side = "left"
if self.model_config.bridge.name == 'qformer':
self.qformer, self.query_tokens = build_qformer(
self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim,
qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob,
qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob,
qformer_drop_path_rate=self.model_config.bridge.qformer_drop_path_rate,
)
self.qformer.resize_token_embeddings(len(self.qformer_tokenizer))
self.qformer.cls = None
self.extra_num_query_token = self.model_config.bridge.extra_num_query_token
if self.model_config.bridge.extra_num_query_token > 0:
logger.info(f"Add extra {self.model_config.bridge.extra_num_query_token} tokens in QFormer")
self.extra_query_tokens = nn.Parameter(
torch.zeros(1, self.model_config.bridge.extra_num_query_token, self.query_tokens.shape[-1])
)
self.freeze_bridge = self.model_config.get("freeze_bridge", False)
if self.freeze_bridge:
logger.info("freeze bridge")
freeze_module(self.qformer)
self.query_tokens.requires_grad = False
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained('OpenGVLab/InternVideo2-Chat-8B',trust_remote_code=True,use_fast=False)
config = AutoConfig.from_pretrained('OpenGVLab/InternVideo2-Chat-8B', torch_dtype=torch.bfloat16,trust_remote_code=True)
model = InternVideo2_Classification(config).cuda()
B, T, C, H, W = 1, 8, 3, 224, 224
video_tensor = torch.randn(B,T,C,H,W).cuda()
user_prompt = "this is a user prompt"
instruction = "this is an instruction"
conversation = model.build_conversation(instruction=instruction, user_prompt=user_prompt, media_type='video')
tokenized = model.build_input_ids(tokenizer,conversation,max_length=248,add_special_tokens=True,truncation=False,padding=False,return_tensors='pt')
input_ids = tokenized['input_ids'].unsqueeze(0).to(model.device)
attn_mask = tokenized['attention_mask'].unsqueeze(0).to(model.device)
indexes = tokenized['index'].unsqueeze(0)
text_embeds = model.pad_text_embeds(input_ids = input_ids,video = video_tensor,video_idx = indexes)
outputs = model.lm(inputs_embeds=text_embeds, attention_mask=attn_mask,output_hidden_states=True,return_dict=True)