|
import torch |
|
from xtuner.model import InternVL_V1_5 |
|
from typing import List, Optional, Tuple, Union |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM, |
|
LlamaTokenizer) |
|
import torch.nn as nn |
|
|
|
from mmengine import print_log |
|
from torch.nn import CrossEntropyLoss |
|
from transformers import (AutoConfig, AutoModel, AutoTokenizer, |
|
BitsAndBytesConfig) |
|
from xtuner.model.utils import (find_all_linear_names, get_peft_model_state_dict, |
|
guess_load_checkpoint, make_inputs_require_grad) |
|
import os |
|
|
|
def get_rank_and_world_size(): |
|
rank = int(os.environ.get('RANK', 0)) |
|
world_size = int(os.environ.get('WORLD_SIZE', 1)) |
|
return rank, world_size |
|
|
|
|
|
def split_model(model_name): |
|
import math |
|
device_map = {} |
|
num_gpus = torch.cuda.device_count() |
|
rank, world_size = get_rank_and_world_size() |
|
num_gpus = num_gpus // world_size |
|
|
|
num_layers = {'InternVL2-8B': 32, 'InternVL2-26B': 48, |
|
'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80}[model_name] |
|
|
|
num_layers_per_gpu = math.ceil(num_layers / (num_gpus - 0.2)) |
|
num_layers_per_gpu = [num_layers_per_gpu] * num_gpus |
|
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.8) |
|
layer_cnt = 0 |
|
for i, num_layer in enumerate(num_layers_per_gpu): |
|
for j in range(num_layer): |
|
device_map[f'language_model.model.layers.{layer_cnt}'] = rank + world_size * i |
|
layer_cnt += 1 |
|
device_map['vision_model'] = rank |
|
device_map['mlp1'] = rank |
|
device_map['language_model.model.tok_embeddings'] = rank |
|
device_map['language_model.model.embed_tokens'] = rank |
|
device_map['language_model.output'] = rank |
|
device_map['language_model.model.norm'] = rank |
|
device_map['language_model.lm_head'] = rank |
|
device_map[f'language_model.model.layers.{num_layers - 1}'] = rank |
|
return device_map |
|
|
|
class InternVL_Slowfast(InternVL_V1_5): |
|
|
|
def __init__(self, |
|
model_path, |
|
freeze_llm=False, |
|
freeze_visual_encoder=False, |
|
llm_lora=None, |
|
visual_encoder_lora=None, |
|
quantization_vit=False, |
|
quantization_llm=False, |
|
pretrained_pth=None, |
|
special_tokens=None, |
|
model_split=False, |
|
): |
|
print_log('Start to load InternVL_V1_5 model.', logger='current') |
|
super(InternVL_V1_5, self).__init__() |
|
self.freeze_llm = freeze_llm |
|
self.freeze_visual_encoder = freeze_visual_encoder |
|
self.use_llm_lora = llm_lora is not None |
|
self.use_visual_encoder_lora = visual_encoder_lora is not None |
|
self.quantization_vit = quantization_vit |
|
self.quantization_llm = quantization_llm |
|
if quantization_vit: |
|
assert visual_encoder_lora is not None |
|
if quantization_llm: |
|
assert quantization_llm and llm_lora is not None |
|
|
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) |
|
if config.llm_config.model_type == 'internlm2': |
|
config.llm_config.attn_implementation = 'flash_attention_2' |
|
else: |
|
config.llm_config._attn_implementation = 'flash_attention_2' |
|
|
|
if quantization_vit is False and quantization_llm is False: |
|
quantization = None |
|
else: |
|
llm_int8_skip_modules = ['mlp1'] |
|
if quantization_llm and not quantization_vit: |
|
llm_int8_skip_modules.append('vision_model') |
|
|
|
if quantization_vit and not quantization_llm: |
|
llm_int8_skip_modules.append('language_model') |
|
|
|
quantization_config = dict( |
|
type=BitsAndBytesConfig, |
|
llm_int8_skip_modules=llm_int8_skip_modules, |
|
load_in_4bit=True, |
|
load_in_8bit=False, |
|
llm_int8_threshold=6.0, |
|
llm_int8_has_fp16_weight=False, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type='nf4') |
|
quantization_clazz = quantization_config.pop('type') |
|
quantization = quantization_clazz(**quantization_config) |
|
|
|
if model_split: |
|
|
|
device_map = split_model("InternVL2-26B") |
|
|
|
self.device = 'cuda' |
|
self.model = AutoModel.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
device_map=device_map).eval() |
|
|
|
else: |
|
self.model = AutoModel.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16, |
|
quantization_config=quantization, |
|
config=config, |
|
trust_remote_code=True) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_path, trust_remote_code=True) |
|
self.tokenizer = tokenizer |
|
|
|
if special_tokens is not None: |
|
self._add_special_tokens(special_tokens) |
|
|
|
img_context_token_id = tokenizer.convert_tokens_to_ids('<IMG_CONTEXT>') |
|
self.model.img_context_token_id = img_context_token_id |
|
|
|
if self.freeze_llm: |
|
self.model.language_model.requires_grad_(False) |
|
if self.freeze_visual_encoder: |
|
self.model.vision_model.requires_grad_(False) |
|
|
|
if hasattr(self.model.language_model, 'enable_input_require_grads'): |
|
self.model.language_model.enable_input_require_grads() |
|
else: |
|
self.model.language_model.get_input_embeddings( |
|
).register_forward_hook(make_inputs_require_grad) |
|
|
|
self.gradient_checkpointing_enable() |
|
|
|
if self.use_llm_lora: |
|
self._prepare_llm_for_lora(llm_lora) |
|
|
|
if self.use_visual_encoder_lora: |
|
self._prepare_visual_encoder_for_lora(visual_encoder_lora) |
|
|
|
if pretrained_pth is not None: |
|
pretrained_state_dict = guess_load_checkpoint(pretrained_pth) |
|
|
|
self.load_state_dict(pretrained_state_dict, strict=False) |
|
print(f'Load pretrained weight from {pretrained_pth}') |
|
|
|
self._count = 0 |
|
print_log(self, logger='current') |
|
print_log('InternVL_V1_5 construction is complete', logger='current') |
|
|
|
self.transfer_to_hf = False |
|
|
|
def _add_special_tokens(self, special_tokens): |
|
num_new_tokens = self.tokenizer.add_tokens( |
|
special_tokens, special_tokens=True) |
|
|
|
if num_new_tokens > 0: |
|
self.model.language_model.resize_token_embeddings(len(self.tokenizer)) |
|
|
|
def _post_init(self, fast_pool_size=4, fast_pool=True): |
|
if fast_pool: |
|
self.fast_pool = nn.AdaptiveAvgPool2d((fast_pool_size, fast_pool_size)) |
|
return |
|
|
|
def forward(self, data, data_samples=None, mode='loss', fast_token_idx=None): |
|
if 'fast_pixel_values' in data.keys(): |
|
assert fast_token_idx is not None |
|
fast_pixel_values = data['fast_pixel_values'] |
|
if type(fast_pixel_values) is list or fast_pixel_values.ndim == 5: |
|
if type(fast_pixel_values) is list: |
|
fast_pixel_values = [ |
|
x.unsqueeze(0) if x.ndim == 3 else x for x in fast_pixel_values |
|
] |
|
|
|
fast_concat_images = torch.cat( |
|
[image.to(self.model.vision_model.dtype) for image in fast_pixel_values], dim=0) |
|
else: |
|
raise NotImplementedError() |
|
else: |
|
fast_pixel_values = None |
|
fast_concat_images = None |
|
|
|
pixel_values = data['pixel_values'] |
|
|
|
if type(pixel_values) is list or pixel_values.ndim == 5: |
|
if type(pixel_values) is list: |
|
pixel_values = [ |
|
x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values |
|
] |
|
|
|
concat_images = torch.cat( |
|
[image.to(self.model.vision_model.dtype) for image in pixel_values], dim=0) |
|
else: |
|
raise NotImplementedError() |
|
|
|
input_ids = data['input_ids'] |
|
position_ids = data['position_ids'] |
|
attention_mask = data['attention_mask'] |
|
|
|
image_flags = torch.sum(concat_images, dim=(1, 2, 3)) != 0 |
|
image_flags = image_flags.long() |
|
|
|
labels = data['labels'] |
|
use_cache = False |
|
|
|
if 'vp_overall_mask' not in data.keys(): |
|
vp_overall_mask = None |
|
else: |
|
vp_overall_mask = data['vp_overall_mask'] |
|
|
|
if 'prompt_masks' in data.keys(): |
|
prompt_masks = data['prompt_masks'] |
|
else: |
|
prompt_masks = None |
|
|
|
outputs = self._llm_forward( |
|
input_ids=input_ids, |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
image_flags=image_flags, |
|
pixel_values=concat_images, |
|
labels=labels, |
|
use_cache=use_cache, |
|
output_hidden_states=True, |
|
fast_pixel_values=fast_concat_images, |
|
fast_token_idx=fast_token_idx, |
|
vp_overall_mask=vp_overall_mask, |
|
prompt_masks=prompt_masks, |
|
) |
|
|
|
return outputs |
|
|
|
def _llm_forward( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
image_flags: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
fast_pixel_values=None, |
|
fast_token_idx=None, |
|
vp_overall_mask=None, |
|
prompt_masks=None, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
return_dict = return_dict if return_dict is not None \ |
|
else self.model.config.use_return_dict |
|
|
|
image_flags = image_flags.squeeze(-1) |
|
|
|
input_embeds = self.model.language_model.get_input_embeddings()( |
|
input_ids).clone() |
|
|
|
if fast_pixel_values is not None: |
|
n_fast_images = fast_pixel_values.shape[0] |
|
whole_pixel_values = torch.cat([fast_pixel_values, pixel_values], dim=0) |
|
vit_embeds = self.model.extract_feature(whole_pixel_values) |
|
vit_embeds = vit_embeds.to(input_embeds.dtype) |
|
fast_vit_embeds = vit_embeds[:n_fast_images] |
|
_size = int(fast_vit_embeds.shape[1] ** 0.5) |
|
fast_vit_embeds = fast_vit_embeds.reshape(fast_vit_embeds.shape[0], _size, _size, fast_vit_embeds.shape[-1]) |
|
|
|
fast_vit_embeds = fast_vit_embeds.permute(0, 3, 1, 2) |
|
fast_vit_embeds = self.fast_pool(fast_vit_embeds).flatten(2) |
|
fast_vit_embeds = fast_vit_embeds.permute(0, 2, 1) |
|
vit_embeds = vit_embeds[n_fast_images:] |
|
else: |
|
vit_embeds = self.model.extract_feature(pixel_values) |
|
vit_embeds = vit_embeds.to(input_embeds.dtype) |
|
fast_vit_embeds = None |
|
|
|
vit_embeds = vit_embeds[image_flags == 1] |
|
vit_batch_size = pixel_values.shape[0] |
|
|
|
B, N, C = input_embeds.shape |
|
input_embeds = input_embeds.reshape(B * N, C) |
|
|
|
self._count += 1 |
|
|
|
if vp_overall_mask is not None and prompt_masks is not None: |
|
vp_embeds = [] |
|
vp_overall_mask = vp_overall_mask.to(vit_embeds.device).bool() |
|
prompt_masks = [item.to(vit_embeds.device).bool() for item in prompt_masks] |
|
|
|
vp_overall_mask = vp_overall_mask[image_flags == 1] |
|
overall_tile_vit_embeds = vit_embeds[vp_overall_mask] |
|
|
|
i_vp_img = 0 |
|
for i_img in range(len(vit_embeds)): |
|
vp_embeds.append(vit_embeds[i_img].reshape(-1, C)) |
|
if vp_overall_mask[i_img]: |
|
tile_vit_embeds = overall_tile_vit_embeds[i_vp_img].reshape(-1, C) |
|
objects_prompt_masks = prompt_masks[i_vp_img] |
|
n_obj = len(objects_prompt_masks) |
|
tile_vit_embeds = tile_vit_embeds.unsqueeze(0).repeat(n_obj, 1, 1) |
|
objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1) |
|
vp_embeds.append(tile_vit_embeds[objects_prompt_masks]) |
|
i_vp_img += 1 |
|
vp_embeds = torch.cat(vp_embeds, dim=0) |
|
else: |
|
vp_embeds = None |
|
|
|
input_ids = input_ids.reshape(B * N) |
|
selected = (input_ids == self.model.img_context_token_id) |
|
|
|
if vp_embeds is None: |
|
try: |
|
input_embeds[selected] = vit_embeds.reshape(-1, C) |
|
except Exception as e: |
|
vit_embeds = vit_embeds.reshape(-1, C) |
|
print(f'warning: {e}, input_embeds[selected].shape=' |
|
f'{input_embeds[selected].shape}, ' |
|
f'vit_embeds.shape={vit_embeds.shape}') |
|
n_token = selected.sum() |
|
if n_token > len(vit_embeds): |
|
print(f"Wrong !!! {n_token} image tokens in text but only {len(vit_embeds)} vit embeds !!!") |
|
expand_ratio = n_token // len(vit_embeds) + 1 |
|
vit_embeds = torch.cat([vit_embeds] * expand_ratio, dim=0) |
|
|
|
input_embeds[selected] = vit_embeds[:n_token] |
|
else: |
|
try: |
|
input_embeds[selected] = vp_embeds.reshape(-1, C) |
|
except Exception as e: |
|
vp_embeds = vp_embeds.reshape(-1, C) |
|
print(f'warning: {e}, input_embeds[selected].shape=' |
|
f'{input_embeds[selected].shape}, ' |
|
f'vp_embeds.shape={vp_embeds.shape}') |
|
n_token = selected.sum() |
|
if n_token > len(vp_embeds): |
|
print(f"Wrong !!! {n_token} image tokens in text but only {len(vp_embeds)} vit embeds !!!") |
|
expand_ratio = n_token // len(vp_embeds) + 1 |
|
vp_embeds = torch.cat([vp_embeds] * expand_ratio, dim=0) |
|
|
|
input_embeds[selected] = vp_embeds[:n_token] |
|
|
|
if fast_vit_embeds is not None: |
|
selected = (input_ids == fast_token_idx) |
|
selected_tot = selected.sum().item() |
|
if selected_tot > fast_vit_embeds.shape[0] * fast_vit_embeds.shape[1]: |
|
assert selected_tot % (fast_vit_embeds.shape[0] * fast_vit_embeds.shape[1]) == 0 |
|
repeat_times = selected_tot / (fast_vit_embeds.shape[0] * fast_vit_embeds.shape[1]) |
|
fast_vit_embeds = fast_vit_embeds.repeat(int(repeat_times), 1, 1) |
|
try: |
|
input_embeds[selected] = fast_vit_embeds.reshape(-1, C) |
|
except Exception as e: |
|
fast_vit_embeds = fast_vit_embeds.reshape(-1, C) |
|
print(f'warning: {e}, input_embeds[fast_selected].shape=' |
|
f'{input_embeds[selected].shape}, ' |
|
f'fast_vit_embeds.shape={fast_vit_embeds.shape}') |
|
n_token = selected.sum() |
|
input_embeds[selected] = fast_vit_embeds[:n_token] |
|
|
|
input_embeds = input_embeds.reshape(B, N, C) |
|
|
|
outputs = self.model.language_model( |
|
inputs_embeds=input_embeds, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
logits = outputs.logits |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = CrossEntropyLoss() |
|
shift_logits = shift_logits.view( |
|
-1, self.model.language_model.config.vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
if not return_dict: |
|
output = (logits, ) + outputs[1:] |
|
return (loss, ) + output if loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
input_ids: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
visual_features: Optional[torch.FloatTensor] = None, |
|
generation_config: Optional[GenerationConfig] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
fast_token_idx=None, |
|
fast_pixel_values=None, |
|
prompt_masks=None, |
|
vp_overall_mask=None, |
|
**generate_kwargs, |
|
) -> torch.LongTensor: |
|
device = self.model.device |
|
assert self.model.img_context_token_id is not None |
|
|
|
if fast_pixel_values is not None: |
|
assert fast_token_idx is not None |
|
if type(fast_pixel_values) is list or fast_pixel_values.ndim == 5: |
|
if type(fast_pixel_values) is list: |
|
fast_pixel_values = [ |
|
x.unsqueeze(0) if x.ndim == 3 else x for x in fast_pixel_values |
|
] |
|
|
|
fast_pixel_values = torch.cat( |
|
[image.to(self.model.vision_model.dtype) for image in fast_pixel_values], dim=0) |
|
|
|
if pixel_values is not None: |
|
if visual_features is not None: |
|
vit_embeds = visual_features |
|
else: |
|
if type(pixel_values) is list or pixel_values.ndim == 5: |
|
if type(pixel_values) is list: |
|
pixel_values = [ |
|
x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values |
|
] |
|
|
|
pixel_values = torch.cat( |
|
[image.to(self.model.vision_model.dtype) for image in pixel_values], dim=0) |
|
|
|
if fast_pixel_values is not None: |
|
n_fast_images = fast_pixel_values.shape[0] |
|
whole_pixel_values = torch.cat([fast_pixel_values, pixel_values], dim=0) |
|
vit_embeds = self.model.extract_feature(whole_pixel_values.to(device)) |
|
|
|
fast_vit_embeds = vit_embeds[:n_fast_images] |
|
_size = int(fast_vit_embeds.shape[1] ** 0.5) |
|
fast_vit_embeds = fast_vit_embeds.reshape(fast_vit_embeds.shape[0], _size, _size, |
|
fast_vit_embeds.shape[-1]) |
|
|
|
fast_vit_embeds = fast_vit_embeds.permute(0, 3, 1, 2) |
|
fast_vit_embeds = self.fast_pool(fast_vit_embeds).flatten(2) |
|
fast_vit_embeds = fast_vit_embeds.permute(0, 2, 1) |
|
vit_embeds = vit_embeds[n_fast_images:] |
|
else: |
|
fast_vit_embeds = None |
|
vit_embeds = self.model.extract_feature(pixel_values.to(device)) |
|
image_flags = torch.sum(pixel_values, dim=(1, 2, 3)) != 0 |
|
image_flags = image_flags.long() |
|
vit_embeds = vit_embeds[image_flags == 1] |
|
|
|
input_embeds = self.model.language_model.get_input_embeddings()(input_ids.to(device)) |
|
B, N, C = input_embeds.shape |
|
input_embeds = input_embeds.reshape(B * N, C) |
|
|
|
if vp_overall_mask is not None and prompt_masks is not None: |
|
vp_embeds = [] |
|
vp_overall_mask = vp_overall_mask.to(vit_embeds.device).bool() |
|
prompt_masks = [item.to(vit_embeds.device).bool() for item in prompt_masks] |
|
|
|
vp_overall_mask = vp_overall_mask[image_flags == 1] |
|
overall_tile_vit_embeds = vit_embeds[vp_overall_mask] |
|
|
|
i_vp_img = 0 |
|
for i_img in range(len(vit_embeds)): |
|
vp_embeds.append(vit_embeds[i_img].reshape(-1, C)) |
|
if vp_overall_mask[i_img]: |
|
tile_vit_embeds = overall_tile_vit_embeds[i_vp_img].reshape(-1, C) |
|
objects_prompt_masks = prompt_masks[i_vp_img] |
|
n_obj = len(objects_prompt_masks) |
|
tile_vit_embeds = tile_vit_embeds.unsqueeze(0).repeat(n_obj, 1, 1) |
|
objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1) |
|
vp_embeds.append(tile_vit_embeds[objects_prompt_masks]) |
|
i_vp_img += 1 |
|
vp_embeds = torch.cat(vp_embeds, dim=0) |
|
else: |
|
vp_embeds = None |
|
|
|
input_ids = input_ids.reshape(B * N) |
|
selected = (input_ids == self.model.img_context_token_id) |
|
assert selected.sum() != 0 |
|
if vp_embeds is None: |
|
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) |
|
else: |
|
if len(input_embeds[selected]) != len(vp_embeds.reshape(-1, C)): |
|
print("Shape mismatch, selected is {}, vp embeds is {} !!!"\ |
|
.format(len(input_embeds[selected]), len(vp_embeds.reshape(-1, C)))) |
|
min_tokens = min(len(input_embeds[selected]), len(vp_embeds.reshape(-1, C))) |
|
input_embeds[selected][:min_tokens] = vp_embeds.reshape(-1, C)[:min_tokens].to(input_embeds.device) |
|
else: |
|
input_embeds[selected] = vp_embeds.reshape(-1, C).to(input_embeds.device) |
|
|
|
if fast_vit_embeds is not None: |
|
selected = (input_ids == fast_token_idx) |
|
|
|
assert selected.sum() != 0 |
|
input_embeds[selected] = fast_vit_embeds.reshape(-1, C).to(input_embeds.device) |
|
|
|
input_embeds = input_embeds.reshape(B, N, C) |
|
else: |
|
input_embeds = self.model.language_model.get_input_embeddings()(input_ids) |
|
|
|
outputs = self.model.language_model.generate( |
|
inputs_embeds=input_embeds, |
|
attention_mask=attention_mask.to(device), |
|
generation_config=generation_config, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
use_cache=True, |
|
**generate_kwargs, |
|
) |
|
|
|
return outputs |
|
|
|
def state_dict(self, *args, **kwargs): |
|
if self.transfer_to_hf: |
|
state_dict = super(InternVL_V1_5, self).state_dict(*args, **kwargs) |
|
return state_dict |
|
else: |
|
return super().state_dict(*args, **kwargs) |
|
|
|
|
|
|