BAAI
/

shunxing1234 commited on
Commit
152ec2a
·
1 Parent(s): 0cc796c

Upload cyg_conversation.py

Browse files
Files changed (1) hide show
  1. cyg_conversation.py +155 -0
cyg_conversation.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple, Any
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class Conversation:
14
+ """A class that keeps all conversation history."""
15
+ system: str
16
+ instruction: str
17
+ roles: List[str]
18
+ messages: List[List[str]]
19
+ offset: int
20
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
21
+ sep: str = "###"
22
+ sep2: str = None
23
+
24
+ skip_next: bool = False
25
+ conv_id: Any = None
26
+
27
+ def get_prompt(self):
28
+ if self.sep_style == SeparatorStyle.SINGLE:
29
+ ret = self.system + self.sep
30
+ if self.instruction is not None and len(self.instruction) > 0:
31
+ ret += self.roles[2] + ": " + self.instruction + self.sep
32
+ for role, message in self.messages:
33
+ if message:
34
+ ret += role + ": " + message + self.sep
35
+ else:
36
+ ret += role + ":"
37
+ return ret
38
+ elif self.sep_style == SeparatorStyle.TWO:
39
+ seps = [self.sep, self.sep2]
40
+ ret = self.system + seps[0]
41
+ if self.instruction is not None and len(self.instruction) > 0:
42
+ ret += self.roles[2] + ": " + self.instruction + self.sep
43
+ for i, (role, message) in enumerate(self.messages):
44
+ if message:
45
+ ret += role + ": " + message + seps[i % 2]
46
+ else:
47
+ ret += role + ":"
48
+ return ret
49
+ else:
50
+ raise ValueError(f"Invalid style: {self.sep_style}")
51
+
52
+ def append_message(self, role, message):
53
+ self.messages.append([role, message])
54
+
55
+ def to_gradio_chatbot(self):
56
+ ret = []
57
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
58
+ if i % 2 == 0:
59
+ ret.append([msg, None])
60
+ else:
61
+ ret[-1][-1] = msg
62
+ return ret
63
+
64
+ def copy(self):
65
+ return Conversation(
66
+ system=self.system,
67
+ instruction=self.instruction,
68
+ roles=self.roles,
69
+ messages=[[x, y] for x, y in self.messages],
70
+ offset=self.offset,
71
+ sep_style=self.sep_style,
72
+ sep=self.sep,
73
+ sep2=self.sep2,
74
+ conv_id=self.conv_id)
75
+
76
+ def dict(self):
77
+ return {
78
+ "system": self.system,
79
+ "instruction": self.instruction,
80
+ "roles": self.roles,
81
+ "messages": self.messages,
82
+ "offset": self.offset,
83
+ "sep": self.sep,
84
+ "sep2": self.sep2,
85
+ "conv_id": self.conv_id,
86
+ }
87
+
88
+
89
+ conv_v1 = Conversation(
90
+ system="A chat between a curious human and an artificial intelligence assistant. "
91
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
92
+ instruction="",
93
+ roles=("Human", "Assistant", "System"),
94
+ messages=(),
95
+ offset=0,
96
+ sep_style=SeparatorStyle.SINGLE,
97
+ sep="###",
98
+ )
99
+
100
+ conv_v1_2 = Conversation(
101
+ system="A chat between a curious human and an artificial intelligence assistant. "
102
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
103
+ instruction="",
104
+ roles=("Human", "Assistant", "System"),
105
+ messages=(),
106
+ offset=0,
107
+ sep_style=SeparatorStyle.SINGLE,
108
+ sep="###",
109
+ )
110
+
111
+ conv_bair_v1 = Conversation(
112
+ system="BEGINNING OF CONVERSATION:",
113
+ instruction="",
114
+ roles=("USER", "GPT", "System"),
115
+ messages=(),
116
+ offset=0,
117
+ sep_style=SeparatorStyle.TWO,
118
+ sep=" ",
119
+ sep2="</s>",
120
+ )
121
+
122
+
123
+ default_conversation = conv_v1_2
124
+ conv_templates = {
125
+ "v1": conv_v1_2,
126
+ "bair_v1": conv_bair_v1,
127
+ }
128
+
129
+ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token):
130
+ conv = default_conversation.copy()
131
+
132
+ conv.append_message(conv.roles[1], None)
133
+ conv.append_message(conv.roles[0], text)
134
+
135
+ example = tokenizer.encode_plus(f"{conv.get_prompt()}", None, max_length=None)['input_ids']
136
+
137
+ while(len(history) > 0 and (len(example) < max_token)):
138
+ tmp = history.pop()
139
+ if tmp[0] == 'ASSISTANT':
140
+ conv.append_message(conv.roles[1], tmp[1])
141
+ else:
142
+ conv.append_message(conv.roles[0], tmp[1])
143
+ example = tokenizer.encode_plus(f"{conv.get_prompt()}", None, max_length=None)['input_ids']
144
+
145
+ if len(example) >= max_token:
146
+ conv.messages.pop()
147
+ conv.messages = conv.messages[::-1]
148
+ print('model in:', conv.get_prompt())
149
+ example = tokenizer.encode_plus(f"{conv.get_prompt()}", None, max_length=None)['input_ids']
150
+ example = example[1:-1]
151
+
152
+ return example
153
+
154
+ if __name__ == "__main__":
155
+ print(default_conversation.get_prompt())