|
import os |
|
import copy |
|
from collections import OrderedDict |
|
from typing import List, Optional, Tuple, Union |
|
from types import MethodType |
|
from enum import Enum |
|
import torch |
|
import torch.amp |
|
import torch.distributed |
|
import torch.nn as nn |
|
from torch.nn import CrossEntropyLoss |
|
import torch.nn.functional as F |
|
from mmengine import print_log |
|
from mmengine.config import Config, ConfigDict |
|
from mmengine.model import BaseModel |
|
from peft import get_peft_model, prepare_model_for_kbit_training |
|
from safetensors.torch import load_file |
|
from safetensors import safe_open |
|
|
|
import torch.utils |
|
import torch.utils.checkpoint |
|
from xtuner.registry import BUILDER |
|
from xtuner.model.modules import dispatch_modules |
|
from xtuner.utils import DEFAULT_IMAGE_TOKEN |
|
from transformers import AutoModel, AutoConfig, AutoTokenizer, BitsAndBytesConfig, GenerationConfig, AutoImageProcessor |
|
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput, BaseModelOutputWithPooling |
|
from transformers import ConvNextV2ForImageClassification |
|
from .utils import (LoadWoInit, traverse_dict, make_inputs_require_grad, find_all_linear_names, |
|
guess_load_checkpoint, get_peft_model_state_dict) |
|
from ..dataset.utils import (get_conv_template, IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, |
|
VPT_CONTEXT_TOKEN, VPT_START_TOKEN, VPT_END_TOKEN) |
|
|
|
|
|
class WrapInternVL(BaseModel): |
|
def __init__(self, |
|
mllm, |
|
tokenizer=None, |
|
freeze_llm=False, |
|
freeze_visual_encoder=False, |
|
freeze_connector=False, |
|
freeze_ot_mlp=False, |
|
unfreeze_vocab=False, |
|
unfreeze_lm_head=False, |
|
llm_lora=None, |
|
visual_encoder_lora=None, |
|
pretrained_pth=None, |
|
radio_pretrained_pth=None, |
|
use_activation_checkpointing=True, |
|
vocab_embeds_name="tok_embeddings", |
|
lm_head_name="output", |
|
contras_loss=False, |
|
use_object_tokens=False, |
|
object_tokenizer=None, |
|
object_tokenizer_pretrain=False, |
|
): |
|
super().__init__() |
|
|
|
self.freeze_llm = freeze_llm |
|
self.freeze_visual_encoder = freeze_visual_encoder |
|
self.freeze_connector = freeze_connector |
|
self.freeze_ot_mlp = freeze_ot_mlp |
|
self.unfreeze_vocab = unfreeze_vocab |
|
self.unfreeze_lm_head = unfreeze_lm_head |
|
self.use_llm_lora = llm_lora is not None |
|
self.use_visual_encoder_lora = visual_encoder_lora is not None |
|
self.use_activation_checkpointing=use_activation_checkpointing |
|
self.vocab_embeds_name = vocab_embeds_name |
|
self.lm_head_name = lm_head_name |
|
self.contras_loss = contras_loss |
|
self.object_tokenizer_pretrain=object_tokenizer_pretrain |
|
|
|
config = AutoConfig.from_pretrained(mllm["pretrained_model_name_or_path"], trust_remote_code=True) |
|
self.config = config |
|
if config.llm_config.model_type == 'internlm2': |
|
config.llm_config.attn_implementation = 'flash_attention_2' |
|
else: |
|
config.llm_config._attn_implementation = 'flash_attention_2' |
|
|
|
traverse_dict(mllm) |
|
model_clazz = mllm.pop('type') |
|
mllm.update(dict(config=config)) |
|
self.model = model_clazz(**mllm) |
|
self.model.language_model.config.use_cache = False |
|
dispatch_modules(self.model.language_model) |
|
|
|
if use_object_tokens: |
|
|
|
ot_config = AutoConfig.from_pretrained(object_tokenizer["pretrained_model_name_or_path"], trust_remote_code=True) |
|
self.ot_config = ot_config |
|
traverse_dict(object_tokenizer) |
|
ot_clazz = object_tokenizer.pop('type') |
|
self.object_tokenizer = ot_clazz(**object_tokenizer) |
|
if 'dinov2' in object_tokenizer["pretrained_model_name_or_path"]: |
|
ot_hidden_size = self.ot_config.hidden_size |
|
self.vfm_name = 'DINOv2' |
|
elif 'RADIO' in object_tokenizer["pretrained_model_name_or_path"]: |
|
ot_hidden_size = self.object_tokenizer.model.num_features |
|
self.vfm_name = 'RADIO' |
|
elif 'convnext' in object_tokenizer["pretrained_model_name_or_path"]: |
|
ot_hidden_size = self.ot_config.hidden_sizes[-1] |
|
self.vfm_name = "ConvNext" |
|
else: |
|
raise NotImplementedError |
|
|
|
self.ot_mlp1 = nn.Sequential( |
|
nn.LayerNorm(ot_hidden_size,), |
|
nn.Linear(ot_hidden_size, config.llm_config.hidden_size,), |
|
nn.GELU(), |
|
nn.Linear(config.llm_config.hidden_size, config.llm_config.hidden_size) |
|
) |
|
else: |
|
self.object_tokenizer = None |
|
self.ot_mlp1 = None |
|
self.ot_config = None |
|
|
|
if tokenizer is not None: |
|
self.tokenizer = self._build_from_cfg_or_module(tokenizer) |
|
else: |
|
self.tokenizer = AutoTokenizer.from_pretrained(mllm["pretrained_model_name_or_path"], trust_remote_code=True) |
|
img_context_token_id = self.tokenizer.convert_tokens_to_ids('<IMG_CONTEXT>') |
|
self.model.img_context_token_id = img_context_token_id |
|
self._add_special_tokens() |
|
|
|
if self.freeze_llm: |
|
self.model.language_model.requires_grad_(False) |
|
if self.freeze_visual_encoder: |
|
self.model.vision_model.requires_grad_(False) |
|
if self.freeze_connector: |
|
self.model.mlp1.requires_grad_(False) |
|
if self.object_tokenizer is not None: |
|
self.object_tokenizer.requires_grad_(False) |
|
if self.freeze_ot_mlp and self.ot_mlp1 is not None: |
|
self.ot_mlp1.requires_grad_(False) |
|
|
|
if use_activation_checkpointing: |
|
|
|
if hasattr(self.model.language_model, 'enable_input_require_grads'): |
|
self.model.language_model.enable_input_require_grads() |
|
else: |
|
self.model.language_model.get_input_embeddings( |
|
).register_forward_hook(make_inputs_require_grad) |
|
|
|
self.gradient_checkpointing_enable() |
|
|
|
if self.use_llm_lora: |
|
self._prepare_llm_for_lora(llm_lora) |
|
|
|
if self.unfreeze_vocab: |
|
self.model.language_model.get_input_embeddings().requires_grad_(True) |
|
if self.unfreeze_lm_head: |
|
self.model.language_model.get_output_embeddings().requires_grad_(True) |
|
if self.use_visual_encoder_lora: |
|
self._prepare_visual_encoder_for_lora(visual_encoder_lora) |
|
|
|
if pretrained_pth is not None: |
|
pretrained_state_dict = guess_load_checkpoint(pretrained_pth) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm_state_dict = {} |
|
for k, v in pretrained_state_dict.items(): |
|
if k.startswith('model.'): |
|
mllm_state_dict[k[len('model.'):]] = v |
|
if len(mllm_state_dict) != 0: |
|
self.model.load_state_dict(mllm_state_dict, strict=False) |
|
|
|
if use_object_tokens: |
|
ot_adapter_state_dict = {} |
|
if radio_pretrained_pth is not None: |
|
pretrained_state_dict = torch.load(radio_pretrained_pth, map_location='cpu') |
|
if 'state_dict' in pretrained_state_dict: |
|
pretrained_state_dict = pretrained_state_dict['state_dict'] |
|
for k, v in pretrained_state_dict.items(): |
|
if k.startswith('ot_mlp1.'): |
|
ot_adapter_state_dict[k[len('ot_mlp1.'):]] = v |
|
if len(ot_adapter_state_dict) != 0: |
|
self.ot_mlp1.load_state_dict(ot_adapter_state_dict, strict=False) |
|
|
|
for k, v in self.ot_mlp1.named_parameters(): |
|
assert v.equal(ot_adapter_state_dict[k]) |
|
|
|
|
|
|
|
print(f"Load pretrained weight from {pretrained_pth}") |
|
if radio_pretrained_pth is not None: |
|
print(f"Load pretrained weight from {radio_pretrained_pth}") |
|
|
|
self.patch_token = int( |
|
(config.force_image_size // config.vision_config.patch_size)**2 * (config.downsample_ratio**2)) |
|
|
|
self._count = 0 |
|
print_log(self, logger="current") |
|
print_log('InternVL_V1_5 construction is complete', logger='current') |
|
|
|
def _merge_lora(self): |
|
|
|
try: |
|
self.model.language_model = self.model.language_model.merge_and_unload() |
|
except: |
|
print("Skip language model, no LoRA in it !!!") |
|
try: |
|
self.model.vision_model = self.model.vision_model.merge_and_unload() |
|
except: |
|
print("Skip vision encoder, no LoRA in it !!!") |
|
|
|
return |
|
|
|
|
|
def _add_special_tokens(self): |
|
assert hasattr(self, "tokenizer") |
|
|
|
special_tokens = [VPT_CONTEXT_TOKEN, ] |
|
num_new_tokens = self.tokenizer.add_tokens(special_tokens, special_tokens=True) |
|
print_log(f"Added {num_new_tokens} special tokens.") |
|
|
|
self.vpt_content_token_idx = self.tokenizer(VPT_CONTEXT_TOKEN, add_special_tokens=False).input_ids[0] |
|
|
|
def _build_from_cfg_or_module(self, cfg_or_mod): |
|
if isinstance(cfg_or_mod, nn.Module): |
|
return cfg_or_mod |
|
elif isinstance(cfg_or_mod, dict): |
|
traverse_dict(cfg_or_mod) |
|
return BUILDER.build(cfg_or_mod) |
|
else: |
|
raise NotImplementedError |
|
|
|
def _parse_lora_config(self, lora_config): |
|
if isinstance(lora_config, dict) or isinstance( |
|
lora_config, Config) or isinstance(lora_config, ConfigDict): |
|
lora_config = BUILDER.build(lora_config) |
|
return lora_config |
|
|
|
def _prepare_llm_for_lora(self, lora_config, use_activation_checkpointing=True): |
|
lora_config = self._parse_lora_config(lora_config) |
|
if self.config.llm_config.architectures[0] == 'InternLM2ForCausalLM': |
|
target_modules = ['attention.wqkv', 'attention.wo', 'feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3'] |
|
elif self.config.llm_config.architectures[0] == 'Phi3ForCausalLM': |
|
target_modules = ['mlp.down_proj', 'mlp.gate_up_proj', 'self_attn.o_proj', 'self_attn.qkv_proj'] |
|
elif self.config.llm_config.architectures[0] in ['Qwen2ForCausalLM', 'LlamaForCausalLM']: |
|
target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj', |
|
'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'] |
|
else: |
|
raise NotImplementedError |
|
lora_config.target_modules = target_modules |
|
self.model.language_model = get_peft_model(self.model.language_model, lora_config) |
|
|
|
def _prepare_visual_encoder_for_lora(self, lora_config): |
|
lora_config = self._parse_lora_config(lora_config) |
|
if lora_config.target_modules is None: |
|
modules = find_all_linear_names(self.model.vision_model) |
|
lora_config.target_modules = modules |
|
if self.interaction_indexes is not None: |
|
modules = [] |
|
for block in self.interaction_indexes: |
|
for layer_id in range(block[0], block[1]+1): |
|
modules.extend([f'{layer_id}.{elem}' for elem in ['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2']]) |
|
lora_config.target_modules = modules |
|
|
|
self.model.vision_model = get_peft_model(self.model.vision_model, lora_config) |
|
|
|
def gradient_checkpointing_enable(self): |
|
self.activation_checkpointing_enable() |
|
|
|
def activation_checkpointing_enable(self): |
|
self.model.language_model.gradient_checkpointing_enable() |
|
|
|
def gradient_checkpointing_disable(self): |
|
self.activation_checkpointing_disable() |
|
|
|
def activation_checkpointing_disable(self): |
|
self.model.language_model.gradient_checkpointing_disable() |
|
|
|
def all_state_dict(self, *args, **kwargs): |
|
state_dict = super().state_dict(*args, **kwargs) |
|
return state_dict |
|
|
|
def state_dict(self, *args, **kwargs): |
|
state_dict = super().state_dict(*args, **kwargs) |
|
to_return = OrderedDict() |
|
|
|
|
|
if self.use_visual_encoder_lora: |
|
to_return.update( |
|
get_peft_model_state_dict( |
|
self.model.vision_model, state_dict=state_dict)) |
|
elif not self.freeze_visual_encoder: |
|
to_return.update({ |
|
k: v |
|
for k, v in state_dict.items() if 'model.vision_model.' in k |
|
}) |
|
|
|
if self.use_llm_lora: |
|
to_return.update( |
|
get_peft_model_state_dict( |
|
self.model.language_model, state_dict=state_dict)) |
|
elif not self.freeze_llm: |
|
to_return.update({ |
|
k: v |
|
for k, v in state_dict.items() if 'model.language_model.' |
|
}) |
|
|
|
if not self.freeze_connector: |
|
to_return.update( |
|
{k: v |
|
for k, v in state_dict.items() if 'model.mlp1.' in k}) |
|
o |
|
|
|
if not self.freeze_ot_mlp: |
|
to_return.update({k: v for k, v in state_dict.items() if 'ot_mlp1.' in k}) |
|
|
|
|
|
|
|
|
|
return to_return |
|
|
|
def init_weights(self): |
|
pass |
|
|
|
def forward(self, data, data_samples=None, mode='loss'): |
|
|
|
|
|
|
|
|
|
pixel_values = data['pixel_values'].to(self.model.vision_model.dtype) |
|
merged_visual_prompts = data['merged_visual_prompts'].to(self.model.vision_model.dtype) |
|
num_patches = data['num_patches'] |
|
has_ot_input = data['ot_pixel_values'] is not None |
|
|
|
vit_embeds = self.model.extract_feature(merged_visual_prompts) |
|
if has_ot_input and self.object_tokenizer: |
|
ot_pixel_values = data['ot_pixel_values'].to(self.object_tokenizer.dtype) |
|
ot_h, ot_w = ot_pixel_values.shape[-2:] |
|
if self.vfm_name == "RADIO": |
|
summary, ot_embeds = self.object_tokenizer(ot_pixel_values) |
|
ot_num_tokens_h, ot_num_tokens_w = ot_h // self.ot_config.patch_size, ot_w // self.ot_config.patch_size |
|
elif self.vfm_name == "DINOv2": |
|
ot_outputs = self.object_tokenizer(pixel_values=ot_pixel_values) |
|
ot_embeds = ot_outputs.last_hidden_state[:, 1:, :] |
|
ot_num_tokens_h, ot_num_tokens_w = ot_h // self.ot_config.patch_size, ot_w // self.ot_config.patch_size |
|
elif self.vfm_name == "ConvNext": |
|
ot_outputs = self.object_tokenizer.convnextv2(pixel_values=ot_pixel_values, |
|
output_hidden_states=True, |
|
return_dict=True) |
|
ot_embeds = ot_outputs.last_hidden_state |
|
ot_num_tokens_h, ot_num_tokens_w = ot_embeds.shape[-2:] |
|
ot_embeds = ot_embeds.flatten(2, 3).transpose(1, 2) |
|
with torch.amp.autocast(device_type='cuda', dtype=self.model.vision_model.dtype): |
|
ot_embeds = self.ot_mlp1(ot_embeds) |
|
|
|
|
|
if self.object_tokenizer_pretrain: |
|
region_ids = data['region_ids'] |
|
|
|
|
|
num_images = data['num_images'] |
|
batch_size = len(num_images) |
|
num_vprompts = data['num_vprompts'] |
|
num_patches = data['num_patches'] |
|
visual_prompts = data['visual_prompts'] |
|
split_num_vprompts = torch.split(num_vprompts, [nimg for nimg in num_images]) |
|
split_num_patches = torch.split(num_patches, [nimg for nimg in num_images]) |
|
vprompt_split_size = [n_vp*n_p for n_vp , n_p in zip(num_vprompts, num_patches)] |
|
split_visual_prompts = torch.split(visual_prompts, vprompt_split_size) |
|
split_vit_embeds = torch.split(vit_embeds, [np for np in num_patches]) |
|
|
|
object_embeds_in_batch = [] |
|
valid_flag_in_batch = [] |
|
start_idx = 0 |
|
for bidx in range(batch_size): |
|
num_vprompt = split_num_vprompts[bidx] |
|
num_patch = split_num_patches[bidx] |
|
visual_prompts_bi = split_visual_prompts[start_idx:start_idx+num_images[bidx]] |
|
split_vit_embeds_bi = split_vit_embeds[start_idx:start_idx+num_images[bidx]] |
|
start_idx = start_idx + num_images[bidx] |
|
|
|
object_embed_list, valid_flag_list = [], [] |
|
for fidx, visual_prompts_fi in enumerate(visual_prompts_bi): |
|
h, w = visual_prompts_fi.shape[-2:] |
|
visual_prompts_fi = visual_prompts_fi.reshape(num_vprompt[fidx], num_patch[fidx], h, w) |
|
patch_token_edge = int(self.patch_token ** 0.5) |
|
visual_prompts_fi = F.interpolate(visual_prompts_fi.to(vit_embeds.dtype), patch_token_edge, mode='bilinear') |
|
visual_prompts_fi = (visual_prompts_fi > 0.55).to(vit_embeds.dtype) |
|
visual_prompts_fi = visual_prompts_fi.reshape(num_vprompt[fidx], -1) |
|
|
|
num_vp_tokens = torch.sum(visual_prompts_fi, dim=-1, keepdim=False) |
|
valid_flag = num_vp_tokens > 0 |
|
|
|
vit_embeds_fi = split_vit_embeds_bi[fidx].flatten(0, 1) |
|
object_embeds = (visual_prompts_fi[:, :, None] / (num_vp_tokens[:, None, None] + 1e-4) * vit_embeds_fi[None, :, :]) |
|
object_embeds = torch.sum(object_embeds, dim=1) |
|
|
|
object_embed_list.append(object_embeds) |
|
valid_flag_list.append(valid_flag) |
|
|
|
object_embeds_in_batch.append(object_embed_list) |
|
valid_flag_in_batch.append(valid_flag_list) |
|
|
|
|
|
ot_visual_prompts = data['ot_visual_prompts'] |
|
split_ot_visual_prompts = torch.split(ot_visual_prompts, [nvp for nvp in num_vprompts]) |
|
|
|
ot_object_embeds_in_batch = [] |
|
ot_valid_flag_in_batch = [] |
|
start_idx = 0 |
|
for bidx in range(batch_size): |
|
num_vprompt = split_num_vprompts[bidx] |
|
ot_visual_prompts_bi = split_ot_visual_prompts[start_idx:start_idx+num_images[bidx]] |
|
ot_embeds_bi = ot_embeds[start_idx:start_idx+num_images[bidx]] |
|
start_idx = start_idx + num_images[bidx] |
|
|
|
object_embed_list, valid_flag_list = [], [] |
|
for fidx, ot_visual_prompts_fi in enumerate(ot_visual_prompts_bi): |
|
h, w = ot_visual_prompts_fi.shape[-2:] |
|
|
|
ot_visual_prompts_fi = ot_visual_prompts_fi[:, None, :, :] |
|
ot_visual_prompts_fi = F.interpolate(ot_visual_prompts_fi.to(ot_embeds.dtype), (ot_num_tokens_h, ot_num_tokens_w), mode="bilinear") |
|
ot_visual_prompts_fi = (ot_visual_prompts_fi > 0.55).to(ot_embeds.dtype) |
|
ot_visual_prompts_fi = ot_visual_prompts_fi.reshape(num_vprompt[fidx], -1) |
|
|
|
num_vp_tokens = torch.sum(ot_visual_prompts_fi, dim=-1, keepdim=False) |
|
valid_flag = num_vp_tokens > 0 |
|
|
|
ot_embeds_fi = ot_embeds_bi[fidx] |
|
object_embeds = (ot_visual_prompts_fi[:, :, None] / (num_vp_tokens[:, None, None] + 1e-4) * ot_embeds_fi[None, :, :]) |
|
object_embeds = torch.sum(object_embeds, dim=1) |
|
|
|
object_embed_list.append(object_embeds) |
|
valid_flag_list.append(valid_flag) |
|
ot_object_embeds_in_batch.append(object_embed_list) |
|
ot_valid_flag_in_batch.append(valid_flag_list) |
|
|
|
|
|
contras_loss = torch.zeros(size=(1, ), dtype=torch.float32).cuda() |
|
|
|
|
|
valid_contras_sample = 0 |
|
for bidx in range(batch_size): |
|
region_ids_bi = region_ids[bidx] |
|
object_embeds_bi = object_embeds_in_batch[bidx] |
|
ot_object_embeds_bi = ot_object_embeds_in_batch[bidx] |
|
valid_flags_bi = valid_flag_in_batch[bidx] |
|
ot_valid_flags_bi = ot_valid_flag_in_batch[bidx] |
|
|
|
for ot_object_embeds, object_embeds, ot_region_ids, _region_ids, ot_valid_flags, valid_flags in zip( |
|
ot_object_embeds_bi, object_embeds_bi[::-1], |
|
region_ids_bi, region_ids_bi[::-1], |
|
ot_valid_flags_bi, valid_flags_bi[::-1], |
|
): |
|
region_id_to_indices = {region_id: idx for idx, region_id in enumerate(_region_ids)} |
|
for anchor_embed, valid_flag, region_id in zip(ot_object_embeds, ot_valid_flags, ot_region_ids): |
|
if not valid_flag: |
|
continue |
|
anchor_embed = anchor_embed.unsqueeze(0) |
|
pos_idx = region_id_to_indices[region_id] |
|
if not valid_flags[pos_idx]: |
|
continue |
|
pos_embed = object_embeds[pos_idx].unsqueeze(0) |
|
if pos_idx == 0: |
|
neg_embeds = object_embeds[1:, :][valid_flags[1:]] |
|
elif pos_idx == (len(object_embeds) - 1): |
|
neg_embeds = object_embeds[:-1, :][valid_flags[:-1]] |
|
else: |
|
neg_embeds = torch.cat([ |
|
object_embeds[:pos_idx, :][valid_flags[:pos_idx]], |
|
object_embeds[pos_idx+1:, :][valid_flags[pos_idx+1:]] |
|
], dim=0) |
|
|
|
pos_neg_embeds = torch.cat([pos_embed, neg_embeds], dim=0) |
|
pos_neg_label = pos_neg_embeds.new_zeros((pos_neg_embeds.shape[0], ), dtype=torch.int64) |
|
pos_neg_label[:1] = 1 |
|
|
|
|
|
dot_product = torch.einsum('ac,kc->ak', [anchor_embed, pos_neg_embeds]) |
|
pos_neg_label = pos_neg_label.unsqueeze(0) |
|
pos_inds = (pos_neg_label == 1) |
|
neg_inds = (pos_neg_label == 0) |
|
pred_pos = dot_product * pos_inds.float() |
|
pred_neg = dot_product * neg_inds.float() |
|
|
|
pred_pos[neg_inds] = pred_pos[neg_inds] + float('inf') |
|
pred_neg[pos_inds] = pred_neg[pos_inds] + float('-inf') |
|
|
|
_pos_expand = torch.repeat_interleave(pred_pos, dot_product.shape[1], dim=1) |
|
_neg_expand = pred_neg.repeat(1, dot_product.shape[1]) |
|
x = F.pad((_neg_expand - _pos_expand), (0, 1), "constant", 0) |
|
try: |
|
contras_loss += torch.logsumexp(x, dim=1) |
|
valid_contras_sample += 1 |
|
except Exception as e: |
|
print("x: ", x.shape) |
|
print("sumexp: ", torch.logsumexp(x, dim=1).shape) |
|
exit(0) |
|
if valid_contras_sample == 0 or torch.any(torch.isnan(contras_loss)): |
|
loss_dict = {"loss": ot_embeds.sum() * 0.0} |
|
else: |
|
loss_dict = {"loss": contras_loss / valid_contras_sample} |
|
return loss_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ot_object_embeds = None |
|
vprompt_flags = data['vprompt_flags'] |
|
ot_object_embeds_in_batch = [] |
|
skip_this_batch = False |
|
if has_ot_input and self.object_tokenizer: |
|
|
|
num_vprompts = data['num_vprompts'] |
|
num_images = data['num_images'] |
|
batch_size = len(num_images) |
|
ot_visual_prompts = data['ot_visual_prompts'] |
|
|
|
try: |
|
split_ot_visual_prompts = torch.split(ot_visual_prompts, [nvp for nvp in num_vprompts]) |
|
except: |
|
nvp_list = [1 for nvp in num_vprompts] |
|
if ot_visual_prompts.shape[0] >= len(nvp_list): |
|
split_ot_visual_prompts = torch.split(ot_visual_prompts[:len(nvp_list)], nvp_list) |
|
else: |
|
split_ot_visual_prompts = torch.stack([ot_visual_prompts[0] for nvp in nvp_list]) |
|
num_vprompts = torch.tensor(nvp_list).to(num_vprompts.dtype).to(num_vprompts.device) |
|
skip_this_batch = True |
|
split_num_vprompts = torch.split(num_vprompts, [nimg for nimg in num_images]) |
|
|
|
start_idx = 0 |
|
for bidx in range(batch_size): |
|
num_vprompt = split_num_vprompts[bidx] |
|
ot_visual_prompts_bi = split_ot_visual_prompts[start_idx:start_idx+num_images[bidx]] |
|
ot_embeds_bi = ot_embeds[start_idx:start_idx+num_images[bidx]] |
|
start_idx = start_idx + num_images[bidx] |
|
|
|
ot_object_embeds_list = [] |
|
for fidx, ot_visual_prompts_fi in enumerate(ot_visual_prompts_bi): |
|
h, w = ot_visual_prompts_fi.shape[-2:] |
|
ot_visual_prompts_fi = ot_visual_prompts_fi.reshape(num_vprompt[fidx], 1, h, w) |
|
|
|
ot_visual_prompts_fi = F.interpolate(ot_visual_prompts_fi.to(ot_embeds.dtype), (ot_num_tokens_h, ot_num_tokens_w), mode="bilinear") |
|
ot_visual_prompts_fi = (ot_visual_prompts_fi > 0.5).to(ot_embeds.dtype) |
|
ot_visual_prompts_fi = ot_visual_prompts_fi.reshape(num_vprompt[fidx], -1) |
|
|
|
num_vp_tokens = torch.sum(ot_visual_prompts_fi, dim=-1, keepdim=False) |
|
ot_embeds_fi = ot_embeds_bi[fidx] |
|
object_embeds = (ot_visual_prompts_fi[:, :, None] / (num_vp_tokens[:, None, None] + 1e-4) * ot_embeds_fi[None, :, :]) |
|
object_embeds = torch.sum(object_embeds, dim=1) |
|
ot_object_embeds_list.append(object_embeds) |
|
ot_object_embeds_in_batch.append(ot_object_embeds_list) |
|
ot_object_embeds = [] |
|
for ele in ot_object_embeds_in_batch: |
|
ot_object_embeds.extend(ele) |
|
ot_object_embeds = torch.cat(ot_object_embeds, dim=0) |
|
|
|
if mode == "loss": |
|
input_ids = data['input_ids'] |
|
position_ids = data['position_ids'] |
|
attention_mask = data['attention_mask'] |
|
image_flags = data['image_flags'] |
|
|
|
labels = data['labels'] |
|
use_cache = False |
|
|
|
outputs, _skip_this_case = self._llm_forward( |
|
input_ids=input_ids, |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
image_flags=image_flags, |
|
labels=labels, |
|
use_cache=use_cache, |
|
vit_embeds=vit_embeds, |
|
ot_object_embeds=ot_object_embeds, |
|
vprompt_flags=vprompt_flags, |
|
) |
|
|
|
if skip_this_batch or _skip_this_case: |
|
print("skip this batch!") |
|
loss_dict = {'loss': outputs.loss * 0.0} |
|
else: |
|
loss_dict = {'loss': outputs.loss} |
|
if not self.contras_loss: |
|
return loss_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif mode == "predict": |
|
pass |
|
elif mode == "tensor": |
|
pass |
|
else: |
|
raise NotImplementedError |
|
|
|
def _llm_forward( |
|
self, |
|
vit_embeds: torch.FloatTensor, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
image_flags: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
ot_object_embeds: torch.FloatTensor = None, |
|
vprompt_flags: Optional[torch.LongTensor] = None, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
return_dict = return_dict if return_dict is not None \ |
|
else self.model.config.use_return_dict |
|
|
|
B, N = input_ids.shape |
|
temp_input_ids = input_ids.clone().flatten() |
|
temp_input_ids[temp_input_ids == self.vpt_content_token_idx] = self.model.img_context_token_id |
|
input_embeds = self.model.language_model.get_input_embeddings()(temp_input_ids.reshape(B, N)).clone() |
|
|
|
|
|
|
|
vit_embeds = vit_embeds[image_flags == 1] |
|
vit_batch_size = vit_embeds.shape[0] |
|
|
|
B, N, C = input_embeds.shape |
|
input_embeds = input_embeds.reshape(B * N, C) |
|
input_ids = input_ids.reshape(B * N) |
|
|
|
skip_this_case=False |
|
if ot_object_embeds is not None: |
|
try: |
|
ot_object_embeds = ot_object_embeds[vprompt_flags > 0] |
|
selected = (input_ids == self.vpt_content_token_idx) |
|
input_embeds[selected] = input_embeds[selected] * 0.0 + ot_object_embeds |
|
skip_this_case=False |
|
except: |
|
print(f"The number of the provided object embeds is not match with vprompt_flags or VPT_CONTENT_TOKEN.") |
|
selected = (input_ids == self.vpt_content_token_idx) |
|
input_embeds[selected] = input_embeds[selected] * 0.0 + ot_object_embeds.mean(dim=0, keepdim=True).to(input_embeds.dtype) |
|
skip_this_case=True |
|
|
|
selected = (input_ids == self.model.img_context_token_id) |
|
try: |
|
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) |
|
except Exception as e: |
|
vit_embeds = vit_embeds.reshape(-1, C) |
|
print(f"warning: {e}, input_embeds[selected].shape=" |
|
f"{input_embeds[selected].shape}, " |
|
f"vit_embeds.shape={vit_embeds.shape}") |
|
n_token = selected.sum() |
|
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token] |
|
input_embeds = input_embeds.reshape(B, N, C) |
|
|
|
if torch.distributed.get_rank() == 0 and self._count % 100 == 0: |
|
print(f"dynamic ViT batch size: {vit_batch_size}, " |
|
f"images per sample: {vit_batch_size}/B, " |
|
f"dynamic token length: {N}") |
|
self._count += 1 |
|
|
|
outputs = self.model.language_model( |
|
inputs_embeds = input_embeds, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
logits = outputs.logits |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = CrossEntropyLoss() |
|
shift_logits = shift_logits.view(-1, self.model.language_model.config.vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
if not return_dict: |
|
output = (logits, ) + outputs[1:] |
|
return (loss, ) + output if loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
), skip_this_case |
|
|
|
def batch_chat(self, pixel_values, questions, merged_visual_prompts, |
|
generation_config, num_patches_list=None, history=None, return_history=False, |
|
): |
|
if history is not None or return_history: |
|
print('Now multi-turn chat is not supported in batch_chat.') |
|
raise NotImplementedError |
|
|
|
|
|
queries = [] |
|
for idx, num_patches in enumerate(num_patches_list): |
|
question = questions[idx] |
|
if pixel_values is not None and DEFAULT_IMAGE_TOKEN not in question: |
|
question = DEFAULT_IMAGE_TOKEN + '\n' + question |
|
template = get_conv_template(self.config.template) |
|
template.append_message(template.roles[0], question) |
|
template.append_message(template.roles[1], None) |
|
query = template.get_prompt() |
|
|
|
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.patch_token * num_patches + IMG_END_TOKEN |
|
query = query.replace(DEFAULT_IMAGE_TOKEN, image_tokens, 1) |
|
queries.append(query) |
|
|
|
self.tokenizer.padding_side = 'left' |
|
model_inputs = self.tokenizer(queries, return_tensors='pt', padding=True) |
|
input_ids = model_inputs['input_ids'].cuda() |
|
attention_mask = model_inputs['attention_mask'].cuda() |
|
eos_token_id = self.tokenizer.convert_tokens_to_ids(template.sep) |
|
generation_config['eos_token_id'] = eos_token_id |
|
|
|
input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone() |
|
|
|
vit_embeds = self.model.extract_feature(merged_visual_prompts) |
|
|
|
B, N, C = input_embeds.shape |
|
input_embeds = input_embeds.reshape(B * N, C) |
|
input_ids = input_ids.reshape(B * N) |
|
|
|
selected = (input_ids == self.model.img_context_token_id) |
|
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) |
|
input_embeds = input_embeds.reshape(B, N, C) |
|
|
|
|
|
generate_output = self.generate( |
|
input_embeds=input_embeds, |
|
attention_mask=attention_mask, |
|
**generation_config |
|
) |
|
|
|
responses = self.tokenizer.batch_decode(generate_output, skip_special_tokens=True) |
|
responses = [response.split(template.sep)[0].strip() for response in responses] |
|
|
|
return responses |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
input_embeds: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
generation_config: Optional[GenerationConfig] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**generate_kwargs, |
|
) -> torch.LongTensor: |
|
outputs = self.model.language_model.generate( |
|
inputs_embeds = input_embeds, |
|
attention_mask=attention_mask, |
|
generation_config=generation_config, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
use_cache=True, |
|
**generate_kwargs, |
|
) |
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|