Limour commited on
Commit
b0b5da4
1 Parent(s): f4e6998

Upload chat_template.py

Browse files
Files changed (1) hide show
  1. chat_template.py +41 -0
chat_template.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+
4
+ class ChatTemplate:
5
+ cache = {}
6
+
7
+ def __init__(self, model, im_start=r'<|im_start|>', im_end=r'<|im_end|>', nl='\n'):
8
+ self.model = model
9
+ self.nl = nl
10
+ self.im_start = im_start
11
+ self.im_start_token = model.tokenize(self.im_start.encode('utf-8'), add_bos=False, special=True)
12
+ self.im_end = im_end
13
+ self.im_end_nl = model.tokenize((self.im_end + self.nl).encode('utf-8'), add_bos=False, special=True)
14
+ self.eos = [model._token_eos, self.im_end_nl[0]]
15
+ self.onenl = [self.im_end_nl[-1]]
16
+ tmp = model.tokenize(('\r' + self.nl).encode('utf-8'), add_bos=False, special=True)
17
+ if len(tmp) == 1:
18
+ self.onenl.append(tmp[0])
19
+ self.onerl = model.tokenize(b'\r', add_bos=False, special=True)
20
+ self.nlnl = None
21
+ tmp = model.tokenize((self.nl + self.nl).encode('utf-8'), add_bos=False, special=True)
22
+ if len(tmp) == 1:
23
+ self.nlnl = tmp[0]
24
+ print('ChatTemplate', self.eos, self.im_end_nl, self.onerl, self.onenl, self.nlnl)
25
+
26
+ def _get(self, key: str):
27
+ if key in self.cache:
28
+ return copy.deepcopy(self.cache[key]) # 深拷贝一下
29
+ else:
30
+ value = self.model.tokenize((self.im_start + key + self.nl).encode('utf-8'), add_bos=False, special=True)
31
+ self.cache[key] = copy.deepcopy(value) # 深拷贝一下
32
+ return value
33
+
34
+ def __call__(self, _role, prompt=None):
35
+ if prompt is None:
36
+ return self._get(_role)
37
+ # print(_role, prompt, self.cache)
38
+ prompt = self.im_start + _role + self.nl + prompt
39
+ prompt = self.model.tokenize(prompt.encode('utf-8'), add_bos=False, special=True) + self.im_end_nl
40
+ # print(self.model.str_detokenize(prompt), prompt)
41
+ return prompt