|
|
|
import os |
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from torch.cuda.amp import autocast as autocast |
|
from typing import Optional |
|
from modeling_internvideo2_vit import pretrain_internvideo2_giant_patch14_224_clean |
|
from modeling_qformer import build_qformer |
|
|
|
from model_config import VideoChat2Config |
|
|
|
from transformers import AutoTokenizer,AutoModel, AutoConfig, PreTrainedModel, PretrainedConfig |
|
import logging |
|
logger = logging.getLogger(__name__) |
|
|
|
token = os.environ['HF_TOKEN'] |
|
|
|
IMG_TOKEN = "[<IMG_PLH>]" |
|
VID_TOKEN = "[<VID_PLH>]" |
|
|
|
DEFAULT_PAD_TOKEN = "[PAD]" |
|
DEFAULT_BOS_TOKEN = '<s>' |
|
DEFAULT_EOS_TOKEN = '</s>' |
|
DEFAULT_UNK_TOKEN = "<unk>" |
|
|
|
DEFAULT_IMAGE_TOKEN = "[IMAGETOKEN]" |
|
DEFAULT_VIDEO_TOKEN = "[VIDEOTOKEN]" |
|
|
|
DEFAULT_IMG_PLACEHOLDER = "[<IMG_PLH>]" |
|
DEFAULT_VID_PLACEHOLDER = "[<VID_PLH>]" |
|
|
|
|
|
|
|
def disabled_train(self, mode=True): |
|
"""Overwrite model.train with this function to make sure train/eval mode |
|
does not change anymore.""" |
|
return self |
|
|
|
|
|
def freeze_module(module): |
|
for _, param in module.named_parameters(): |
|
param.requires_grad = False |
|
module = module.eval() |
|
module.train = disabled_train |
|
return module |
|
|
|
|
|
class InternVideo2_Classification(PreTrainedModel): |
|
config_class = VideoChat2Config |
|
def __init__(self, config): |
|
self.model_config = config.model_config |
|
|
|
super().__init__(config) |
|
self.build_vision_encoder() |
|
self.build_llm() |
|
self.build_bridge() |
|
|
|
for n, p in self.named_parameters(): |
|
if p.requires_grad: |
|
logger.info(f'{n} requires_grad') |
|
|
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
image: Optional[torch.Tensor] = None, |
|
video: Optional[torch.Tensor] = None, |
|
instruction = None, |
|
video_idx = None, |
|
image_idx = None, |
|
): |
|
if self.use_vision_regression_loss: |
|
text_embeds, visual, visual_idx = self.pad_text_embeds(input_ids=input_ids, image=image,video=video, return_visual=True, video_idx=video_idx, image_idx=image_idx, instruction = instruction) |
|
else: |
|
text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, return_visual=False, video_idx=video_idx, image_idx=image_idx, instruction = instruction) |
|
|
|
outputs = self.lm( |
|
inputs_embeds=text_embeds, |
|
attention_mask=attention_mask, |
|
labels=labels, |
|
output_hidden_states=True, |
|
return_dict=True, |
|
) |
|
|
|
return outputs |
|
|
|
|
|
def build_vision_encoder(self): |
|
|
|
|
|
if 'internvideo2' in self.model_config.vision_encoder.name.lower(): |
|
encoder_name = self.model_config.vision_encoder.name |
|
logger.info(f"Build vision_encoder: {encoder_name}") |
|
if encoder_name == 'internvideo2-1B': |
|
self.vision_encoder = pretrain_internvideo2_giant_patch14_224_clean(self.model_config) |
|
else: |
|
raise ValueError(f"Not implemented: {encoder_name}") |
|
else: |
|
raise NotImplementedError(self.model_config.vision_encoder.name) |
|
|
|
if self.model_config.vision_encoder.vit_add_ln: |
|
self.vision_layernorm = nn.LayerNorm(self.model_config.vision_encoder.encoder_embed_dim, eps=1e-12) |
|
else: |
|
self.vision_layernorm = nn.Identity() |
|
|
|
self.freeze_vision_encoder = self.model_config.get("freeze_vision_encoder", False) |
|
|
|
if self.freeze_vision_encoder: |
|
logger.info("freeze vision encoder") |
|
freeze_module(self.vision_encoder) |
|
freeze_module(self.vision_layernorm) |
|
|
|
|
|
def build_bridge(self): |
|
|
|
self.project_up = nn.Linear(768, self.lm.config.hidden_size) |
|
|
|
self.project_down = nn.Linear(self.lm.config.hidden_size, 768) |
|
|
|
if 'qformer' in self.model_config.bridge.name.lower(): |
|
from transformers import BertTokenizer |
|
self.qformer_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="left") |
|
self.qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"}) |
|
self.qformer_tokenizer.padding_side = "left" |
|
if self.model_config.bridge.name == 'qformer': |
|
self.qformer, self.query_tokens = build_qformer( |
|
self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim, |
|
qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob, |
|
qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob, |
|
qformer_drop_path_rate=self.model_config.bridge.qformer_drop_path_rate, |
|
) |
|
self.qformer.resize_token_embeddings(len(self.qformer_tokenizer)) |
|
self.qformer.cls = None |
|
self.extra_num_query_token = self.model_config.bridge.extra_num_query_token |
|
if self.model_config.bridge.extra_num_query_token > 0: |
|
logger.info(f"Add extra {self.model_config.bridge.extra_num_query_token} tokens in QFormer") |
|
self.extra_query_tokens = nn.Parameter( |
|
torch.zeros(1, self.model_config.bridge.extra_num_query_token, self.query_tokens.shape[-1]) |
|
) |
|
|
|
self.freeze_bridge = self.model_config.get("freeze_bridge", False) |
|
if self.freeze_bridge: |
|
logger.info("freeze bridge") |
|
freeze_module(self.qformer) |
|
self.query_tokens.requires_grad = False |
|
|
|
|
|
def build_llm(self): |
|
self.lm_name = self.model_config.llm.name |
|
if self.model_config.llm.name == 'mistral_7b': |
|
from transformers import AutoModelForSequenceClassification |
|
config = AutoConfig.from_pretrained( |
|
self.model_config.llm.pretrained_llm_path, |
|
torch_dtype=torch.bfloat16, |
|
token=token, |
|
|
|
) |
|
self.lm = AutoModelForSequenceClassification.from_config(config) |
|
elif self.model_config.llm.name == 'internlm_20b': |
|
from transformers import AutoModelForSequenceClassification |
|
self.lm = AutoModelForSequenceClassification.from_pretrained( |
|
self.model_config.llm.pretrained_llm_path, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
) |
|
self.lm.gradient_checkpointing = True |
|
self.lm._set_gradient_checkpointing() |
|
elif self.model_config.llm.name == 'internlm2_5_7b': |
|
from transformers import AutoModelForSequenceClassification |
|
self.lm = AutoModelForSequenceClassification.from_pretrained( |
|
self.model_config.llm.pretrained_llm_path, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
local_files_only=True, |
|
) |
|
else: |
|
raise NotImplementedError(self.model_config.llm.name) |
|
|
|
self.freeze_llm = self.model_config.get("freeze_llm", True) |
|
logger.info(f'freeze_llm: {self.freeze_llm}') |
|
if self.freeze_llm: |
|
logger.info("freeze llm") |
|
freeze_module(self.lm) |
|
|
|
if self.model_config.llm.use_lora: |
|
self.use_lora = True |
|
from peft import get_peft_model, LoraConfig, TaskType |
|
logger.info("Use lora") |
|
if self.model_config.llm.name == 'internlm_20b': |
|
peft_config = LoraConfig( |
|
task_type=TaskType.CAUSAL_LM, inference_mode=False, |
|
r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout, |
|
target_modules=['wqkv', 'wo', 'w1', 'w2', 'w3', 'output'] |
|
) |
|
else: |
|
peft_config = LoraConfig( |
|
task_type=TaskType.CAUSAL_LM, inference_mode=False, |
|
r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout, |
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", |
|
"gate_proj", "up_proj", "down_proj", "lm_head"] |
|
) |
|
|
|
self.lm = get_peft_model(self.lm, peft_config) |
|
self.lm.enable_input_require_grads() |
|
self.lm.print_trainable_parameters() |
|
else: |
|
self.use_lora = False |
|
|
|
|
|
def build_conversation(self,instruction, user_prompt,media_type='video',msg=''): |
|
|
|
conversation = "" |
|
if instruction: |
|
conversation += instruction |
|
conversation += ("[INST]" + " ") |
|
|
|
if media_type == 'image': |
|
conversation +=( "<Image>" + IMG_TOKEN + "</Image>") |
|
else: |
|
conversation += ("<Video>" + VID_TOKEN + "</Video>") |
|
|
|
conversation += (msg.rstrip() + "[/INST]") |
|
conversation += (" [INST] " + user_prompt + " [/INST]") |
|
conversation += ("") |
|
return conversation |
|
|
|
|
|
def pad_text_embeds( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
image: Optional[torch.Tensor] = None, |
|
video: Optional[torch.Tensor] = None, |
|
image_idx = None, |
|
video_idx = None, |
|
return_visual: bool = False, |
|
instruction = None, |
|
): |
|
|
|
text_embeds = self.lm.get_input_embeddings()(input_ids.long()).detach() |
|
|
|
visual = None |
|
visual_idx = None |
|
|
|
if image is not None: |
|
B, T, C, H, W = image.shape |
|
image = image.permute(0, 2, 1, 3, 4) |
|
prompt_image_embeds = self.encode_vision(image, instruction=instruction) |
|
visual = prompt_image_embeds |
|
prompt_image_embeds = self.project_up(prompt_image_embeds) |
|
prompt_image_embeds = prompt_image_embeds.view(-1, prompt_image_embeds.shape[-1]) |
|
visual_idx = image_idx |
|
text_embeds[image_idx == 1] = text_embeds[image_idx == 1] * 0 + prompt_image_embeds.to(text_embeds.device) |
|
elif video is not None: |
|
if len(video.shape) == 5: |
|
B, T, C, H, W = video.shape |
|
N = 1 |
|
else: |
|
B, N, T, C, H, W = video.shape |
|
video = video.reshape(B*N, T, C, H, W).permute(0, 2, 1, 3, 4) |
|
prompt_video_embeds = self.encode_vision(video, instruction=instruction) |
|
visual = prompt_video_embeds |
|
prompt_video_embeds = self.project_up(prompt_video_embeds) |
|
prompt_video_embeds = prompt_video_embeds.view(-1, prompt_video_embeds.shape[-1]) |
|
visual_idx = video_idx |
|
text_embeds[video_idx == 1] = text_embeds[video_idx == 1] * 0 + prompt_video_embeds.to(text_embeds.device).to(text_embeds.dtype) |
|
else: |
|
logger.warn(f"don't get visual input, input_ids: {input_ids}") |
|
|
|
if return_visual: |
|
return text_embeds, visual, visual_idx |
|
|
|
return text_embeds |
|
|
|
|
|
def encode_vision( |
|
self, |
|
image, |
|
instruction |
|
): |
|
device = image.device |
|
B = image.shape[0] |
|
T = image.shape[2] |
|
use_image = True if T == 1 else False |
|
image_embeds = self.vision_encoder(image, use_image=use_image) |
|
C = image_embeds.shape[-1] |
|
image_embeds = image_embeds.reshape(B, -1, C) |
|
image_embeds = self.vision_layernorm(image_embeds).to(device) |
|
|
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) |
|
if self.extra_num_query_token > 0: |
|
query_tokens = torch.cat([self.query_tokens, self.extra_query_tokens], dim=1) |
|
query_tokens = query_tokens.expand(image_embeds.shape[0], -1, -1) |
|
if instruction is not None: |
|
text_Qformer = self.qformer_tokenizer( |
|
instruction, |
|
padding='longest', |
|
truncation=True, |
|
max_length=512, |
|
return_tensors="pt", |
|
).to(image_embeds.device) |
|
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device) |
|
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1) |
|
query_output = self.qformer.bert( |
|
text_Qformer.input_ids, |
|
attention_mask=Qformer_atts, |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_atts, |
|
return_dict=True, |
|
) |
|
else: |
|
query_output = self.qformer.bert( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_atts, |
|
return_dict=True, |
|
) |
|
|
|
return query_output.last_hidden_state[:, :query_tokens.size(1), :] |
|
|
|
|
|
def build_input_ids( |
|
self, |
|
tokenizer, |
|
conversation, |
|
max_length, |
|
add_special_tokens, |
|
truncation, |
|
image = None, |
|
video = None, |
|
padding = "longest", |
|
return_tensors = "pt", |
|
image_placeholder: str = DEFAULT_IMG_PLACEHOLDER, |
|
video_placeholder: str = DEFAULT_VID_PLACEHOLDER, |
|
): |
|
input_ids = [] |
|
indexs = [] |
|
attention_mask = [] |
|
start, total_len = 0, 0 |
|
while True: |
|
index1 = conversation.find(image_placeholder, start) |
|
index2 = conversation.find(video_placeholder, start) |
|
if index1 == -1 and index2 == -1: |
|
index = -1 |
|
elif index1 == -1: |
|
index = index2 |
|
elif index2 == -1: |
|
index = index1 |
|
else: |
|
index = min(index1, index2) |
|
assert index != -1 |
|
if index == -1: |
|
inputs = tokenizer(conversation[start:], max_length=max_length-total_len, truncation=truncation, padding=padding, return_tensors=return_tensors) |
|
else: |
|
inputs = tokenizer(conversation[start:index], max_length=max_length, truncation=truncation, padding='longest', return_tensors=return_tensors) |
|
|
|
input_ids += inputs.input_ids |
|
attention_mask += inputs.attention_mask |
|
total_len += inputs.input_ids[0].shape[0] |
|
indexs += torch.zeros_like(inputs.input_ids) |
|
|
|
if index != -1: |
|
input_ids += [torch.zeros(96).long()] |
|
attention_mask += [torch.ones(96).long()] |
|
indexs += [torch.ones(96)] |
|
|
|
if index == -1: |
|
return { |
|
'input_ids': torch.cat(input_ids), |
|
'attention_mask': torch.cat(attention_mask), |
|
'index': torch.cat(indexs).to(torch.bool), |
|
} |
|
start = index + len(DEFAULT_IMG_PLACEHOLDER) |
|
|
|
|
|
@property |
|
def dtype(self): |
|
return self.lm.dtype |
|
|
|
@property |
|
def device(self): |
|
return self.lm.device |
|
|
|
|
|
class InternVideo2_Classification_test(PreTrainedModel): |
|
config_class = VideoChat2Config |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.conv1 = nn.Conv2d(1, 20, 5) |
|
self.conv2 = nn.Conv2d(20, 20, 5) |
|
self.model_config = config.model_config |
|
self.build_bridge() |
|
|
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
return self.conv2(x) |
|
|
|
def test_lol(self, x): |
|
return x |
|
|
|
def build_bridge(self): |
|
|
|
if 'qformer' in self.model_config.bridge.name.lower(): |
|
from transformers import BertTokenizer |
|
self.qformer_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="left") |
|
self.qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"}) |
|
self.qformer_tokenizer.padding_side = "left" |
|
if self.model_config.bridge.name == 'qformer': |
|
self.qformer, self.query_tokens = build_qformer( |
|
self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim, |
|
qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob, |
|
qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob, |
|
qformer_drop_path_rate=self.model_config.bridge.qformer_drop_path_rate, |
|
) |
|
self.qformer.resize_token_embeddings(len(self.qformer_tokenizer)) |
|
self.qformer.cls = None |
|
self.extra_num_query_token = self.model_config.bridge.extra_num_query_token |
|
if self.model_config.bridge.extra_num_query_token > 0: |
|
logger.info(f"Add extra {self.model_config.bridge.extra_num_query_token} tokens in QFormer") |
|
self.extra_query_tokens = nn.Parameter( |
|
torch.zeros(1, self.model_config.bridge.extra_num_query_token, self.query_tokens.shape[-1]) |
|
) |
|
|
|
self.freeze_bridge = self.model_config.get("freeze_bridge", False) |
|
if self.freeze_bridge: |
|
logger.info("freeze bridge") |
|
freeze_module(self.qformer) |
|
self.query_tokens.requires_grad = False |
|
|
|
if __name__ == "__main__": |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('OpenGVLab/InternVideo2-Chat-8B',trust_remote_code=True,use_fast=False) |
|
config = AutoConfig.from_pretrained('OpenGVLab/InternVideo2-Chat-8B', torch_dtype=torch.bfloat16,trust_remote_code=True) |
|
model = InternVideo2_Classification(config).cuda() |
|
|
|
B, T, C, H, W = 1, 8, 3, 224, 224 |
|
video_tensor = torch.randn(B,T,C,H,W).cuda() |
|
user_prompt = "this is a user prompt" |
|
instruction = "this is an instruction" |
|
|
|
conversation = model.build_conversation(instruction=instruction, user_prompt=user_prompt, media_type='video') |
|
tokenized = model.build_input_ids(tokenizer,conversation,max_length=248,add_special_tokens=True,truncation=False,padding=False,return_tensors='pt') |
|
|
|
input_ids = tokenized['input_ids'].unsqueeze(0).to(model.device) |
|
attn_mask = tokenized['attention_mask'].unsqueeze(0).to(model.device) |
|
indexes = tokenized['index'].unsqueeze(0) |
|
text_embeds = model.pad_text_embeds(input_ids = input_ids,video = video_tensor,video_idx = indexes) |
|
outputs = model.lm(inputs_embeds=text_embeds, attention_mask=attn_mask,output_hidden_states=True,return_dict=True) |
|
|