bingwork commited on
Commit
6642c6e
1 Parent(s): 1f9001d

update model scripts

Browse files
Files changed (2) hide show
  1. README.md +10 -1
  2. inference.py +8 -4
README.md CHANGED
@@ -1,3 +1,12 @@
1
  # MMAlaya
2
- MMAlaya is a multimodal large language model from [Alaya](https://github.com/DataCanvasIO/Alaya).
 
 
 
 
 
 
 
 
 
3
 
 
1
  # MMAlaya
2
+ MMAlaya是基于大语言模型[Alaya](https://github.com/DataCanvasIO/Alaya)的多模态模型。
3
+
4
+ MMAlaya包含以下三个模块:
5
+ <br>1,大语言模型Alaya。
6
+ <br>2,图像文本特征编码器[blip2-opt-2.7b](https://huggingface.co/Salesforce/blip2-opt-2.7b)
7
+ <br>3,图像文本特征到大预言模型的线性投影器。
8
+
9
+ 模型的训练主要基于[LLaVA]架构(https://github.com/haotian-liu/LLaVA)
10
+
11
+ 2024.01.23 最终在[MMBench](https://mmbench.opencompass.org.cn)线上测试中文测试集分数为56.9,总排名为第20名,7B模型的第9名。英文测试集分数为59.8,总排名为第29名,7B模型的第12名。
12
 
inference.py CHANGED
@@ -12,10 +12,10 @@ import argparse
12
 
13
  def main(args):
14
  disable_torch_init()
15
-
16
  conv_mode = "mmalaya_llama"
17
  model_path = args.model_path
18
- tokenizer, model, image_processor, context_len = load_pretrained_model(
 
19
  model_path=model_path,
20
  )
21
  prompts = [
@@ -27,21 +27,25 @@ def main(args):
27
  time1 = time.time()
28
 
29
  for prompt in prompts:
 
30
  conv = conv_templates[conv_mode].copy()
31
  inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
32
  conv.append_message(conv.roles[0], inp)
33
  conv.append_message(conv.roles[1], None)
34
  prompt = conv.get_prompt()
 
35
  input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
 
36
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
37
  if conv_mode == 'mmalaya_llama':
38
  stop_str = conv.sep2
39
  keywords = [stop_str]
40
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
41
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0)
42
-
43
  image = Image.open('./data/chang_chen.jpg').convert("RGB")
44
  image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half().cuda()
 
45
  with torch.inference_mode():
46
  generate_ids = model.generate(
47
  inputs=input_ids,
@@ -55,7 +59,7 @@ def main(args):
55
  use_cache=True,
56
  stopping_criteria=[stopping_criteria],
57
  )
58
- # remove input_ids in generate_ids
59
  input_token_len = input_ids.shape[1]
60
  output = tokenizer.batch_decode(
61
  generate_ids[:, input_token_len:],
 
12
 
13
  def main(args):
14
  disable_torch_init()
 
15
  conv_mode = "mmalaya_llama"
16
  model_path = args.model_path
17
+ # 加载model,tokenizer,image_processor
18
+ tokenizer, model, image_processor, _ = load_pretrained_model(
19
  model_path=model_path,
20
  )
21
  prompts = [
 
27
  time1 = time.time()
28
 
29
  for prompt in prompts:
30
+ # 加载对话模板
31
  conv = conv_templates[conv_mode].copy()
32
  inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
33
  conv.append_message(conv.roles[0], inp)
34
  conv.append_message(conv.roles[1], None)
35
  prompt = conv.get_prompt()
36
+ # 对prompt进行分词
37
  input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
38
+ # 加载generate stop条件
39
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
40
  if conv_mode == 'mmalaya_llama':
41
  stop_str = conv.sep2
42
  keywords = [stop_str]
43
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
44
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0)
45
+ # 加载图像
46
  image = Image.open('./data/chang_chen.jpg').convert("RGB")
47
  image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half().cuda()
48
+ # 推理
49
  with torch.inference_mode():
50
  generate_ids = model.generate(
51
  inputs=input_ids,
 
59
  use_cache=True,
60
  stopping_criteria=[stopping_criteria],
61
  )
62
+ # 截断generate_ids中的input_ids,然后解码为文本
63
  input_token_len = input_ids.shape[1]
64
  output = tokenizer.batch_decode(
65
  generate_ids[:, input_token_len:],