yentinglin commited on
Commit
0cfd320
1 Parent(s): 107aff6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +8 -7
  2. conversation.py +271 -0
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
 
3
  import gradio as gr
4
  from text_generation import Client
 
5
 
6
 
7
  eos_token = "</s>"
@@ -23,12 +24,12 @@ endpoint_url = os.environ.get("ENDPOINT_URL")
23
  client = Client(endpoint_url, timeout=120)
24
 
25
  def generate_response(user_input, max_new_token, top_p, top_k, temperature, do_sample, repetition_penalty):
26
- msg = _concat_messages([
27
- #{"role": "system", "content": "你是一個由國立台灣大學的MiuLab實驗室開發的大型語言模型。你基於Transformer架構被訓練,並已經經過大量的台灣中文語料庫的訓練。你的設計目標是理解和生成優雅的繁體中文,並具有跨語境和跨領域的對話能力。使用者可以向你提問任何問題或提出任何話題,並期待從你那裡得到高質量的回答。你應該要盡量幫助使用者解決問題,提供他們需要的資訊,並在適當時候給予建議。"},
28
- {"role": "user", "content": user_input},
29
- ])
30
- msg += "<|assistant|>\n"
31
- #msg = user_input.strip()
32
 
33
  res = client.generate(
34
  msg,
@@ -115,4 +116,4 @@ with gr.Blocks() as demo:
115
 
116
  emptyBtn.click(lambda: chatbot.reset(), outputs=[chatbot], show_progress=True)
117
 
118
- demo.launch()
 
2
 
3
  import gradio as gr
4
  from text_generation import Client
5
+ from conversation import get_default_conv_template, SeparatorStyle
6
 
7
 
8
  eos_token = "</s>"
 
24
  client = Client(endpoint_url, timeout=120)
25
 
26
  def generate_response(user_input, max_new_token, top_p, top_k, temperature, do_sample, repetition_penalty):
27
+ user_input = user_input.strip()
28
+ conv = get_default_conv_template("vicuna").copy()
29
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # map human to USER and gpt to ASSISTANT
30
+ role = roles["human"]
31
+ conv.append_message(role, user_input)
32
+ msg = conv.get_prompt()
33
 
34
  res = client.generate(
35
  msg,
 
116
 
117
  emptyBtn.click(lambda: chatbot.reset(), outputs=[chatbot], show_progress=True)
118
 
119
+ demo.launch()
conversation.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt template.
3
+ Now we support
4
+ - Vicuna
5
+ - Koala
6
+ - OpenAssistant/oasst-sft-1-pythia-12b
7
+ - StabilityAI/stablelm-tuned-alpha-7b
8
+ - databricks/dolly-v2-12b
9
+ - THUDM/chatglm-6b
10
+ - Alpaca/LLaMa
11
+ """
12
+
13
+ import dataclasses
14
+ from enum import auto, Enum
15
+ from typing import List, Tuple, Any
16
+
17
+
18
+ class SeparatorStyle(Enum):
19
+ """Different separator style."""
20
+
21
+ SINGLE = auto()
22
+ TWO = auto()
23
+ DOLLY = auto()
24
+ OASST_PYTHIA = auto()
25
+
26
+
27
+ @dataclasses.dataclass
28
+ class Conversation:
29
+ """A class that keeps all conversation history."""
30
+
31
+ system: str
32
+ roles: List[str]
33
+ messages: List[List[str]]
34
+ offset: int
35
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
36
+ sep: str = "###"
37
+ sep2: str = None
38
+
39
+ # Used for gradio server
40
+ skip_next: bool = False
41
+ conv_id: Any = None
42
+
43
+ def get_prompt(self):
44
+ if self.sep_style == SeparatorStyle.SINGLE:
45
+ ret = self.system
46
+ for role, message in self.messages:
47
+ if message:
48
+ ret += self.sep + " " + role + ": " + message
49
+ else:
50
+ ret += self.sep + " " + role + ":"
51
+ return ret
52
+ elif self.sep_style == SeparatorStyle.TWO:
53
+ seps = [self.sep, self.sep2]
54
+ ret = self.system + seps[0]
55
+ for i, (role, message) in enumerate(self.messages):
56
+ if message:
57
+ ret += role + ": " + message + seps[i % 2]
58
+ else:
59
+ ret += role + ":"
60
+ return ret
61
+ elif self.sep_style == SeparatorStyle.DOLLY:
62
+ seps = [self.sep, self.sep2]
63
+ ret = self.system
64
+ for i, (role, message) in enumerate(self.messages):
65
+ if message:
66
+ ret += role + ":\n" + message + seps[i % 2]
67
+ if i % 2 == 1:
68
+ ret += "\n\n"
69
+ else:
70
+ ret += role + ":\n"
71
+ return ret
72
+ elif self.sep_style == SeparatorStyle.OASST_PYTHIA:
73
+ ret = self.system
74
+ for role, message in self.messages:
75
+ if message:
76
+ ret += role + message + self.sep
77
+ else:
78
+ ret += role
79
+ return ret
80
+ else:
81
+ raise ValueError(f"Invalid style: {self.sep_style}")
82
+
83
+ def append_message(self, role, message):
84
+ self.messages.append([role, message])
85
+
86
+ def to_gradio_chatbot(self):
87
+ ret = []
88
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
89
+ if i % 2 == 0:
90
+ ret.append([msg, None])
91
+ else:
92
+ ret[-1][-1] = msg
93
+ return ret
94
+
95
+ def copy(self):
96
+ return Conversation(
97
+ system=self.system,
98
+ roles=self.roles,
99
+ messages=[[x, y] for x, y in self.messages],
100
+ offset=self.offset,
101
+ sep_style=self.sep_style,
102
+ sep=self.sep,
103
+ sep2=self.sep2,
104
+ conv_id=self.conv_id,
105
+ )
106
+
107
+ def dict(self):
108
+ return {
109
+ "system": self.system,
110
+ "roles": self.roles,
111
+ "messages": self.messages,
112
+ "offset": self.offset,
113
+ "sep": self.sep,
114
+ "sep2": self.sep2,
115
+ "conv_id": self.conv_id,
116
+ }
117
+
118
+
119
+ conv_one_shot = Conversation(
120
+ system="A chat between a curious human and an artificial intelligence assistant. "
121
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
122
+ roles=("Human", "Assistant"),
123
+ messages=(
124
+ (
125
+ "Human",
126
+ "What are the key differences between renewable and non-renewable energy sources?",
127
+ ),
128
+ (
129
+ "Assistant",
130
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
131
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
132
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
133
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
134
+ "renewable and non-renewable energy sources:\n"
135
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
136
+ "energy sources are finite and will eventually run out.\n"
137
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
138
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
139
+ "and other negative effects.\n"
140
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
141
+ "have lower operational costs than non-renewable sources.\n"
142
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
143
+ "locations than non-renewable sources.\n"
144
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
145
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
146
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
147
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.",
148
+ ),
149
+ ),
150
+ offset=2,
151
+ sep_style=SeparatorStyle.SINGLE,
152
+ sep="###",
153
+ )
154
+
155
+
156
+ conv_vicuna_v1_1 = Conversation(
157
+ system="A chat between a curious user and an artificial intelligence assistant. "
158
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
159
+ # system="一位好奇的用戶和一個人工智能助理之間的聊天。你是一位助理。請對用戶的問題提供有用、詳細和有禮貌的答案。",
160
+ roles=("USER", "ASSISTANT"),
161
+ messages=(),
162
+ offset=0,
163
+ sep_style=SeparatorStyle.TWO,
164
+ sep=" ",
165
+ sep2="</s>",
166
+ )
167
+
168
+ conv_story = Conversation(
169
+ system="A chat between a curious user and an artificial intelligence assistant. "
170
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
171
+ roles=("USER", "ASSISTANT"),
172
+ messages=(),
173
+ offset=0,
174
+ sep_style=SeparatorStyle.TWO,
175
+ sep=" ",
176
+ sep2="<|endoftext|>",
177
+ )
178
+
179
+ conv_koala_v1 = Conversation(
180
+ system="BEGINNING OF CONVERSATION:",
181
+ roles=("USER", "GPT"),
182
+ messages=(),
183
+ offset=0,
184
+ sep_style=SeparatorStyle.TWO,
185
+ sep=" ",
186
+ sep2="</s>",
187
+ )
188
+
189
+ conv_dolly = Conversation(
190
+ system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
191
+ roles=("### Instruction", "### Response"),
192
+ messages=(),
193
+ offset=0,
194
+ sep_style=SeparatorStyle.DOLLY,
195
+ sep="\n\n",
196
+ sep2="### End",
197
+ )
198
+
199
+ conv_oasst = Conversation(
200
+ system="",
201
+ roles=("<|prompter|>", "<|assistant|>"),
202
+ messages=(),
203
+ offset=0,
204
+ sep_style=SeparatorStyle.OASST_PYTHIA,
205
+ sep="<|endoftext|>",
206
+ )
207
+
208
+ conv_stablelm = Conversation(
209
+ system="""<|SYSTEM|># StableLM Tuned (Alpha version)
210
+ - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
211
+ - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
212
+ - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
213
+ - StableLM will refuse to participate in anything that could harm a human.
214
+ """,
215
+ roles=("<|USER|>", "<|ASSISTANT|>"),
216
+ messages=(),
217
+ offset=0,
218
+ sep_style=SeparatorStyle.OASST_PYTHIA,
219
+ sep="",
220
+ )
221
+
222
+ conv_templates = {
223
+ "conv_one_shot": conv_one_shot,
224
+ "vicuna_v1.1": conv_vicuna_v1_1,
225
+ "koala_v1": conv_koala_v1,
226
+ "dolly": conv_dolly,
227
+ "oasst": conv_oasst,
228
+ }
229
+
230
+
231
+ def get_default_conv_template(model_name):
232
+ model_name = model_name.lower()
233
+ if "vicuna" in model_name or "output" in model_name:
234
+ return conv_vicuna_v1_1
235
+ elif "koala" in model_name:
236
+ return conv_koala_v1
237
+ elif "dolly-v2" in model_name:
238
+ return conv_dolly
239
+ elif "oasst" in model_name and "pythia" in model_name:
240
+ return conv_oasst
241
+ elif "stablelm" in model_name:
242
+ return conv_stablelm
243
+ return conv_one_shot
244
+
245
+
246
+ def compute_skip_echo_len(model_name, conv, prompt):
247
+ model_name = model_name.lower()
248
+ if "chatglm" in model_name:
249
+ skip_echo_len = len(conv.messages[-2][1]) + 1
250
+ elif "dolly-v2" in model_name:
251
+ special_toks = ["### Instruction:", "### Response:", "### End"]
252
+ skip_echo_len = len(prompt)
253
+ for tok in special_toks:
254
+ skip_echo_len -= prompt.count(tok) * len(tok)
255
+ elif "oasst" in model_name and "pythia" in model_name:
256
+ special_toks = ["<|prompter|>", "<|assistant|>", "<|endoftext|>"]
257
+ skip_echo_len = len(prompt)
258
+ for tok in special_toks:
259
+ skip_echo_len -= prompt.count(tok) * len(tok)
260
+ elif "stablelm" in model_name:
261
+ special_toks = ["<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"]
262
+ skip_echo_len = len(prompt)
263
+ for tok in special_toks:
264
+ skip_echo_len -= prompt.count(tok) * len(tok)
265
+ else:
266
+ skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
267
+ return skip_echo_len
268
+
269
+
270
+ if __name__ == "__main__":
271
+ print(default_conversation.get_prompt())