MMAlaya / inference.py
bingwork's picture
Upload 18 files
4236505 verified
raw
history blame
3.28 kB
import torch
from PIL import Image
from mm_builder import load_pretrained_model
from mm_utils import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from mm_utils import conv_templates, SeparatorStyle
from mm_utils import disable_torch_init
from mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
from modeling_mmalaya import MMAlayaMPTForCausalLM
from transformers.generation.streamers import TextIteratorStreamer
import argparse
def main(args):
disable_torch_init()
conv_mode = "mmalaya_llama"
model_path = args.model_path
# 加载model,tokenizer,image_processor
tokenizer, model, image_processor, _ = load_pretrained_model(
model_path=model_path,
)
prompts = [
"这张图可能是在哪拍的?当去这里游玩时需要注意什么?",
"Where might this picture have been taken? What should you pay attention to when visiting here?"
]
import time
time1 = time.time()
for prompt in prompts:
# 加载对话模板
conv = conv_templates[conv_mode].copy()
inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
# 对prompt进行分词
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
# 加载generate stop条件
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
if conv_mode == 'mmalaya_llama':
stop_str = conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0)
# 加载图像
image = Image.open('./chang_chen.jpg').convert("RGB")
image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half().cuda()
# 推理
with torch.inference_mode():
generate_ids = model.generate(
inputs=input_ids,
images=image_tensor,
# do_sample=True,
# temperature=0.2,
# top_p=1.0,
max_new_tokens=1024,
# streamer=streamer,
# num_beams = 2,
use_cache=True,
stopping_criteria=[stopping_criteria],
)
# 截断generate_ids中的input_ids,然后解码为文本
input_token_len = input_ids.shape[1]
output = tokenizer.batch_decode(
generate_ids[:, input_token_len:],
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
print(output)
time2 = time.time()
print("cost seconds: ", time2 - time1)
print("cost seconds per sample: ", (time2 - time1) / len(prompts))
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='/tmp/MMAlaya-v0.1.6.1')
args = parser.parse_args()
main(args)
"""
export PYTHONPATH=$PYTHONPATH:/tmp/MMAlaya
CUDA_VISIBLE_DEVICES=0 python inference.py --model_path /tmp/MMAlaya-v0.1.6.1
"""