File size: 3,280 Bytes
1f9001d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6642c6e
 
1f9001d
 
 
 
 
 
 
 
 
 
 
6642c6e
1f9001d
 
 
 
 
6642c6e
1f9001d
6642c6e
1f9001d
 
 
 
 
 
6642c6e
4236505
1f9001d
6642c6e
1f9001d
 
 
 
 
 
 
 
 
 
 
 
 
6642c6e
1f9001d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from PIL import Image
from mm_builder import load_pretrained_model
from mm_utils import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from mm_utils import conv_templates, SeparatorStyle
from mm_utils import disable_torch_init
from mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
from modeling_mmalaya import MMAlayaMPTForCausalLM
from transformers.generation.streamers import TextIteratorStreamer
import argparse


def main(args):
    disable_torch_init()
    conv_mode = "mmalaya_llama"
    model_path = args.model_path
    # 加载model,tokenizer,image_processor
    tokenizer, model, image_processor, _ = load_pretrained_model(
        model_path=model_path, 
        )
    prompts = [
        "这张图可能是在哪拍的?当去这里游玩时需要注意什么?",
        "Where might this picture have been taken? What should you pay attention to when visiting here?"
    ]

    import time
    time1 = time.time()
    
    for prompt in prompts:
        # 加载对话模板
        conv = conv_templates[conv_mode].copy()
        inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
        conv.append_message(conv.roles[0], inp)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        # 对prompt进行分词
        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
        # 加载generate stop条件
        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        if conv_mode == 'mmalaya_llama':
            stop_str = conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0)
        # 加载图像
        image = Image.open('./chang_chen.jpg').convert("RGB")
        image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half().cuda()
        # 推理
        with torch.inference_mode():
            generate_ids = model.generate(
                inputs=input_ids,
                images=image_tensor,
                # do_sample=True,
                # temperature=0.2,
                # top_p=1.0,
                max_new_tokens=1024,
                # streamer=streamer,
                # num_beams = 2,
                use_cache=True,
                stopping_criteria=[stopping_criteria],
                )
            # 截断generate_ids中的input_ids,然后解码为文本
            input_token_len = input_ids.shape[1]
            output = tokenizer.batch_decode(
                generate_ids[:, input_token_len:], 
                skip_special_tokens=True, 
                clean_up_tokenization_spaces=False
                )[0]
            print(output)
    
    time2 = time.time()
    print("cost seconds: ", time2 - time1)
    print("cost seconds per sample: ", (time2 - time1) / len(prompts))
    

if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='/tmp/MMAlaya-v0.1.6.1')
    args = parser.parse_args()
    main(args)


"""
export PYTHONPATH=$PYTHONPATH:/tmp/MMAlaya
CUDA_VISIBLE_DEVICES=0 python inference.py --model_path /tmp/MMAlaya-v0.1.6.1
"""