zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
from xtuner.model import LLaVAModel as XtunerLLaVAModel
import torch
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
PROMPT_TEMPLATE)
from xtuner.tools.utils import get_stop_criteria, is_cn_string
from transformers import GenerationConfig
from xtuner.model.utils import prepare_inputs_labels_for_multimodal
class LLaVAModel(XtunerLLaVAModel):
def __init__(self,
llm,
visual_encoder,
freeze_llm=False,
freeze_visual_encoder=False,
visual_select_layer=-2,
pretrained_pth=None,
projector_depth=2,
llm_lora=None,
visual_encoder_lora=None,
use_activation_checkpointing=True,
max_position_embeddings=None,
tokenizer=None,
inference_dtype=torch.bfloat16,
):
super(LLaVAModel, self).__init__(
llm,
visual_encoder,
freeze_llm=freeze_llm,
freeze_visual_encoder=freeze_visual_encoder,
visual_select_layer=visual_select_layer,
pretrained_pth=pretrained_pth,
projector_depth=projector_depth,
llm_lora=llm_lora,
visual_encoder_lora=visual_encoder_lora,
use_activation_checkpointing=use_activation_checkpointing,
max_position_embeddings=max_position_embeddings)
if tokenizer is not None:
self.tokenizer = tokenizer
tokenizer_type = self.tokenizer['type']
del self.tokenizer['type']
self.tokenizer = tokenizer_type(**self.tokenizer)
self.visual_select_layer = visual_select_layer
self.inference_dtype = inference_dtype
def preparing_for_generation(self, metainfo):
# set stop criteria and generation configs for model
assert hasattr(self, 'tokenizer'), "The Model does not have the tokenizer!!!"
self.bot_name = 'BOT'
if 'template' in metainfo.keys():
template = metainfo['template']
else:
template = PROMPT_TEMPLATE['internlm2_chat']
self.template = template
stop_words = []
stop_words += template.get('STOP_WORDS', [])
stop_criteria = get_stop_criteria(
tokenizer=self.tokenizer, stop_words=stop_words)
self.stop_criteria = stop_criteria
default_generation_kwargs = dict(
max_new_tokens=2048,
do_sample=False,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=(
self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None
else self.tokenizer.eos_token_id
),
)
default_generation_kwargs.update(metainfo.get('generation_kwargs', {}))
self.gen_config = GenerationConfig(**default_generation_kwargs)
self.init_prediction_config = True
self.visual_encoder.to(self.inference_dtype)
self.projector.to(self.inference_dtype)
return
def predict_forward(
self, pixel_values, text_prompts, **kwargs):
# pixel_values: image tensor
# text_prompts: question without template
assert self.init_prediction_config, "Please set prediction configs using self.preparing_for_generation()"
# add template for text
input_text = ''
input_text += self.template['INSTRUCTION'].format(
input=text_prompts, round=1, bot_name=self.bot_name)
chunk_encode = []
for idx, chunk in enumerate(input_text.split(DEFAULT_IMAGE_TOKEN)):
if idx == 0:
cur_encode = self.tokenizer.encode(chunk)
else:
cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False)
chunk_encode.append(cur_encode)
assert len(chunk_encode) == 2
ids = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
ids.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
ids.append(IMAGE_TOKEN_INDEX)
ids = torch.tensor(ids).cuda().unsqueeze(0)
image = pixel_values.cuda().unsqueeze(0)
visual_outputs = self.visual_encoder(image, output_hidden_states=True)
pixel_values = self.projector(
visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
mm_inputs = prepare_inputs_labels_for_multimodal(
llm=self.llm, input_ids=ids, pixel_values=pixel_values)
generate_output = self.llm.generate(
**mm_inputs,
generation_config=self.gen_config,
streamer=None,
bos_token_id=self.tokenizer.bos_token_id,
stopping_criteria=self.stop_criteria,
output_hidden_states=False,
return_dict_in_generate=True
)
predict = self.tokenizer.decode(
generate_output.sequences[0], skip_special_tokens=True).strip()
return {'prediction': predict}