zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import os
import copy
from collections import OrderedDict
from typing import List, Optional, Tuple, Union
from types import MethodType
from enum import Enum
import torch
import torch.amp
import torch.distributed
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from mmengine.model import BaseModel
from peft import get_peft_model, prepare_model_for_kbit_training
from safetensors.torch import load_file
from safetensors import safe_open
import torch.utils
import torch.utils.checkpoint
from xtuner.registry import BUILDER
from xtuner.model.modules import dispatch_modules
from xtuner.utils import DEFAULT_IMAGE_TOKEN
from transformers import AutoModel, AutoConfig, AutoTokenizer, BitsAndBytesConfig, GenerationConfig, AutoImageProcessor
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput, BaseModelOutputWithPooling
from transformers import ConvNextV2ForImageClassification
from .utils import (LoadWoInit, traverse_dict, make_inputs_require_grad, find_all_linear_names,
guess_load_checkpoint, get_peft_model_state_dict)
from ..dataset.utils import (get_conv_template, IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN,
VPT_CONTEXT_TOKEN, VPT_START_TOKEN, VPT_END_TOKEN)
class WrapInternVL(BaseModel):
def __init__(self,
mllm,
tokenizer=None,
freeze_llm=False,
freeze_visual_encoder=False,
freeze_connector=False,
freeze_ot_mlp=False,
unfreeze_vocab=False,
unfreeze_lm_head=False,
llm_lora=None,
visual_encoder_lora=None,
pretrained_pth=None,
radio_pretrained_pth=None,
use_activation_checkpointing=True,
vocab_embeds_name="tok_embeddings",
lm_head_name="output",
contras_loss=False,
use_object_tokens=False,
object_tokenizer=None,
object_tokenizer_pretrain=False,
):
super().__init__()
self.freeze_llm = freeze_llm
self.freeze_visual_encoder = freeze_visual_encoder
self.freeze_connector = freeze_connector
self.freeze_ot_mlp = freeze_ot_mlp
self.unfreeze_vocab = unfreeze_vocab
self.unfreeze_lm_head = unfreeze_lm_head
self.use_llm_lora = llm_lora is not None
self.use_visual_encoder_lora = visual_encoder_lora is not None
self.use_activation_checkpointing=use_activation_checkpointing
self.vocab_embeds_name = vocab_embeds_name
self.lm_head_name = lm_head_name
self.contras_loss = contras_loss
self.object_tokenizer_pretrain=object_tokenizer_pretrain
config = AutoConfig.from_pretrained(mllm["pretrained_model_name_or_path"], trust_remote_code=True)
self.config = config
if config.llm_config.model_type == 'internlm2':
config.llm_config.attn_implementation = 'flash_attention_2'
else:
config.llm_config._attn_implementation = 'flash_attention_2'
traverse_dict(mllm)
model_clazz = mllm.pop('type')
mllm.update(dict(config=config))
self.model = model_clazz(**mllm)
self.model.language_model.config.use_cache = False
dispatch_modules(self.model.language_model)
if use_object_tokens:
# 1280
ot_config = AutoConfig.from_pretrained(object_tokenizer["pretrained_model_name_or_path"], trust_remote_code=True)
self.ot_config = ot_config
traverse_dict(object_tokenizer)
ot_clazz = object_tokenizer.pop('type')
self.object_tokenizer = ot_clazz(**object_tokenizer)
if 'dinov2' in object_tokenizer["pretrained_model_name_or_path"]:
ot_hidden_size = self.ot_config.hidden_size
self.vfm_name = 'DINOv2'
elif 'RADIO' in object_tokenizer["pretrained_model_name_or_path"]:
ot_hidden_size = self.object_tokenizer.model.num_features
self.vfm_name = 'RADIO'
elif 'convnext' in object_tokenizer["pretrained_model_name_or_path"]:
ot_hidden_size = self.ot_config.hidden_sizes[-1]
self.vfm_name = "ConvNext"
else:
raise NotImplementedError
self.ot_mlp1 = nn.Sequential(
nn.LayerNorm(ot_hidden_size,),
nn.Linear(ot_hidden_size, config.llm_config.hidden_size,),
nn.GELU(),
nn.Linear(config.llm_config.hidden_size, config.llm_config.hidden_size)
)
else:
self.object_tokenizer = None
self.ot_mlp1 = None
self.ot_config = None
if tokenizer is not None:
self.tokenizer = self._build_from_cfg_or_module(tokenizer)
else:
self.tokenizer = AutoTokenizer.from_pretrained(mllm["pretrained_model_name_or_path"], trust_remote_code=True)
img_context_token_id = self.tokenizer.convert_tokens_to_ids('<IMG_CONTEXT>')
self.model.img_context_token_id = img_context_token_id
self._add_special_tokens()
if self.freeze_llm:
self.model.language_model.requires_grad_(False)
if self.freeze_visual_encoder:
self.model.vision_model.requires_grad_(False)
if self.freeze_connector:
self.model.mlp1.requires_grad_(False)
if self.object_tokenizer is not None:
self.object_tokenizer.requires_grad_(False)
if self.freeze_ot_mlp and self.ot_mlp1 is not None:
self.ot_mlp1.requires_grad_(False)
if use_activation_checkpointing:
# it is necessary when using gradient checkpointing
if hasattr(self.model.language_model, 'enable_input_require_grads'):
self.model.language_model.enable_input_require_grads()
else:
self.model.language_model.get_input_embeddings(
).register_forward_hook(make_inputs_require_grad)
self.gradient_checkpointing_enable()
if self.use_llm_lora:
self._prepare_llm_for_lora(llm_lora)
# put this after llm_lora
if self.unfreeze_vocab:
self.model.language_model.get_input_embeddings().requires_grad_(True)
if self.unfreeze_lm_head:
self.model.language_model.get_output_embeddings().requires_grad_(True)
if self.use_visual_encoder_lora:
self._prepare_visual_encoder_for_lora(visual_encoder_lora)
if pretrained_pth is not None:
pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
# for k, v in pretrained_state_dict.items():
# if 'radio_model.summary_idxs' in k:
# self.object_tokenizer.radio_model.summary_idxs = v
# if 'radio_model.input_conditioner.norm_mean' in k:
# self.object_tokenizer.radio_model.input_conditioner.norm_mean = v
# if 'radio_model.input_conditioner.norm_std' in k:
# self.object_tokenizer.radio_model.input_conditioner.norm_std = v
mllm_state_dict = {}
for k, v in pretrained_state_dict.items():
if k.startswith('model.'):
mllm_state_dict[k[len('model.'):]] = v
if len(mllm_state_dict) != 0:
self.model.load_state_dict(mllm_state_dict, strict=False)
if use_object_tokens:
ot_adapter_state_dict = {}
if radio_pretrained_pth is not None:
pretrained_state_dict = torch.load(radio_pretrained_pth, map_location='cpu')
if 'state_dict' in pretrained_state_dict:
pretrained_state_dict = pretrained_state_dict['state_dict']
for k, v in pretrained_state_dict.items():
if k.startswith('ot_mlp1.'):
ot_adapter_state_dict[k[len('ot_mlp1.'):]] = v
if len(ot_adapter_state_dict) != 0:
self.ot_mlp1.load_state_dict(ot_adapter_state_dict, strict=False)
for k, v in self.ot_mlp1.named_parameters():
assert v.equal(ot_adapter_state_dict[k])
# pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
# self.load_state_dict(pretrained_state_dict, strict=False) # TODO, check whether the internvl2 weights are loaded correctly.
print(f"Load pretrained weight from {pretrained_pth}")
if radio_pretrained_pth is not None:
print(f"Load pretrained weight from {radio_pretrained_pth}")
self.patch_token = int(
(config.force_image_size // config.vision_config.patch_size)**2 * (config.downsample_ratio**2))
self._count = 0
print_log(self, logger="current")
print_log('InternVL_V1_5 construction is complete', logger='current')
def _merge_lora(self):
# print('pre merge lora: ', self.mllm.model.language_model.base_model.model.get_input_embeddings().weight.shape)
try:
self.model.language_model = self.model.language_model.merge_and_unload()
except:
print("Skip language model, no LoRA in it !!!")
try:
self.model.vision_model = self.model.vision_model.merge_and_unload()
except:
print("Skip vision encoder, no LoRA in it !!!")
# print('after merge lora: ', self.mllm.model.language_model.get_input_embeddings().weight.shape)
return
def _add_special_tokens(self):
assert hasattr(self, "tokenizer")
special_tokens = [VPT_CONTEXT_TOKEN, ]
num_new_tokens = self.tokenizer.add_tokens(special_tokens, special_tokens=True)
print_log(f"Added {num_new_tokens} special tokens.")
self.vpt_content_token_idx = self.tokenizer(VPT_CONTEXT_TOKEN, add_special_tokens=False).input_ids[0]
def _build_from_cfg_or_module(self, cfg_or_mod):
if isinstance(cfg_or_mod, nn.Module):
return cfg_or_mod
elif isinstance(cfg_or_mod, dict):
traverse_dict(cfg_or_mod)
return BUILDER.build(cfg_or_mod)
else:
raise NotImplementedError
def _parse_lora_config(self, lora_config):
if isinstance(lora_config, dict) or isinstance(
lora_config, Config) or isinstance(lora_config, ConfigDict):
lora_config = BUILDER.build(lora_config)
return lora_config
def _prepare_llm_for_lora(self, lora_config, use_activation_checkpointing=True):
lora_config = self._parse_lora_config(lora_config)
if self.config.llm_config.architectures[0] == 'InternLM2ForCausalLM':
target_modules = ['attention.wqkv', 'attention.wo', 'feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']
elif self.config.llm_config.architectures[0] == 'Phi3ForCausalLM':
target_modules = ['mlp.down_proj', 'mlp.gate_up_proj', 'self_attn.o_proj', 'self_attn.qkv_proj']
elif self.config.llm_config.architectures[0] in ['Qwen2ForCausalLM', 'LlamaForCausalLM']:
target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj']
else:
raise NotImplementedError
lora_config.target_modules = target_modules
self.model.language_model = get_peft_model(self.model.language_model, lora_config)
def _prepare_visual_encoder_for_lora(self, lora_config):
lora_config = self._parse_lora_config(lora_config)
if lora_config.target_modules is None:
modules = find_all_linear_names(self.model.vision_model)
lora_config.target_modules = modules
if self.interaction_indexes is not None:
modules = []
for block in self.interaction_indexes:
for layer_id in range(block[0], block[1]+1):
modules.extend([f'{layer_id}.{elem}' for elem in ['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2']])
lora_config.target_modules = modules
self.model.vision_model = get_peft_model(self.model.vision_model, lora_config)
def gradient_checkpointing_enable(self):
self.activation_checkpointing_enable()
def activation_checkpointing_enable(self):
self.model.language_model.gradient_checkpointing_enable()
def gradient_checkpointing_disable(self):
self.activation_checkpointing_disable()
def activation_checkpointing_disable(self):
self.model.language_model.gradient_checkpointing_disable()
def all_state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
return state_dict
def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
to_return = OrderedDict()
# Step 1. visual_encoder
if self.use_visual_encoder_lora:
to_return.update(
get_peft_model_state_dict(
self.model.vision_model, state_dict=state_dict))
elif not self.freeze_visual_encoder:
to_return.update({
k: v
for k, v in state_dict.items() if 'model.vision_model.' in k
})
# Step 2. LLM
if self.use_llm_lora:
to_return.update(
get_peft_model_state_dict(
self.model.language_model, state_dict=state_dict))
elif not self.freeze_llm:
to_return.update({
k: v
for k, v in state_dict.items() if 'model.language_model.'
})
# Step 3. Projector
if not self.freeze_connector:
to_return.update(
{k: v
for k, v in state_dict.items() if 'model.mlp1.' in k})
o
# Step 4. Custom part
if not self.freeze_ot_mlp:
to_return.update({k: v for k, v in state_dict.items() if 'ot_mlp1.' in k})
# for k, v in to_return.items():
# print(k)
# exit(0)
return to_return
def init_weights(self):
pass
def forward(self, data, data_samples=None, mode='loss'):
# for n, p in self.named_parameters():
# if p.requires_grad:
# print(n)
# exit(0)
pixel_values = data['pixel_values'].to(self.model.vision_model.dtype)
merged_visual_prompts = data['merged_visual_prompts'].to(self.model.vision_model.dtype)
num_patches = data['num_patches']
has_ot_input = data['ot_pixel_values'] is not None
vit_embeds = self.model.extract_feature(merged_visual_prompts)
if has_ot_input and self.object_tokenizer:
ot_pixel_values = data['ot_pixel_values'].to(self.object_tokenizer.dtype)
ot_h, ot_w = ot_pixel_values.shape[-2:]
if self.vfm_name == "RADIO":
summary, ot_embeds = self.object_tokenizer(ot_pixel_values)
ot_num_tokens_h, ot_num_tokens_w = ot_h // self.ot_config.patch_size, ot_w // self.ot_config.patch_size
elif self.vfm_name == "DINOv2":
ot_outputs = self.object_tokenizer(pixel_values=ot_pixel_values)
ot_embeds = ot_outputs.last_hidden_state[:, 1:, :]
ot_num_tokens_h, ot_num_tokens_w = ot_h // self.ot_config.patch_size, ot_w // self.ot_config.patch_size
elif self.vfm_name == "ConvNext":
ot_outputs = self.object_tokenizer.convnextv2(pixel_values=ot_pixel_values,
output_hidden_states=True,
return_dict=True)
ot_embeds = ot_outputs.last_hidden_state # N, C, H, W
ot_num_tokens_h, ot_num_tokens_w = ot_embeds.shape[-2:]
ot_embeds = ot_embeds.flatten(2, 3).transpose(1, 2)
with torch.amp.autocast(device_type='cuda', dtype=self.model.vision_model.dtype):
ot_embeds = self.ot_mlp1(ot_embeds)
if self.object_tokenizer_pretrain:
region_ids = data['region_ids']
# internvit object tokenization
num_images = data['num_images']
batch_size = len(num_images)
num_vprompts = data['num_vprompts']
num_patches = data['num_patches']
visual_prompts = data['visual_prompts']
split_num_vprompts = torch.split(num_vprompts, [nimg for nimg in num_images])
split_num_patches = torch.split(num_patches, [nimg for nimg in num_images])
vprompt_split_size = [n_vp*n_p for n_vp , n_p in zip(num_vprompts, num_patches)]
split_visual_prompts = torch.split(visual_prompts, vprompt_split_size)
split_vit_embeds = torch.split(vit_embeds, [np for np in num_patches])
object_embeds_in_batch = []
valid_flag_in_batch = []
start_idx = 0
for bidx in range(batch_size):
num_vprompt = split_num_vprompts[bidx]
num_patch = split_num_patches[bidx]
visual_prompts_bi = split_visual_prompts[start_idx:start_idx+num_images[bidx]]
split_vit_embeds_bi = split_vit_embeds[start_idx:start_idx+num_images[bidx]]
start_idx = start_idx + num_images[bidx]
object_embed_list, valid_flag_list = [], []
for fidx, visual_prompts_fi in enumerate(visual_prompts_bi):
h, w = visual_prompts_fi.shape[-2:]
visual_prompts_fi = visual_prompts_fi.reshape(num_vprompt[fidx], num_patch[fidx], h, w)
patch_token_edge = int(self.patch_token ** 0.5)
visual_prompts_fi = F.interpolate(visual_prompts_fi.to(vit_embeds.dtype), patch_token_edge, mode='bilinear')
visual_prompts_fi = (visual_prompts_fi > 0.55).to(vit_embeds.dtype)
visual_prompts_fi = visual_prompts_fi.reshape(num_vprompt[fidx], -1)
num_vp_tokens = torch.sum(visual_prompts_fi, dim=-1, keepdim=False)
valid_flag = num_vp_tokens > 0
vit_embeds_fi = split_vit_embeds_bi[fidx].flatten(0, 1)
object_embeds = (visual_prompts_fi[:, :, None] / (num_vp_tokens[:, None, None] + 1e-4) * vit_embeds_fi[None, :, :])
object_embeds = torch.sum(object_embeds, dim=1)
object_embed_list.append(object_embeds)
valid_flag_list.append(valid_flag)
object_embeds_in_batch.append(object_embed_list)
valid_flag_in_batch.append(valid_flag_list)
# object tokenizer
ot_visual_prompts = data['ot_visual_prompts']
split_ot_visual_prompts = torch.split(ot_visual_prompts, [nvp for nvp in num_vprompts])
ot_object_embeds_in_batch = []
ot_valid_flag_in_batch = []
start_idx = 0
for bidx in range(batch_size):
num_vprompt = split_num_vprompts[bidx]
ot_visual_prompts_bi = split_ot_visual_prompts[start_idx:start_idx+num_images[bidx]]
ot_embeds_bi = ot_embeds[start_idx:start_idx+num_images[bidx]] # bs, s, c
start_idx = start_idx + num_images[bidx]
object_embed_list, valid_flag_list = [], []
for fidx, ot_visual_prompts_fi in enumerate(ot_visual_prompts_bi):
h, w = ot_visual_prompts_fi.shape[-2:]
# ot_visual_prompts_fi = ot_visual_prompts_fi.reshape(num_vprompt[fidx], 1, h, w)
ot_visual_prompts_fi = ot_visual_prompts_fi[:, None, :, :]
ot_visual_prompts_fi = F.interpolate(ot_visual_prompts_fi.to(ot_embeds.dtype), (ot_num_tokens_h, ot_num_tokens_w), mode="bilinear")
ot_visual_prompts_fi = (ot_visual_prompts_fi > 0.55).to(ot_embeds.dtype)
ot_visual_prompts_fi = ot_visual_prompts_fi.reshape(num_vprompt[fidx], -1)
num_vp_tokens = torch.sum(ot_visual_prompts_fi, dim=-1, keepdim=False)
valid_flag = num_vp_tokens > 0
ot_embeds_fi = ot_embeds_bi[fidx]
object_embeds = (ot_visual_prompts_fi[:, :, None] / (num_vp_tokens[:, None, None] + 1e-4) * ot_embeds_fi[None, :, :])
object_embeds = torch.sum(object_embeds, dim=1)
object_embed_list.append(object_embeds)
valid_flag_list.append(valid_flag)
ot_object_embeds_in_batch.append(object_embed_list)
ot_valid_flag_in_batch.append(valid_flag_list)
# contrastive loss
contras_loss = torch.zeros(size=(1, ), dtype=torch.float32).cuda()
# contras_loss += ot_embeds.sum() * 0.0
# return {"loss": contras_loss}
valid_contras_sample = 0
for bidx in range(batch_size):
region_ids_bi = region_ids[bidx]
object_embeds_bi = object_embeds_in_batch[bidx]
ot_object_embeds_bi = ot_object_embeds_in_batch[bidx]
valid_flags_bi = valid_flag_in_batch[bidx]
ot_valid_flags_bi = ot_valid_flag_in_batch[bidx]
for ot_object_embeds, object_embeds, ot_region_ids, _region_ids, ot_valid_flags, valid_flags in zip(
ot_object_embeds_bi, object_embeds_bi[::-1],
region_ids_bi, region_ids_bi[::-1],
ot_valid_flags_bi, valid_flags_bi[::-1],
):
region_id_to_indices = {region_id: idx for idx, region_id in enumerate(_region_ids)}
for anchor_embed, valid_flag, region_id in zip(ot_object_embeds, ot_valid_flags, ot_region_ids):
if not valid_flag:
continue
anchor_embed = anchor_embed.unsqueeze(0) # 1, C
pos_idx = region_id_to_indices[region_id]
if not valid_flags[pos_idx]:
continue
pos_embed = object_embeds[pos_idx].unsqueeze(0) # 1, C
if pos_idx == 0:
neg_embeds = object_embeds[1:, :][valid_flags[1:]] # N, C
elif pos_idx == (len(object_embeds) - 1):
neg_embeds = object_embeds[:-1, :][valid_flags[:-1]] # N, C
else:
neg_embeds = torch.cat([
object_embeds[:pos_idx, :][valid_flags[:pos_idx]],
object_embeds[pos_idx+1:, :][valid_flags[pos_idx+1:]]
], dim=0) # N, C
pos_neg_embeds = torch.cat([pos_embed, neg_embeds], dim=0)
pos_neg_label = pos_neg_embeds.new_zeros((pos_neg_embeds.shape[0], ), dtype=torch.int64)
pos_neg_label[:1] = 1
# dot product
dot_product = torch.einsum('ac,kc->ak', [anchor_embed, pos_neg_embeds])
pos_neg_label = pos_neg_label.unsqueeze(0)
pos_inds = (pos_neg_label == 1)
neg_inds = (pos_neg_label == 0)
pred_pos = dot_product * pos_inds.float()
pred_neg = dot_product * neg_inds.float()
# use -inf to mask out unwanted elements
pred_pos[neg_inds] = pred_pos[neg_inds] + float('inf')
pred_neg[pos_inds] = pred_neg[pos_inds] + float('-inf')
_pos_expand = torch.repeat_interleave(pred_pos, dot_product.shape[1], dim=1)
_neg_expand = pred_neg.repeat(1, dot_product.shape[1])
x = F.pad((_neg_expand - _pos_expand), (0, 1), "constant", 0)
try:
contras_loss += torch.logsumexp(x, dim=1)
valid_contras_sample += 1
except Exception as e:
print("x: ", x.shape)
print("sumexp: ", torch.logsumexp(x, dim=1).shape)
exit(0)
if valid_contras_sample == 0 or torch.any(torch.isnan(contras_loss)):
loss_dict = {"loss": ot_embeds.sum() * 0.0}
else:
loss_dict = {"loss": contras_loss / valid_contras_sample}
return loss_dict
# ##### Abaltion #########
# ot_object_embeds = None
# vprompt_flags = data['vprompt_flags']
# ot_object_embeds_in_batch = []
# skip_this_batch = False
# num_vprompts = data['num_vprompts']
# num_images = data['num_images']
# batch_size = len(num_images)
# num_patches = data['num_patches']
# visual_prompts = data['visual_prompts']
# try:
# split_num_vprompts = torch.split(num_vprompts, [nimg for nimg in num_images])
# split_num_patches = torch.split(num_patches, [nimg for nimg in num_images])
# vprompt_split_size = [n_vp*n_p for n_vp , n_p in zip(num_vprompts, num_patches)]
# split_visual_prompts = torch.split(visual_prompts, vprompt_split_size)
# split_vit_embeds = torch.split(vit_embeds, [np for np in num_patches])
# start_idx = 0
# for bidx in range(batch_size):
# num_vprompt = split_num_vprompts[bidx]
# num_patch = split_num_patches[bidx]
# visual_prompts_bi = split_visual_prompts[start_idx:start_idx+num_images[bidx]]
# split_vit_embeds_bi = split_vit_embeds[start_idx:start_idx+num_images[bidx]]
# start_idx = start_idx + num_images[bidx]
# object_embed_list = []
# for fidx, visual_prompts_fi in enumerate(visual_prompts_bi):
# h, w = visual_prompts_fi.shape[-2:]
# visual_prompts_fi = visual_prompts_fi.reshape(num_vprompt[fidx], num_patch[fidx], h, w)
# patch_token_edge = int(self.patch_token ** 0.5)
# visual_prompts_fi = F.interpolate(visual_prompts_fi.to(vit_embeds.dtype), patch_token_edge, mode='bilinear')
# visual_prompts_fi = (visual_prompts_fi > 0.55).to(vit_embeds.dtype)
# visual_prompts_fi = visual_prompts_fi.reshape(num_vprompt[fidx], -1)
# num_vp_tokens = torch.sum(visual_prompts_fi, dim=-1, keepdim=False)
# vit_embeds_fi = split_vit_embeds_bi[fidx].flatten(0, 1)
# object_embeds = (visual_prompts_fi[:, :, None] / (num_vp_tokens[:, None, None] + 1e-4) * vit_embeds_fi[None, :, :])
# object_embeds = torch.sum(object_embeds, dim=1)
# object_embed_list.append(object_embeds)
# ot_object_embeds_in_batch.append(object_embed_list)
# ot_object_embeds = []
# for ele in ot_object_embeds_in_batch:
# ot_object_embeds.extend(ele)
# ot_object_embeds = torch.cat(ot_object_embeds, dim=0)
# except:
# skip_this_batch = True
ot_object_embeds = None
vprompt_flags = data['vprompt_flags']
ot_object_embeds_in_batch = []
skip_this_batch = False
if has_ot_input and self.object_tokenizer:
# object tokenizer
num_vprompts = data['num_vprompts']
num_images = data['num_images']
batch_size = len(num_images)
ot_visual_prompts = data['ot_visual_prompts']
try:
split_ot_visual_prompts = torch.split(ot_visual_prompts, [nvp for nvp in num_vprompts])
except:
nvp_list = [1 for nvp in num_vprompts]
if ot_visual_prompts.shape[0] >= len(nvp_list):
split_ot_visual_prompts = torch.split(ot_visual_prompts[:len(nvp_list)], nvp_list)
else:
split_ot_visual_prompts = torch.stack([ot_visual_prompts[0] for nvp in nvp_list])
num_vprompts = torch.tensor(nvp_list).to(num_vprompts.dtype).to(num_vprompts.device)
skip_this_batch = True
split_num_vprompts = torch.split(num_vprompts, [nimg for nimg in num_images])
start_idx = 0
for bidx in range(batch_size):
num_vprompt = split_num_vprompts[bidx]
ot_visual_prompts_bi = split_ot_visual_prompts[start_idx:start_idx+num_images[bidx]]
ot_embeds_bi = ot_embeds[start_idx:start_idx+num_images[bidx]] # bs, s, c
start_idx = start_idx + num_images[bidx]
ot_object_embeds_list = []
for fidx, ot_visual_prompts_fi in enumerate(ot_visual_prompts_bi):
h, w = ot_visual_prompts_fi.shape[-2:]
ot_visual_prompts_fi = ot_visual_prompts_fi.reshape(num_vprompt[fidx], 1, h, w)
# ot_visual_prompts_fi = ot_visual_prompts_fi[:, None, :, :]
ot_visual_prompts_fi = F.interpolate(ot_visual_prompts_fi.to(ot_embeds.dtype), (ot_num_tokens_h, ot_num_tokens_w), mode="bilinear")
ot_visual_prompts_fi = (ot_visual_prompts_fi > 0.5).to(ot_embeds.dtype)
ot_visual_prompts_fi = ot_visual_prompts_fi.reshape(num_vprompt[fidx], -1)
num_vp_tokens = torch.sum(ot_visual_prompts_fi, dim=-1, keepdim=False)
ot_embeds_fi = ot_embeds_bi[fidx]
object_embeds = (ot_visual_prompts_fi[:, :, None] / (num_vp_tokens[:, None, None] + 1e-4) * ot_embeds_fi[None, :, :])
object_embeds = torch.sum(object_embeds, dim=1)
ot_object_embeds_list.append(object_embeds)
ot_object_embeds_in_batch.append(ot_object_embeds_list)
ot_object_embeds = []
for ele in ot_object_embeds_in_batch:
ot_object_embeds.extend(ele)
ot_object_embeds = torch.cat(ot_object_embeds, dim=0)
if mode == "loss":
input_ids = data['input_ids']
position_ids = data['position_ids']
attention_mask = data['attention_mask']
image_flags = data['image_flags']
labels = data['labels']
use_cache = False
outputs, _skip_this_case = self._llm_forward(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
image_flags=image_flags,
labels=labels,
use_cache=use_cache,
vit_embeds=vit_embeds,
ot_object_embeds=ot_object_embeds,
vprompt_flags=vprompt_flags,
)
if skip_this_batch or _skip_this_case:
print("skip this batch!")
loss_dict = {'loss': outputs.loss * 0.0}
else:
loss_dict = {'loss': outputs.loss}
if not self.contras_loss:
return loss_dict
# num_images = data['num_images']
# batch_size = len(num_images)
# region_ids = data['region_ids']
# contras_loss = loss_dict['loss'].new_zeros((1,))
# valid_contras_sample = 0
# for bidx in range(batch_size):
# if num_images[bidx] < 2 or skip_this_batch:
# contras_loss += ot_embeds.sum() * 0.0
# else:
# assert num_images[bidx] == 2
# region_id_to_indices = {region_id: idx for idx, region_id in enumerate(region_ids[bidx][-1])}
# anchor_region_id = region_ids[bidx][0][0]
# if anchor_region_id == -1:
# contras_loss += ot_embeds.sum() * 0.0
# continue
# anchor_embeds = ot_object_embeds_in_batch[bidx][0].mean(dim=0, keepdim=True) # 1, C
# pos_idx = region_id_to_indices[anchor_region_id]
# try:
# pos_embeds = ot_object_embeds_in_batch[bidx][-1][pos_idx].unsqueeze(0)
# except:
# contras_loss += ot_embeds.sum() * 0.0
# continue
# try:
# if pos_idx == 0:
# neg_embeds = ot_object_embeds_in_batch[bidx][-1][1:, :] # N, C
# elif pos_idx == (len(ot_object_embeds_in_batch[bidx][-1]) - 1):
# neg_embeds = ot_object_embeds_in_batch[bidx][-1][:-1, :] # N, C
# else:
# neg_embeds = torch.cat([
# ot_object_embeds_in_batch[bidx][-1][:pos_idx, :],
# ot_object_embeds_in_batch[bidx][-1][pos_idx+1:, :]
# ], dim=0) # N, C
# except:
# contras_loss += ot_embeds.sum() * 0.0
# continue
# pos_neg_embeds = torch.cat([pos_embeds, neg_embeds], dim=0)
# pos_neg_label = pos_neg_embeds.new_zeros((pos_neg_embeds.shape[0], ), dtype=torch.int64)
# pos_neg_label[:1] = 1
# # dot product
# dot_product = torch.einsum('ac,kc->ak', [anchor_embeds, pos_neg_embeds])
# pos_neg_label = pos_neg_label.unsqueeze(0)
# pos_inds = (pos_neg_label == 1)
# neg_inds = (pos_neg_label == 0)
# pred_pos = dot_product * pos_inds.float()
# pred_neg = dot_product * neg_inds.float()
# # use -inf to mask out unwanted elements
# pred_pos[neg_inds] = pred_pos[neg_inds] + float('inf')
# pred_neg[pos_inds] = pred_neg[pos_inds] + float('-inf')
# _pos_expand = torch.repeat_interleave(pred_pos, dot_product.shape[1], dim=1)
# _neg_expand = pred_neg.repeat(1, dot_product.shape[1])
# x = F.pad((_neg_expand - _pos_expand), (0, 1), "constant", 0)
# try:
# contras_loss += torch.logsumexp(x, dim=1)
# valid_contras_sample += 1
# except Exception as e:
# print("x: ", x.shape)
# print("sumexp: ", torch.logsumexp(x, dim=1).shape)
# exit(0)
# num_images = data['num_images']
# batch_size = len(num_images)
# num_vprompts = data['num_vprompts']
# num_patches = data['num_patches']
# visual_prompts = data['visual_prompts']
# split_num_vprompts = torch.split(num_vprompts, num_images)
# split_num_patches = torch.split(num_patches, num_images)
# vprompt_split_size = [n_vp*n_p for n_vp , n_p in zip(num_vprompts, num_patches)]
# split_visual_prompts = torch.split(visual_prompts, vprompt_split_size)
# contras_loss = loss_dict['loss'].new_zeros((1,))
# valid_contras_sample = 0
# for bidx in range(batch_size):
# num_vprompt = split_num_vprompts[bidx]
# num_patch = split_num_patches[bidx]
# visual_prompts = split_visual_prompts[:num_images[bidx]]
# split_vit_embeds = torch.split(vit_embeds, [np for np in num_patch])
# if num_images[bidx] < 2:
# visual_prompts = visual_prompts[0]
# h, w = visual_prompts.shape[-2:]
# visual_prompts = visual_prompts.reshape(num_vprompt[0], num_patch[0], h, w)
# contras_loss += vit_embeds.sum() * 0.0
# else:
# assert num_images[bidx] == 2
# visual_prompts_1 = visual_prompts[0]
# visual_prompts_2 = visual_prompts[1]
# h, w = visual_prompts_1.shape[-2:]
# visual_prompts_1 = visual_prompts_1.reshape(num_vprompt[0], num_patch[0], h, w)
# visual_prompts_2 = visual_prompts_2.reshape(num_vprompt[1], num_patch[1], h, w)
# patch_token_edge = int(self.patch_token ** 0.5)
# visual_prompts_1 = F.interpolate(visual_prompts_1.to(vit_embeds.dtype), patch_token_edge, mode='bilinear')
# visual_prompts_2 = F.interpolate(visual_prompts_2.to(vit_embeds.dtype), patch_token_edge, mode='bilinear')
# visual_prompts_1 = (visual_prompts_1 > 0.6).to(vit_embeds.dtype)
# visual_prompts_2 = (visual_prompts_2 > 0.6).to(vit_embeds.dtype)
# visual_prompts_1 = visual_prompts_1.reshape(num_vprompt[0], -1)
# visual_prompts_2 = visual_prompts_2.reshape(num_vprompt[1], -1)
# vit_embeds_1 = split_vit_embeds[0].flatten(0, 1)
# vit_embeds_2 = split_vit_embeds[1].flatten(0, 1)
# anchor_num_tokens = torch.sum(visual_prompts_1, dim=1)
# if anchor_num_tokens.sum() < 1:
# contras_loss += vit_embeds.sum() * 0.0
# anchor_embeds = (visual_prompts_1[:, :, None] / anchor_num_tokens[:, None, None] * vit_embeds_1[None, :, :])
# anchor_embeds = torch.sum(anchor_embeds, dim=1).mean(dim=0, keepdim=True) # 1, C
# pos_num_tokens = torch.sum(visual_prompts_2[:1], dim=1)
# if pos_num_tokens.sum() < 1:
# contras_loss += vit_embeds.sum() * 0.0
# pos_embeds = (visual_prompts_2[:1, :, None] / pos_num_tokens[:, None, None] * vit_embeds_2[None, :, :])
# pos_embeds = torch.sum(pos_embeds, dim=1).mean(dim=0, keepdim=True) # 1, C
# neg_num_tokens = torch.sum(visual_prompts_2[1:], dim=1)
# valid_neg_sample = neg_num_tokens > 0
# if len(neg_num_tokens[valid_neg_sample]) < 4:
# contras_loss += vit_embeds.sum() * 0.0
# visual_prompts_2 = visual_prompts_2[1:, :][valid_neg_sample]
# neg_num_tokens = neg_num_tokens[valid_neg_sample]
# neg_embeds = (visual_prompts_2[:, :, None] / neg_num_tokens[:, None, None] * vit_embeds_2[None, :, :])
# neg_embeds =torch.sum(neg_embeds, dim=1) # N, C
# pos_neg_embeds = torch.cat([pos_embeds, neg_embeds], dim=0)
# pos_neg_label = pos_neg_embeds.new_zeros((pos_neg_embeds.shape[0], ), dtype=torch.int64)
# pos_neg_label[:1] = 1
# # dot product
# dot_product = torch.einsum('ac,kc->ak', [anchor_embeds, pos_neg_embeds])
# pos_neg_label = pos_neg_label.unsqueeze(0)
# pos_inds = (pos_neg_label == 1)
# neg_inds = (pos_neg_label == 0)
# pred_pos = dot_product * pos_inds.float()
# pred_neg = dot_product * neg_inds.float()
# # use -inf to mask out unwanted elements
# pred_pos[neg_inds] = pred_pos[neg_inds] + float('inf')
# pred_neg[pos_inds] = pred_neg[pos_inds] + float('-inf')
# _pos_expand = torch.repeat_interleave(pred_pos, dot_product.shape[1], dim=1)
# _neg_expand = pred_neg.repeat(1, dot_product.shape[1])
# x = F.pad((_neg_expand - _pos_expand), (0, 1), "constant", 0)
# try:
# contras_loss += torch.logsumexp(x, dim=1)
# valid_contras_sample += 1
# except Exception as e:
# print("x: ", x.shape)
# print("sumexp: ", torch.logsumexp(x, dim=1).shape)
# exit(0)
# if valid_contras_sample == 0 or torch.any(torch.isnan(contras_loss)) or skip_this_batch:
# loss_dict.update({"contras_loss": ot_embeds.sum() * 0.0})
# else:
# loss_dict.update({"contras_loss": contras_loss / valid_contras_sample})
# return loss_dict
elif mode == "predict":
pass
elif mode == "tensor":
pass
else:
raise NotImplementedError
def _llm_forward(
self,
vit_embeds: torch.FloatTensor,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
image_flags: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
ot_object_embeds: torch.FloatTensor = None,
vprompt_flags: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None \
else self.model.config.use_return_dict
B, N = input_ids.shape
temp_input_ids = input_ids.clone().flatten()
temp_input_ids[temp_input_ids == self.vpt_content_token_idx] = self.model.img_context_token_id
input_embeds = self.model.language_model.get_input_embeddings()(temp_input_ids.reshape(B, N)).clone()
# input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone()
vit_embeds = vit_embeds[image_flags == 1]
vit_batch_size = vit_embeds.shape[0]
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids = input_ids.reshape(B * N)
skip_this_case=False
if ot_object_embeds is not None:
try:
ot_object_embeds = ot_object_embeds[vprompt_flags > 0]
selected = (input_ids == self.vpt_content_token_idx)
input_embeds[selected] = input_embeds[selected] * 0.0 + ot_object_embeds
skip_this_case=False
except:
print(f"The number of the provided object embeds is not match with vprompt_flags or VPT_CONTENT_TOKEN.")
selected = (input_ids == self.vpt_content_token_idx)
input_embeds[selected] = input_embeds[selected] * 0.0 + ot_object_embeds.mean(dim=0, keepdim=True).to(input_embeds.dtype)
skip_this_case=True
selected = (input_ids == self.model.img_context_token_id)
try:
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, C)
print(f"warning: {e}, input_embeds[selected].shape="
f"{input_embeds[selected].shape}, "
f"vit_embeds.shape={vit_embeds.shape}")
n_token = selected.sum()
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
input_embeds = input_embeds.reshape(B, N, C)
if torch.distributed.get_rank() == 0 and self._count % 100 == 0:
print(f"dynamic ViT batch size: {vit_batch_size}, "
f"images per sample: {vit_batch_size}/B, "
f"dynamic token length: {N}")
self._count += 1
outputs = self.model.language_model(
inputs_embeds = input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs.logits
loss = None
if labels is not None:
# Shit so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.model.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits, ) + outputs[1:]
return (loss, ) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
), skip_this_case
def batch_chat(self, pixel_values, questions, merged_visual_prompts,
generation_config, num_patches_list=None, history=None, return_history=False,
):
if history is not None or return_history:
print('Now multi-turn chat is not supported in batch_chat.')
raise NotImplementedError
# tokenize and embed the text prompts
queries = []
for idx, num_patches in enumerate(num_patches_list):
question = questions[idx]
if pixel_values is not None and DEFAULT_IMAGE_TOKEN not in question:
question = DEFAULT_IMAGE_TOKEN + '\n' + question
template = get_conv_template(self.config.template)
template.append_message(template.roles[0], question)
template.append_message(template.roles[1], None)
query = template.get_prompt()
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.patch_token * num_patches + IMG_END_TOKEN
query = query.replace(DEFAULT_IMAGE_TOKEN, image_tokens, 1)
queries.append(query)
self.tokenizer.padding_side = 'left' # Important, When training the model, this attribute is setted to 'right'
model_inputs = self.tokenizer(queries, return_tensors='pt', padding=True)
input_ids = model_inputs['input_ids'].cuda()
attention_mask = model_inputs['attention_mask'].cuda()
eos_token_id = self.tokenizer.convert_tokens_to_ids(template.sep)
generation_config['eos_token_id'] = eos_token_id
input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone()
vit_embeds = self.model.extract_feature(merged_visual_prompts)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids = input_ids.reshape(B * N)
selected = (input_ids == self.model.img_context_token_id)
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
input_embeds = input_embeds.reshape(B, N, C)
# 4. generate
generate_output = self.generate(
input_embeds=input_embeds,
attention_mask=attention_mask,
**generation_config
)
responses = self.tokenizer.batch_decode(generate_output, skip_special_tokens=True)
responses = [response.split(template.sep)[0].strip() for response in responses]
return responses
@torch.no_grad()
def generate(
self,
input_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**generate_kwargs,
) -> torch.LongTensor:
outputs = self.model.language_model.generate(
inputs_embeds = input_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=True,
**generate_kwargs,
)
return outputs