File size: 6,372 Bytes
032e687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from xtuner.model import LLaVAModel as XtunerLLaVAModel
import torch
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
PROMPT_TEMPLATE)
from xtuner.tools.utils import get_stop_criteria, is_cn_string
from transformers import GenerationConfig
from .utils import prepare_inputs_labels_for_multimodal
from xtuner.model.utils import prepare_inputs_labels_for_multimodal as xtuner_prepare_inputs_labels_for_multimodal
class LLaVAModel(XtunerLLaVAModel):
def __init__(self,
llm,
visual_encoder,
freeze_llm=False,
freeze_visual_encoder=False,
visual_select_layer=-2,
pretrained_pth=None,
projector_depth=2,
llm_lora=None,
visual_encoder_lora=None,
use_activation_checkpointing=True,
max_position_embeddings=None,
tokenizer=None,
inference_dtype=torch.bfloat16,
):
super(LLaVAModel, self).__init__(
llm,
visual_encoder,
freeze_llm=freeze_llm,
freeze_visual_encoder=freeze_visual_encoder,
visual_select_layer=visual_select_layer,
pretrained_pth=pretrained_pth,
projector_depth=projector_depth,
llm_lora=llm_lora,
visual_encoder_lora=visual_encoder_lora,
use_activation_checkpointing=use_activation_checkpointing,
max_position_embeddings=max_position_embeddings)
if tokenizer is not None:
self.tokenizer = tokenizer
tokenizer_type = self.tokenizer['type']
del self.tokenizer['type']
self.tokenizer = tokenizer_type(**self.tokenizer)
self.visual_select_layer = visual_select_layer
self.inference_dtype = inference_dtype
def forward(self, data, data_samples=None, mode='loss'):
if self.is_first_iter:
# hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
# device
# Only required in `LLaVAModel` .
# We do not need this in `SupervisedFinetune` .
self.to(data['input_ids'].device)
self.is_first_iter = False
if 'pixel_values' in data:
visual_outputs = self.visual_encoder(
data['pixel_values'].to(self.visual_encoder.dtype),
output_hidden_states=True)
pixel_values = self.projector(
visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
data['pixel_values'] = pixel_values
data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
if mode == 'loss':
return self.compute_loss(data, data_samples)
elif mode == 'predict':
return self.predict(data, data_samples)
elif mode == 'tensor':
return self._forward(data, data_samples)
else:
raise NotImplementedError
def preparing_for_generation(self, metainfo):
# set stop criteria and generation configs for model
assert hasattr(self, 'tokenizer'), "The Model does not have the tokenizer!!!"
self.bot_name = 'BOT'
if 'template' in metainfo.keys():
template = metainfo['template']
else:
template = PROMPT_TEMPLATE['internlm2_chat']
self.template = template
stop_words = []
stop_words += template.get('STOP_WORDS', [])
stop_criteria = get_stop_criteria(
tokenizer=self.tokenizer, stop_words=stop_words)
self.stop_criteria = stop_criteria
default_generation_kwargs = dict(
max_new_tokens=2048,
do_sample=False,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=(
self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None
else self.tokenizer.eos_token_id
),
)
default_generation_kwargs.update(metainfo.get('generation_kwargs', {}))
self.gen_config = GenerationConfig(**default_generation_kwargs)
self.init_prediction_config = True
self.visual_encoder.to(self.inference_dtype)
self.projector.to(self.inference_dtype)
return
def predict_forward(
self, pixel_values, text_prompts, **kwargs):
# pixel_values: image tensor
# text_prompts: question without template
assert self.init_prediction_config, "Please set prediction configs using self.preparing_for_generation()"
# add template for text
input_text = ''
input_text += self.template['INSTRUCTION'].format(
input=text_prompts, round=1, bot_name=self.bot_name)
chunk_encode = []
for idx, chunk in enumerate(input_text.split(DEFAULT_IMAGE_TOKEN)):
if idx == 0:
cur_encode = self.tokenizer.encode(chunk)
else:
cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False)
chunk_encode.append(cur_encode)
assert len(chunk_encode) == 2
ids = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
ids.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
ids.append(IMAGE_TOKEN_INDEX)
ids = torch.tensor(ids).cuda().unsqueeze(0)
image = pixel_values.cuda().unsqueeze(0)
visual_outputs = self.visual_encoder(image, output_hidden_states=True)
pixel_values = self.projector(
visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
mm_inputs = xtuner_prepare_inputs_labels_for_multimodal(
llm=self.llm, input_ids=ids, pixel_values=pixel_values)
generate_output = self.llm.generate(
**mm_inputs,
generation_config=self.gen_config,
streamer=None,
bos_token_id=self.tokenizer.bos_token_id,
stopping_criteria=self.stop_criteria,
output_hidden_states=False,
return_dict_in_generate=True
)
predict = self.tokenizer.decode(
generate_output.sequences[0], skip_special_tokens=True).strip()
return {'prediction': predict}
|