File size: 3,185 Bytes
b0b5da4
 
 
 
 
026cf13
b0b5da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
026cf13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0b5da4
026cf13
b0b5da4
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import copy


class ChatTemplate:
    cache = {}
    roles = set()

    def __init__(self, model, im_start=r'<|im_start|>', im_end=r'<|im_end|>', nl='\n'):
        self.model = model
        self.nl = nl
        self.im_start = im_start
        self.im_start_token = model.tokenize(self.im_start.encode('utf-8'), add_bos=False, special=True)
        self.im_end = im_end
        self.im_end_nl = model.tokenize((self.im_end + self.nl).encode('utf-8'), add_bos=False, special=True)
        self.eos = [model._token_eos, self.im_end_nl[0]]
        self.onenl = [self.im_end_nl[-1]]
        tmp = model.tokenize(('\r' + self.nl).encode('utf-8'), add_bos=False, special=True)
        if len(tmp) == 1:
            self.onenl.append(tmp[0])
        self.onerl = model.tokenize(b'\r', add_bos=False, special=True)
        self.nlnl = None
        tmp = model.tokenize((self.nl + self.nl).encode('utf-8'), add_bos=False, special=True)
        if len(tmp) == 1:
            self.nlnl = tmp[0]
        print('ChatTemplate', self.eos, self.im_end_nl, self.onerl, self.onenl, self.nlnl)

    def _get(self, key: str):
        if key in self.cache:
            return copy.deepcopy(self.cache[key])  # 深拷贝一下
        else:
            value = self.model.tokenize((self.im_start + key + self.nl).encode('utf-8'), add_bos=False, special=True)
            self.cache[key] = copy.deepcopy(value)  # 深拷贝一下
            return value

    def _add_role(self, _role):
        if _role:
            self.roles.add('\n' + _role)

    def eos_in_role(self, history: str, t_bot):
        if not (history.endswith('\n') or history.endswith('\r')):
            return 0
        tmp = history.rstrip()
        for _role in self.roles:
            if tmp.endswith(_role):
                n = len(t_bot)
                for i in range(1, n):  # 找出需要弃置的tokens长度
                    tmp = self.model.str_detokenize(t_bot[n - i:])
                    if tmp.rstrip().endswith(_role):
                        print('eos_in_role', t_bot[n - i:], repr(tmp))
                        return i
                print('eos_in_role missing')
                break
        return 0

    def eos_in_nlnl(self, history: str, t_bot):
        if not (history.endswith('\n\n') or history.endswith('\n\r\n')):
            return 0
        n = len(t_bot)
        for i in range(1, n):  # 找出需要弃置的tokens长度
            tmp = self.model.str_detokenize(t_bot[n - i:])
            if tmp.endswith('\n\n') or tmp.endswith('\n\r\n'):
                if tmp.startswith(']'):  # 避免误判
                    return 0
                print('eos_in_nlnl', t_bot[n - i:], repr(tmp))
                return i
        print('eos_in_nlnl missing')
        return 0

    def __call__(self, _role, prompt=None):
        self._add_role(_role)
        if prompt is None:
            return self._get(_role)
        # print(_role, prompt, self.cache)
        prompt = self.im_start + _role + self.nl + prompt
        prompt = self.model.tokenize(prompt.encode('utf-8'), add_bos=False, special=True) + self.im_end_nl
        # print(self.model.str_detokenize(prompt), prompt)
        return prompt