import torch from PIL import Image from mm_builder import load_pretrained_model from mm_utils import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN from mm_utils import conv_templates, SeparatorStyle from mm_utils import disable_torch_init from mm_utils import tokenizer_image_token, KeywordsStoppingCriteria from modeling_mmalaya import MMAlayaMPTForCausalLM from transformers.generation.streamers import TextIteratorStreamer import argparse def main(args): disable_torch_init() conv_mode = "mmalaya_llama" model_path = args.model_path # 加载model,tokenizer,image_processor tokenizer, model, image_processor, _ = load_pretrained_model( model_path=model_path, ) prompts = [ "这张图可能是在哪拍的?当去这里游玩时需要注意什么?", "Where might this picture have been taken? What should you pay attention to when visiting here?" ] import time time1 = time.time() for prompt in prompts: # 加载对话模板 conv = conv_templates[conv_mode].copy() inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() # 对prompt进行分词 input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() # 加载generate stop条件 stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 if conv_mode == 'mmalaya_llama': stop_str = conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0) # 加载图像 image = Image.open('./chang_chen.jpg').convert("RGB") image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half().cuda() # 推理 with torch.inference_mode(): generate_ids = model.generate( inputs=input_ids, images=image_tensor, # do_sample=True, # temperature=0.2, # top_p=1.0, max_new_tokens=1024, # streamer=streamer, # num_beams = 2, use_cache=True, stopping_criteria=[stopping_criteria], ) # 截断generate_ids中的input_ids,然后解码为文本 input_token_len = input_ids.shape[1] output = tokenizer.batch_decode( generate_ids[:, input_token_len:], skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] print(output) time2 = time.time() print("cost seconds: ", time2 - time1) print("cost seconds per sample: ", (time2 - time1) / len(prompts)) if __name__=="__main__": parser = argparse.ArgumentParser() parser.add_argument('--model_path', type=str, default='/tmp/MMAlaya-v0.1.6.1') args = parser.parse_args() main(args) """ export PYTHONPATH=$PYTHONPATH:/tmp/MMAlaya CUDA_VISIBLE_DEVICES=0 python inference.py --model_path /tmp/MMAlaya-v0.1.6.1 """