update model scripts
Browse files- README.md +10 -1
- inference.py +8 -4
README.md
CHANGED
@@ -1,3 +1,12 @@
|
|
1 |
# MMAlaya
|
2 |
-
MMAlaya
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
|
|
1 |
# MMAlaya
|
2 |
+
MMAlaya是基于大语言模型[Alaya](https://github.com/DataCanvasIO/Alaya)的多模态模型。
|
3 |
+
|
4 |
+
MMAlaya包含以下三个模块:
|
5 |
+
<br>1,大语言模型Alaya。
|
6 |
+
<br>2,图像文本特征编码器[blip2-opt-2.7b](https://huggingface.co/Salesforce/blip2-opt-2.7b)
|
7 |
+
<br>3,图像文本特征到大预言模型的线性投影器。
|
8 |
+
|
9 |
+
模型的训练主要基于[LLaVA]架构(https://github.com/haotian-liu/LLaVA)
|
10 |
+
|
11 |
+
2024.01.23 最终在[MMBench](https://mmbench.opencompass.org.cn)线上测试中文测试集分数为56.9,总排名为第20名,7B模型的第9名。英文测试集分数为59.8,总排名为第29名,7B模型的第12名。
|
12 |
|
inference.py
CHANGED
@@ -12,10 +12,10 @@ import argparse
|
|
12 |
|
13 |
def main(args):
|
14 |
disable_torch_init()
|
15 |
-
|
16 |
conv_mode = "mmalaya_llama"
|
17 |
model_path = args.model_path
|
18 |
-
|
|
|
19 |
model_path=model_path,
|
20 |
)
|
21 |
prompts = [
|
@@ -27,21 +27,25 @@ def main(args):
|
|
27 |
time1 = time.time()
|
28 |
|
29 |
for prompt in prompts:
|
|
|
30 |
conv = conv_templates[conv_mode].copy()
|
31 |
inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
|
32 |
conv.append_message(conv.roles[0], inp)
|
33 |
conv.append_message(conv.roles[1], None)
|
34 |
prompt = conv.get_prompt()
|
|
|
35 |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
|
|
36 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
37 |
if conv_mode == 'mmalaya_llama':
|
38 |
stop_str = conv.sep2
|
39 |
keywords = [stop_str]
|
40 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
41 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0)
|
42 |
-
|
43 |
image = Image.open('./data/chang_chen.jpg').convert("RGB")
|
44 |
image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half().cuda()
|
|
|
45 |
with torch.inference_mode():
|
46 |
generate_ids = model.generate(
|
47 |
inputs=input_ids,
|
@@ -55,7 +59,7 @@ def main(args):
|
|
55 |
use_cache=True,
|
56 |
stopping_criteria=[stopping_criteria],
|
57 |
)
|
58 |
-
#
|
59 |
input_token_len = input_ids.shape[1]
|
60 |
output = tokenizer.batch_decode(
|
61 |
generate_ids[:, input_token_len:],
|
|
|
12 |
|
13 |
def main(args):
|
14 |
disable_torch_init()
|
|
|
15 |
conv_mode = "mmalaya_llama"
|
16 |
model_path = args.model_path
|
17 |
+
# 加载model,tokenizer,image_processor
|
18 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(
|
19 |
model_path=model_path,
|
20 |
)
|
21 |
prompts = [
|
|
|
27 |
time1 = time.time()
|
28 |
|
29 |
for prompt in prompts:
|
30 |
+
# 加载对话模板
|
31 |
conv = conv_templates[conv_mode].copy()
|
32 |
inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
|
33 |
conv.append_message(conv.roles[0], inp)
|
34 |
conv.append_message(conv.roles[1], None)
|
35 |
prompt = conv.get_prompt()
|
36 |
+
# 对prompt进行分词
|
37 |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
38 |
+
# 加载generate stop条件
|
39 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
40 |
if conv_mode == 'mmalaya_llama':
|
41 |
stop_str = conv.sep2
|
42 |
keywords = [stop_str]
|
43 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
44 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0)
|
45 |
+
# 加载图像
|
46 |
image = Image.open('./data/chang_chen.jpg').convert("RGB")
|
47 |
image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half().cuda()
|
48 |
+
# 推理
|
49 |
with torch.inference_mode():
|
50 |
generate_ids = model.generate(
|
51 |
inputs=input_ids,
|
|
|
59 |
use_cache=True,
|
60 |
stopping_criteria=[stopping_criteria],
|
61 |
)
|
62 |
+
# 截断generate_ids中的input_ids,然后解码为文本
|
63 |
input_token_len = input_ids.shape[1]
|
64 |
output = tokenizer.batch_decode(
|
65 |
generate_ids[:, input_token_len:],
|