|
import os |
|
import copy |
|
from collections import OrderedDict |
|
from typing import List, Optional, Tuple, Union |
|
from types import MethodType |
|
from enum import Enum |
|
import torch |
|
import torch.amp |
|
import torch.distributed |
|
import torch.nn as nn |
|
from torch.nn import CrossEntropyLoss |
|
import torch.nn.functional as F |
|
from mmengine import print_log |
|
from mmengine.config import Config, ConfigDict |
|
from mmengine.model import BaseModel |
|
from peft import get_peft_model, prepare_model_for_kbit_training |
|
from safetensors.torch import load_file |
|
from safetensors import safe_open |
|
from accelerate import init_empty_weights |
|
|
|
import torch.utils |
|
import torch.utils.checkpoint |
|
from xtuner.registry import BUILDER |
|
from xtuner.model.modules import dispatch_modules |
|
from xtuner.utils import DEFAULT_IMAGE_TOKEN |
|
from transformers import (AutoModel, AutoConfig, AutoTokenizer, BitsAndBytesConfig, |
|
GenerationConfig, AutoImageProcessor, AutoProcessor) |
|
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutput, BaseModelOutputWithPooling |
|
from transformers.modeling_outputs import BaseModelOutputWithPast |
|
from .utils import (LoadWoInit, traverse_dict, make_inputs_require_grad, find_all_linear_names, |
|
guess_load_checkpoint, get_peft_model_state_dict) |
|
from ..dataset.utils import (get_conv_template, IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, |
|
VPT_CONTEXT_TOKEN, VPT_START_TOKEN, VPT_END_TOKEN) |
|
|
|
class WrapQwen2VL(BaseModel): |
|
def __init__(self, |
|
mllm, |
|
freeze_llm=False, |
|
freeze_visual_encoder=False, |
|
freeze_connector=False, |
|
freeze_ot_mlp=False, |
|
unfreeze_vocab=False, |
|
unfreeze_lm_head=False, |
|
llm_lora=None, |
|
visual_encoder_lora=None, |
|
pretrained_pth=None, |
|
use_activation_checkpointing=True, |
|
vocab_embeds_name="tok_embeddings", |
|
lm_head_name="output", |
|
contras_loss=False, |
|
use_object_tokens=False, |
|
object_tokenizer=None, |
|
object_tokenizer_pretrain=False, |
|
): |
|
super().__init__() |
|
|
|
self.freeze_llm = freeze_llm |
|
self.freeze_visual_encoder = freeze_visual_encoder |
|
self.freeze_connector = freeze_connector |
|
self.freeze_ot_mlp = freeze_ot_mlp |
|
self.unfreeze_vocab = unfreeze_vocab |
|
self.unfreeze_lm_head = unfreeze_lm_head |
|
self.use_llm_lora = llm_lora is not None |
|
self.use_visual_encoder_lora = visual_encoder_lora is not None |
|
self.use_activation_checkpointing=use_activation_checkpointing |
|
self.vocab_embeds_name = vocab_embeds_name |
|
self.lm_head_name = lm_head_name |
|
self.contras_loss = contras_loss |
|
self.object_tokenizer_pretrain=object_tokenizer_pretrain |
|
|
|
traverse_dict(mllm) |
|
model_clazz = mllm.pop('type') |
|
self.model = model_clazz(**mllm) |
|
self.model.model.config.use_cache = False |
|
dispatch_modules(self.model.model) |
|
|
|
self.model.model.forward = MethodType(Qwen2VLModel_forward, self.model.model) |
|
|
|
if use_object_tokens: |
|
|
|
ot_config = AutoConfig.from_pretrained(object_tokenizer["pretrained_model_name_or_path"], trust_remote_code=True) |
|
self.ot_config = ot_config |
|
traverse_dict(object_tokenizer) |
|
ot_clazz = object_tokenizer.pop('type') |
|
self.object_tokenizer = ot_clazz(**object_tokenizer) |
|
ot_hidden_size = self.object_tokenizer.model.num_features |
|
llm_hidden_size = self.model.model.config.hidden_size |
|
self.ot_mlp1 = nn.Sequential( |
|
nn.LayerNorm(ot_hidden_size,), |
|
nn.Linear(ot_hidden_size, llm_hidden_size,), |
|
nn.GELU(), |
|
nn.Linear(llm_hidden_size, llm_hidden_size) |
|
) |
|
else: |
|
self.object_tokenizer = None |
|
self.ot_mlp1 = None |
|
self.ot_config = None |
|
|
|
self.processor = AutoProcessor.from_pretrained(mllm["pretrained_model_name_or_path"]) |
|
|
|
self._add_special_tokens() |
|
|
|
if self.freeze_llm: |
|
self.model.model.requires_grad_(False) |
|
if self.freeze_visual_encoder: |
|
assert self.freeze_connector |
|
self.model.visual.requires_grad_(False) |
|
if self.object_tokenizer is not None: |
|
self.object_tokenizer.requires_grad_(False) |
|
if self.freeze_ot_mlp and self.ot_mlp1 is not None: |
|
self.ot_mlp1.requires_grad_(False) |
|
|
|
if use_activation_checkpointing: |
|
|
|
if hasattr(self.model.model, 'enable_input_require_grads'): |
|
self.model.model.enable_input_require_grads() |
|
else: |
|
self.model.model.get_input_embeddings( |
|
).register_forward_hook(make_inputs_require_grad) |
|
|
|
self.gradient_checkpointing_enable() |
|
|
|
if self.use_llm_lora: |
|
self._prepare_llm_for_lora(llm_lora) |
|
|
|
if self.unfreeze_vocab: |
|
self.model.get_input_embeddings().requires_grad_(True) |
|
else: |
|
self.model.get_input_embeddings().requires_grad_(False) |
|
if self.unfreeze_lm_head: |
|
self.model.get_output_embeddings().requires_grad_(True) |
|
else: |
|
self.model.get_output_embeddings().requires_grad_(False) |
|
|
|
if pretrained_pth is not None: |
|
pretrained_state_dict = guess_load_checkpoint(pretrained_pth) |
|
|
|
mllm_state_dict = {} |
|
for k, v in pretrained_state_dict.items(): |
|
if k.startswith('model.'): |
|
mllm_state_dict[k[len('model.'):]] = v |
|
if len(mllm_state_dict) != 0: |
|
self.model.load_state_dict(mllm_state_dict, strict=False) |
|
|
|
if use_object_tokens: |
|
ot_adapter_state_dict = {} |
|
for k, v in pretrained_state_dict.items(): |
|
if k.startswith('ot_mlp1.'): |
|
ot_adapter_state_dict[k[len('ot_mlp1.'):]] = v |
|
if len(ot_adapter_state_dict) != 0: |
|
self.ot_mlp1.load_state_dict(ot_adapter_state_dict, strict=False) |
|
|
|
for k, v in self.ot_mlp1.named_parameters(): |
|
assert v.equal(ot_adapter_state_dict[k]) |
|
|
|
print(f"Load pretrained weight from {pretrained_pth}") |
|
|
|
self._count = 0 |
|
print_log(self, logger="current") |
|
print_log('Qwen2-VL construction is complete', logger='current') |
|
|
|
def _add_special_tokens(self): |
|
assert hasattr(self, "processor") |
|
|
|
special_tokens = [VPT_CONTEXT_TOKEN, ] |
|
num_new_tokens = self.processor.tokenizer.add_tokens(special_tokens, special_tokens=True) |
|
print_log(f"Added {num_new_tokens} special tokens.") |
|
|
|
self.vpt_content_token_idx = self.processor.tokenizer(VPT_CONTEXT_TOKEN, add_special_tokens=False).input_ids[0] |
|
image_token = "<|image_pad|>" if not hasattr(self.processor.tokenizer, "image_token") else self.processor.tokenizer.image_token |
|
self.img_context_token_idx = self.processor.tokenizer(image_token, add_special_tokens=False).input_ids[0] |
|
|
|
def _parse_lora_config(self, lora_config): |
|
if isinstance(lora_config, dict) or isinstance( |
|
lora_config, Config) or isinstance(lora_config, ConfigDict): |
|
lora_config = BUILDER.build(lora_config) |
|
return lora_config |
|
|
|
def _prepare_llm_for_lora(self, lora_config, use_activation_checkpointing=True): |
|
lora_config = self._parse_lora_config(lora_config) |
|
self.model.model = prepare_model_for_kbit_training(self.model.model, use_activation_checkpointing) |
|
if lora_config.target_modules is None: |
|
modules = find_all_linear_names(self.model.model) |
|
lora_config.target_modules = modules |
|
|
|
self.model.model = get_peft_model(self.model.model, lora_config) |
|
|
|
def gradient_checkpointing_enable(self): |
|
self.activation_checkpointing_enable() |
|
|
|
def activation_checkpointing_enable(self): |
|
self.model.model.gradient_checkpointing_enable() |
|
|
|
def gradient_checkpointing_disable(self): |
|
self.activation_checkpointing_disable() |
|
|
|
def activation_checkpointing_disable(self): |
|
self.model.model.gradient_checkpointing_disable() |
|
|
|
def state_dict(self, *args, **kwargs): |
|
state_dict = super().state_dict(*args, **kwargs) |
|
to_return = OrderedDict() |
|
|
|
|
|
if self.use_visual_encoder_lora: |
|
to_return.update( |
|
get_peft_model_state_dict( |
|
self.model.visual, state_dict=state_dict)) |
|
elif not self.freeze_visual_encoder: |
|
to_return.update({ |
|
k: v |
|
for k, v in state_dict.items() if 'model.visual.' in k |
|
}) |
|
|
|
if self.use_llm_lora: |
|
to_return.update( |
|
get_peft_model_state_dict( |
|
self.model.model, state_dict=state_dict)) |
|
elif not self.freeze_llm: |
|
to_return.update({ |
|
k: v |
|
for k, v in state_dict.items() if 'model.model.' |
|
}) |
|
|
|
|
|
if not self.freeze_ot_mlp: |
|
to_return.update({k: v for k, v in state_dict.items() if 'ot_mlp1.' in k}) |
|
|
|
|
|
|
|
|
|
return to_return |
|
|
|
def init_weights(self): |
|
pass |
|
|
|
def forward(self, data, data_samples=None, mode='loss'): |
|
|
|
|
|
|
|
|
|
pixel_values = data['pixel_values'].to(self.model.visual.dtype) |
|
merged_visual_prompts = data['merged_visual_prompts'].to(self.model.visual.dtype) |
|
has_ot_input = data['ot_pixel_values'] is not None |
|
image_grid_thw = data['image_grid_thw'] |
|
|
|
vit_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw) |
|
|
|
if has_ot_input and self.object_tokenizer: |
|
ot_pixel_values = data['ot_pixel_values'].to(self.object_tokenizer.dtype) |
|
ot_h, ot_w = ot_pixel_values.shape[-2:] |
|
ot_num_tokens_h, ot_num_tokens_w = ot_h // self.ot_config.patch_size, ot_w // self.ot_config.patch_size |
|
summary, ot_embeds = self.object_tokenizer(ot_pixel_values) |
|
with torch.amp.autocast(device_type='cuda', dtype=self.model.visual.dtype): |
|
ot_embeds = self.ot_mlp1(ot_embeds) |
|
|
|
if self.object_tokenizer_pretrain: |
|
region_ids = data['region_ids'] |
|
|
|
num_images = data['num_images'] |
|
batch_size = len(num_images) |
|
num_vprompts = data['num_vprompts'] |
|
visual_prompts = data['visual_prompts'] |
|
image_grid_thw = data['image_grid_thw'] |
|
merge_length = self.processor.image_processor.merge_size ** 2 |
|
image_num_tokens = image_grid_thw[:, 0] * image_grid_thw[:, 1] * image_grid_thw[:, 2] // merge_length |
|
split_vit_embeds = torch.split(vit_embeds, [num_tokens for num_tokens in image_num_tokens]) |
|
split_num_vprompts = torch.split(num_vprompts, [num_img for num_img in num_images]) |
|
|
|
object_embeds_in_batch = [] |
|
valid_flag_in_batch = [] |
|
start_idx = 0 |
|
for bidx in range(batch_size): |
|
num_vprompts = split_num_vprompts[bidx] |
|
visual_prompts_bi = torch.split(visual_prompts[bidx], [nvp for nvp in num_vprompts]) |
|
split_vit_embeds_bi = split_vit_embeds[start_idx:start_idx+num_images[bidx]] |
|
start_idx = start_idx + num_images[bidx] |
|
|
|
object_embed_list, valid_flag_list = [], [] |
|
for fidx, visual_prompts_fi in enumerate(visual_prompts_bi): |
|
h, w = visual_prompts_fi.shape[-2:] |
|
visual_prompts_fi = visual_prompts_fi.reshape(num_vprompts[fidx], h, w) |
|
visual_prompts_fi = (visual_prompts_fi > 0.55).to(vit_embeds.dtype) |
|
visual_prompts_fi = visual_prompts_fi.reshape(num_vprompts[fidx], -1) |
|
|
|
num_vp_tokens = torch.sum(visual_prompts_fi, dim=-1, keepdim=False) |
|
valid_flag = num_vp_tokens > 0 |
|
|
|
vit_embeds_fi = split_vit_embeds_bi[fidx] |
|
object_embeds = (visual_prompts_fi[:, :, None] / (num_vp_tokens[:, None, None] + 1e-4) * vit_embeds_fi[None, :, :]) |
|
object_embeds = torch.sum(object_embeds, dim=1) |
|
|
|
object_embed_list.append(object_embeds) |
|
valid_flag_list.append(valid_flag) |
|
|
|
object_embeds_in_batch.append(object_embed_list) |
|
valid_flag_in_batch.append(valid_flag_list) |
|
|
|
|
|
ot_visual_prompts = data['ot_visual_prompts'] |
|
split_ot_visual_prompts = torch.split(ot_visual_prompts, [nvp for nvp in num_vprompts]) |
|
|
|
ot_object_embeds_in_batch = [] |
|
ot_valid_flag_in_batch = [] |
|
start_idx = 0 |
|
for bidx in range(batch_size): |
|
num_vprompt = split_num_vprompts[bidx] |
|
ot_visual_prompts_bi = split_ot_visual_prompts[start_idx:start_idx+num_images[bidx]] |
|
ot_embeds_bi = ot_embeds[start_idx:start_idx+num_images[bidx]] |
|
start_idx = start_idx + num_images[bidx] |
|
|
|
object_embed_list, valid_flag_list = [], [] |
|
for fidx, ot_visual_prompts_fi in enumerate(ot_visual_prompts_bi): |
|
h, w = ot_visual_prompts_fi.shape[-2:] |
|
ot_visual_prompts_fi = ot_visual_prompts_fi[:, None, :, :] |
|
ot_visual_prompts_fi = F.interpolate(ot_visual_prompts_fi.to(ot_embeds.dtype), (ot_num_tokens_h, ot_num_tokens_w), mode="bilinear") |
|
ot_visual_prompts_fi = (ot_visual_prompts_fi > 0.55).to(ot_embeds.dtype) |
|
ot_visual_prompts_fi = ot_visual_prompts_fi.reshape(num_vprompt[fidx], -1) |
|
|
|
num_vp_tokens = torch.sum(ot_visual_prompts_fi, dim=-1, keepdim=False) |
|
valid_flag = num_vp_tokens > 0 |
|
|
|
ot_embeds_fi = ot_embeds_bi[fidx] |
|
object_embeds = (ot_visual_prompts_fi[:, :, None] / (num_vp_tokens[:, None, None] + 1e-4) * ot_embeds_fi[None, :, :]) |
|
object_embeds = torch.sum(object_embeds, dim=1) |
|
|
|
object_embed_list.append(object_embeds) |
|
valid_flag_list.append(valid_flag) |
|
ot_object_embeds_in_batch.append(object_embed_list) |
|
ot_valid_flag_in_batch.append(valid_flag_list) |
|
|
|
|
|
contras_loss = torch.zeros(size=(1, ), dtype=torch.float32).cuda() |
|
|
|
|
|
valid_contras_sample = 0 |
|
for bidx in range(batch_size): |
|
region_ids_bi = region_ids[bidx] |
|
object_embeds_bi = object_embeds_in_batch[bidx] |
|
ot_object_embeds_bi = ot_object_embeds_in_batch[bidx] |
|
valid_flags_bi = valid_flag_in_batch[bidx] |
|
ot_valid_flags_bi = ot_valid_flag_in_batch[bidx] |
|
|
|
for ot_object_embeds, object_embeds, ot_region_ids, _region_ids, ot_valid_flags, valid_flags in zip( |
|
ot_object_embeds_bi, object_embeds_bi[::-1], |
|
region_ids_bi, region_ids_bi[::-1], |
|
ot_valid_flags_bi, valid_flags_bi[::-1], |
|
): |
|
region_id_to_indices = {region_id: idx for idx, region_id in enumerate(_region_ids)} |
|
for anchor_embed, valid_flag, region_id in zip(ot_object_embeds, ot_valid_flags, ot_region_ids): |
|
if not valid_flag: |
|
continue |
|
anchor_embed = anchor_embed.unsqueeze(0) |
|
pos_idx = region_id_to_indices[region_id] |
|
if not valid_flags[pos_idx]: |
|
continue |
|
pos_embed = object_embeds[pos_idx].unsqueeze(0) |
|
if pos_idx == 0: |
|
neg_embeds = object_embeds[1:, :][valid_flags[1:]] |
|
elif pos_idx == (len(object_embeds) - 1): |
|
neg_embeds = object_embeds[:-1, :][valid_flags[:-1]] |
|
else: |
|
neg_embeds = torch.cat([ |
|
object_embeds[:pos_idx, :][valid_flags[:pos_idx]], |
|
object_embeds[pos_idx+1:, :][valid_flags[pos_idx+1:]] |
|
], dim=0) |
|
|
|
pos_neg_embeds = torch.cat([pos_embed, neg_embeds], dim=0) |
|
pos_neg_label = pos_neg_embeds.new_zeros((pos_neg_embeds.shape[0], ), dtype=torch.int64) |
|
pos_neg_label[:1] = 1 |
|
|
|
|
|
dot_product = torch.einsum('ac,kc->ak', [anchor_embed, pos_neg_embeds]) |
|
pos_neg_label = pos_neg_label.unsqueeze(0) |
|
pos_inds = (pos_neg_label == 1) |
|
neg_inds = (pos_neg_label == 0) |
|
pred_pos = dot_product * pos_inds.float() |
|
pred_neg = dot_product * neg_inds.float() |
|
|
|
pred_pos[neg_inds] = pred_pos[neg_inds] + float('inf') |
|
pred_neg[pos_inds] = pred_neg[pos_inds] + float('-inf') |
|
|
|
_pos_expand = torch.repeat_interleave(pred_pos, dot_product.shape[1], dim=1) |
|
_neg_expand = pred_neg.repeat(1, dot_product.shape[1]) |
|
x = F.pad((_neg_expand - _pos_expand), (0, 1), "constant", 0) |
|
try: |
|
contras_loss += torch.logsumexp(x, dim=1) |
|
valid_contras_sample += 1 |
|
except Exception as e: |
|
print("x: ", x.shape) |
|
print("sumexp: ", torch.logsumexp(x, dim=1).shape) |
|
exit(0) |
|
if valid_contras_sample == 0 or torch.any(torch.isnan(contras_loss)): |
|
loss_dict = {"loss": ot_embeds.sum() * 0.0} |
|
else: |
|
loss_dict = {"loss": contras_loss / valid_contras_sample} |
|
|
|
return loss_dict |
|
|
|
|
|
ot_object_embeds = None |
|
vprompt_flags = data['vprompt_flags'] |
|
ot_object_embeds_in_batch = [] |
|
skip_this_batch = False |
|
if has_ot_input and self.object_tokenizer: |
|
|
|
num_vprompts = data['num_vprompts'] |
|
num_images = data['num_images'] |
|
batch_size = len(num_images) |
|
ot_visual_prompts = data['ot_visual_prompts'] |
|
|
|
try: |
|
split_ot_visual_prompts = torch.split(ot_visual_prompts, [nvp for nvp in num_vprompts]) |
|
except: |
|
nvp_list = [1 for nvp in num_vprompts] |
|
if ot_visual_prompts.shape[0] >= len(nvp_list): |
|
split_ot_visual_prompts = torch.split(ot_visual_prompts[:len(nvp_list)], nvp_list) |
|
else: |
|
split_ot_visual_prompts = torch.stack([ot_visual_prompts[0] for nvp in nvp_list]) |
|
num_vprompts = torch.tensor(nvp_list).to(num_vprompts.dtype).to(num_vprompts.device) |
|
skip_this_batch = True |
|
split_num_vprompts = torch.split(num_vprompts, [nimg for nimg in num_images]) |
|
|
|
start_idx = 0 |
|
for bidx in range(batch_size): |
|
num_vprompt = split_num_vprompts[bidx] |
|
ot_visual_prompts_bi = split_ot_visual_prompts[start_idx:start_idx+num_images[bidx]] |
|
ot_embeds_bi = ot_embeds[start_idx:start_idx+num_images[bidx]] |
|
start_idx = start_idx + num_images[bidx] |
|
|
|
ot_object_embeds_list = [] |
|
for fidx, ot_visual_prompts_fi in enumerate(ot_visual_prompts_bi): |
|
h, w = ot_visual_prompts_fi.shape[-2:] |
|
ot_visual_prompts_fi = ot_visual_prompts_fi.reshape(num_vprompt[fidx], 1, h, w) |
|
|
|
ot_visual_prompts_fi = F.interpolate(ot_visual_prompts_fi.to(ot_embeds.dtype), (ot_num_tokens_h, ot_num_tokens_w), mode="bilinear") |
|
ot_visual_prompts_fi = (ot_visual_prompts_fi > 0.5).to(ot_embeds.dtype) |
|
ot_visual_prompts_fi = ot_visual_prompts_fi.reshape(num_vprompt[fidx], -1) |
|
|
|
num_vp_tokens = torch.sum(ot_visual_prompts_fi, dim=-1, keepdim=False) |
|
ot_embeds_fi = ot_embeds_bi[fidx] |
|
object_embeds = (ot_visual_prompts_fi[:, :, None] / (num_vp_tokens[:, None, None] + 1e-4) * ot_embeds_fi[None, :, :]) |
|
object_embeds = torch.sum(object_embeds, dim=1) |
|
ot_object_embeds_list.append(object_embeds) |
|
ot_object_embeds_in_batch.append(ot_object_embeds_list) |
|
ot_object_embeds = [] |
|
for ele in ot_object_embeds_in_batch: |
|
ot_object_embeds.extend(ele) |
|
ot_object_embeds = torch.cat(ot_object_embeds, dim=0) |
|
|
|
if mode == "loss": |
|
input_ids = data['input_ids'] |
|
position_ids = data['position_ids'] |
|
attention_mask = data['attention_mask'] |
|
image_flags = data['image_flags'] |
|
|
|
labels = data['labels'] |
|
use_cache = False |
|
|
|
outputs, _skip_this_case = self._llm_forward( |
|
input_ids=input_ids, |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
image_flags=image_flags, |
|
labels=labels, |
|
use_cache=use_cache, |
|
vit_embeds=vit_embeds, |
|
ot_object_embeds=ot_object_embeds, |
|
vprompt_flags=vprompt_flags, |
|
) |
|
|
|
if skip_this_batch or _skip_this_case: |
|
print("skip this batch!") |
|
loss_dict = {'loss': outputs.loss * 0.0} |
|
else: |
|
loss_dict = {'loss': outputs.loss} |
|
if not self.contras_loss: |
|
return loss_dict |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def _llm_forward( |
|
self, |
|
vit_embeds: torch.FloatTensor, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
image_flags: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
ot_object_embeds: torch.FloatTensor = None, |
|
vprompt_flags: Optional[torch.LongTensor] = None, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
return_dict = return_dict if return_dict is not None \ |
|
else self.model.config.use_return_dict |
|
|
|
B, N = input_ids.shape |
|
temp_input_ids = input_ids.clone().flatten() |
|
temp_input_ids[temp_input_ids == self.vpt_content_token_idx] = self.img_context_token_idx |
|
input_embeds = self.model.get_input_embeddings()(temp_input_ids.reshape(B, N)).clone() |
|
|
|
vit_embeds = vit_embeds[image_flags == 1] |
|
vit_batch_size = vit_embeds.shape[0] |
|
|
|
B, N, C = input_embeds.shape |
|
input_embeds = input_embeds.reshape(B * N, C) |
|
input_ids = input_ids.reshape(B * N) |
|
|
|
skip_this_case=False |
|
if ot_object_embeds is not None: |
|
try: |
|
ot_object_embeds = ot_object_embeds[vprompt_flags > 0] |
|
selected = (input_ids == self.vpt_content_token_idx) |
|
input_embeds[selected] = input_embeds[selected] * 0.0 + ot_object_embeds |
|
skip_this_case=False |
|
except: |
|
print(f"The number of the provided object embeds is not match with vprompt_flags or VPT_CONTENT_TOKEN.") |
|
selected = (input_ids == self.vpt_content_token_idx) |
|
input_embeds[selected] = input_embeds[selected] * 0.0 + ot_object_embeds.mean(dim=0, keepdim=True).to(input_embeds.dtype) |
|
skip_this_case=True |
|
|
|
selected = (input_ids == self.img_context_token_idx) |
|
try: |
|
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) |
|
except Exception as e: |
|
vit_embeds = vit_embeds.reshape(-1, C) |
|
print(f"warning: {e}, input_embeds[selected].shape=" |
|
f"{input_embeds[selected].shape}, " |
|
f"vit_embeds.shape={vit_embeds.shape}") |
|
n_token = selected.sum() |
|
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token] |
|
input_embeds = input_embeds.reshape(B, N, C) |
|
|
|
if torch.distributed.get_rank() == 0 and self._count % 100 == 0: |
|
print(f"dynamic ViT batch size: {vit_batch_size}, " |
|
f"images per sample: {vit_batch_size}/B, " |
|
f"dynamic token length: {N}") |
|
self._count += 1 |
|
|
|
|
|
outputs = self.model( |
|
inputs_embeds = input_embeds, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
labels=labels, |
|
) |
|
|
|
return outputs, skip_this_case |
|
|
|
|
|
|
|
|
|
|
|
|
|
def Qwen2VLModel_forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
labels: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
if self.gradient_checkpointing and self.training: |
|
if use_cache: |
|
|
|
|
|
|
|
use_cache = False |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
if cache_position is None: |
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
) |
|
|
|
|
|
if position_ids is None: |
|
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) |
|
elif position_ids.dim() == 2: |
|
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) |
|
|
|
causal_mask = self._update_causal_mask( |
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
|
) |
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
next_decoder_cache = None |
|
|
|
for decoder_layer in self.layers: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
decoder_layer.__call__, |
|
hidden_states, |
|
causal_mask, |
|
position_ids, |
|
past_key_values, |
|
output_attentions, |
|
use_cache, |
|
cache_position, |
|
position_embeddings, |
|
) |
|
else: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=causal_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |