BAAI
/

shunxing1234 commited on
Commit
811643a
·
1 Parent(s): f1c7f0c

Upload 2 files

Browse files
Files changed (2) hide show
  1. chat_test_NBCE.py +132 -0
  2. cyg_conversation.py +131 -0
chat_test_NBCE.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! -*- coding: utf-8 -*-
2
+ # Naive Bayes-based Context Extension (NBCE)
3
+ # 使用朴素贝叶斯增加LLM的Context处理长度
4
+ # 链接:https://kexue.fm/archives/9617
5
+ # Torch 2.0 测试通过
6
+
7
+ import json
8
+ import torch
9
+ from transformers import AutoTokenizer
10
+ from transformers import AquilaForCausalLM
11
+ from transformers import TopPLogitsWarper, LogitsProcessorList
12
+ import pdb
13
+
14
+ # 加载tokenizer
15
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
16
+ tokenizer.padding_side = 'left'
17
+ tokenizer.pad_token = tokenizer.unk_token
18
+
19
+ # 加载Aquila模型
20
+ model = AquilaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
21
+ device = torch.device('cuda')
22
+ model.to(device)
23
+ # 加载示例Context
24
+ from cyg_conversation import default_conversation
25
+
26
+ conv = default_conversation.copy()
27
+ contexts = json.load(open('code_text_2.json'))
28
+
29
+ question = "请解释这段程序的功能:"
30
+ batch = []
31
+ conv.append_message(conv.roles[0], question)
32
+ conv.append_message(conv.roles[1], None)
33
+ batch.append(conv.get_prompt())
34
+ # 拼接context和question
35
+ for ci,context in enumerate(contexts):
36
+ conv1 = default_conversation.copy()
37
+ conv1.append_message(conv.roles[0], context+question)
38
+ conv1.append_message(conv.roles[1], None)
39
+ batch.append(conv1.get_prompt())
40
+ print('Context长度分布:', [len(text) for text in batch])
41
+ print('Context总长度:', sum([len(text) for text in batch]))
42
+
43
+ # Top-P截断
44
+ processors = LogitsProcessorList()
45
+ processors.append(TopPLogitsWarper(0.95))
46
+
47
+ # Copied from https://github.com/bojone/NBCE/blob/main/test.py#L51-L106
48
+ @torch.inference_mode()
49
+ def generate(max_tokens):
50
+ """Naive Bayes-based Context Extension 演示代码
51
+ """
52
+ inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(device)
53
+ input_ids = inputs.input_ids
54
+ attention_mask = inputs.attention_mask
55
+
56
+ print('input_ids', input_ids.shape)
57
+ past_key_values = None
58
+ n = input_ids.shape[0]
59
+
60
+ for i in range(max_tokens):
61
+ # 模型输出
62
+ outputs = model(input_ids=input_ids,
63
+ attention_mask=attention_mask,
64
+ return_dict=True,
65
+ use_cache=True,
66
+ past_key_values=past_key_values
67
+ )
68
+ past_key_values = outputs.past_key_values
69
+
70
+ # ===== 核心代码开始 =====
71
+ beta, eta = 0.25, 0.1
72
+ logits = outputs.logits[:, -1]
73
+ logits = logits - logits.logsumexp(dim=-1, keepdims=True)
74
+ logits = processors(input_ids, logits)
75
+ entropy = -(logits.exp() * logits.clip(-100, 0)).sum(dim=-1)
76
+ if i > 0:
77
+ entropy[k] -= eta
78
+ k = entropy[1:].argmin() + 1
79
+ logits_max = logits[k]
80
+ logits_uncond = logits[0]
81
+ logits_merged = (1 + beta) * logits_max - beta * logits_uncond
82
+ logits = torch.where(logits_uncond > -100, logits_merged, logits_max)
83
+ # ===== 核心代码结束 =====
84
+
85
+ # 构建分布,采样
86
+ # tau = 1是标准的随机采样,tau->0则是贪心搜索
87
+ # 简单起见,这里没有实现topk、topp截断
88
+ tau = 0.01
89
+ probas = torch.nn.functional.softmax(logits[None] / tau , dim=-1)
90
+ next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)
91
+ if next_tokens[0] == tokenizer.eos_token_id:
92
+ break
93
+
94
+ ret = tokenizer.batch_decode(next_tokens)
95
+ print(ret[0], flush=True, end='')
96
+
97
+ # prepare for next iteration
98
+ input_ids = next_tokens.unsqueeze(-1).tile(n, 1)
99
+ attention_mask = torch.cat([attention_mask, torch.ones(n, 1, dtype=torch.long, device=device)], dim=-1)
100
+
101
+
102
+ if __name__ == '__main__':
103
+ generate(1000)
104
+
105
+
106
+ """
107
+ ========= 输出结果参考 =========
108
+
109
+ 1.菲律宾国家电网公司,中国占股多少?
110
+ 答:中国国家电网公司持有菲律宾国家电网公司40%的股份。
111
+
112
+ 2.领英计划裁员多少人?
113
+ 答:领英计划裁员716人。
114
+
115
+ 3.吉利德收购Pharmasset的价格是多少?
116
+ 答:吉利德收购Pharmasset的价格为110亿美元。
117
+
118
+ 4.丙肝神药Sovaldi在哪一年上市?
119
+ 答:丙肝神药Sovaldi于2013年上市。
120
+
121
+ 5.中亚峰会将在哪里举行?由谁主持?
122
+ 答:中亚峰会将在陕西省西安市举行,由国家主席习近平主持。
123
+
124
+ 6.哪个演员由于侮辱人民军队而被立案调查?
125
+ 答:李昊石因在表演中存在侮辱人民军队的言论而被立案调查。
126
+
127
+ 7.哪个项目宣称“能过坦克”的水上道路?
128
+ 答:湖北恩施宣称的“能过坦克”水上道路。
129
+
130
+ 8.如果你是默沙东的CEO,你的首要任务是什么?
131
+ 答:如果我是默沙东的CEO,我的首要任务是如何让基本盘更加坚固,并通过药物联用获得更好的增长。
132
+ """
cyg_conversation.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple, Any
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class Conversation:
14
+ """A class that keeps all conversation history."""
15
+ system: str
16
+ instruction: str
17
+ roles: List[str]
18
+ messages: List[List[str]]
19
+ offset: int
20
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
21
+ sep: str = "###"
22
+ sep2: str = None
23
+
24
+ skip_next: bool = False
25
+ conv_id: Any = None
26
+
27
+ def get_prompt(self):
28
+ if self.sep_style == SeparatorStyle.SINGLE:
29
+ ret = self.system + self.sep
30
+ if self.instruction is not None and len(self.instruction) > 0:
31
+ ret += self.roles[2] + ": " + self.instruction + self.sep
32
+ for role, message in self.messages:
33
+ if message:
34
+ ret += role + ": " + message + self.sep
35
+ else:
36
+ ret += role + ":"
37
+ return ret
38
+ elif self.sep_style == SeparatorStyle.TWO:
39
+ seps = [self.sep, self.sep2]
40
+ ret = self.system + seps[0]
41
+ if self.instruction is not None and len(self.instruction) > 0:
42
+ ret += self.roles[2] + ": " + self.instruction + self.sep
43
+ for i, (role, message) in enumerate(self.messages):
44
+ if message:
45
+ ret += role + ": " + message + seps[i % 2]
46
+ else:
47
+ ret += role + ":"
48
+ return ret
49
+ else:
50
+ raise ValueError(f"Invalid style: {self.sep_style}")
51
+
52
+ def append_message(self, role, message):
53
+ self.messages.append([role, message])
54
+
55
+ def to_gradio_chatbot(self):
56
+ ret = []
57
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
58
+ if i % 2 == 0:
59
+ ret.append([msg, None])
60
+ else:
61
+ ret[-1][-1] = msg
62
+ return ret
63
+
64
+ def copy(self):
65
+ return Conversation(
66
+ system=self.system,
67
+ instruction=self.instruction,
68
+ roles=self.roles,
69
+ messages=[[x, y] for x, y in self.messages],
70
+ offset=self.offset,
71
+ sep_style=self.sep_style,
72
+ sep=self.sep,
73
+ sep2=self.sep2,
74
+ conv_id=self.conv_id)
75
+
76
+ def dict(self):
77
+ return {
78
+ "system": self.system,
79
+ "instruction": self.instruction,
80
+ "roles": self.roles,
81
+ "messages": self.messages,
82
+ "offset": self.offset,
83
+ "sep": self.sep,
84
+ "sep2": self.sep2,
85
+ "conv_id": self.conv_id,
86
+ }
87
+
88
+
89
+ conv_v1 = Conversation(
90
+ system="A chat between a curious human and an artificial intelligence assistant. "
91
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
92
+ instruction="",
93
+ roles=("Human", "Assistant", "System"),
94
+ messages=(),
95
+ offset=0,
96
+ sep_style=SeparatorStyle.SINGLE,
97
+ sep="###",
98
+ )
99
+
100
+ conv_v1_2 = Conversation(
101
+ system="A chat between a curious human and an artificial intelligence assistant. "
102
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
103
+ instruction="",
104
+ roles=("Human", "Assistant", "System"),
105
+ messages=(),
106
+ offset=0,
107
+ sep_style=SeparatorStyle.SINGLE,
108
+ sep="###",
109
+ )
110
+
111
+ conv_bair_v1 = Conversation(
112
+ system="BEGINNING OF CONVERSATION:",
113
+ instruction="",
114
+ roles=("USER", "GPT", "System"),
115
+ messages=(),
116
+ offset=0,
117
+ sep_style=SeparatorStyle.TWO,
118
+ sep=" ",
119
+ sep2="</s>",
120
+ )
121
+
122
+
123
+ default_conversation = conv_v1_2
124
+ conv_templates = {
125
+ "v1": conv_v1_2,
126
+ "bair_v1": conv_bair_v1,
127
+ }
128
+
129
+
130
+ if __name__ == "__main__":
131
+ print(default_conversation.get_prompt())