|
import torch |
|
from PIL import Image |
|
from mm_builder import load_pretrained_model |
|
from mm_utils import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
|
from mm_utils import conv_templates, SeparatorStyle |
|
from mm_utils import disable_torch_init |
|
from mm_utils import tokenizer_image_token, KeywordsStoppingCriteria |
|
from modeling_mmalaya import MMAlayaMPTForCausalLM |
|
from transformers.generation.streamers import TextIteratorStreamer |
|
import argparse |
|
|
|
|
|
def main(args): |
|
disable_torch_init() |
|
conv_mode = "mmalaya_llama" |
|
model_path = args.model_path |
|
|
|
tokenizer, model, image_processor, _ = load_pretrained_model( |
|
model_path=model_path, |
|
) |
|
prompts = [ |
|
"这张图可能是在哪拍的?当去这里游玩时需要注意什么?", |
|
"Where might this picture have been taken? What should you pay attention to when visiting here?" |
|
] |
|
|
|
import time |
|
time1 = time.time() |
|
|
|
for prompt in prompts: |
|
|
|
conv = conv_templates[conv_mode].copy() |
|
inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt |
|
conv.append_message(conv.roles[0], inp) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
|
|
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() |
|
|
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
|
if conv_mode == 'mmalaya_llama': |
|
stop_str = conv.sep2 |
|
keywords = [stop_str] |
|
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) |
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0) |
|
|
|
image = Image.open('./data/chang_chen.jpg').convert("RGB") |
|
image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half().cuda() |
|
|
|
with torch.inference_mode(): |
|
generate_ids = model.generate( |
|
inputs=input_ids, |
|
images=image_tensor, |
|
|
|
|
|
|
|
max_new_tokens=1024, |
|
|
|
|
|
use_cache=True, |
|
stopping_criteria=[stopping_criteria], |
|
) |
|
|
|
input_token_len = input_ids.shape[1] |
|
output = tokenizer.batch_decode( |
|
generate_ids[:, input_token_len:], |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=False |
|
)[0] |
|
print(output) |
|
|
|
time2 = time.time() |
|
print("cost seconds: ", time2 - time1) |
|
print("cost seconds per sample: ", (time2 - time1) / len(prompts)) |
|
|
|
|
|
if __name__=="__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model_path', type=str, default='/tmp/MMAlaya-v0.1.6.1') |
|
args = parser.parse_args() |
|
main(args) |
|
|
|
|
|
""" |
|
export PYTHONPATH=$PYTHONPATH:/tmp/MMAlaya |
|
CUDA_VISIBLE_DEVICES=0 python inference.py --model_path /tmp/MMAlaya-v0.1.6.1 |
|
""" |