CaiRou-Huang commited on
Commit
dd62f9f
1 Parent(s): 8d1c2f5

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +154 -0
  2. conversation.py +271 -0
  3. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import gradio as gr
4
+ from text_generation import Client
5
+ from conversation import get_default_conv_template
6
+ from transformers import AutoTokenizer
7
+
8
+
9
+ endpoint_url = os.environ.get("ENDPOINT_URL", "http://127.0.0.1:8080")
10
+ client = Client(endpoint_url, timeout=120)
11
+ eos_token = "</s>"
12
+ max_new_tokens = 512
13
+ max_prompt_length = 4096 - max_new_tokens - 10
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained("yentinglin/Taiwan-LLaMa-v1.0")
16
+
17
+ with gr.Blocks() as demo:
18
+ chatbot = gr.Chatbot()
19
+ msg = gr.Textbox()
20
+ clear = gr.Button("Clear")
21
+
22
+ def user(user_message, history):
23
+ return "", history + [[user_message, None]]
24
+
25
+ def bot(history):
26
+ conv = get_default_conv_template("vicuna").copy()
27
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # map human to USER and gpt to ASSISTANT
28
+ for user, bot in history:
29
+ conv.append_message(roles['human'], user)
30
+ conv.append_message(roles["gpt"], bot)
31
+ msg = conv.get_prompt()
32
+ prompt_tokens = tokenizer.encode(msg)
33
+ length_of_prompt = len(prompt_tokens)
34
+ if length_of_prompt > max_prompt_length:
35
+ msg = tokenizer.decode(prompt_tokens[-max_prompt_length+1:])
36
+
37
+ history[-1][1] = ""
38
+ for response in client.generate_stream(
39
+ msg,
40
+ max_new_tokens=max_new_tokens,
41
+ ):
42
+ if not response.token.special:
43
+ character = response.token.text
44
+ history[-1][1] += character
45
+ yield history
46
+
47
+
48
+ def generate_response(history, max_new_token=512, top_p=0.9, temperature=0.8, do_sample=True):
49
+ conv = get_default_conv_template("vicuna").copy()
50
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # map human to USER and gpt to ASSISTANT
51
+ for user, bot in history:
52
+ conv.append_message(roles['human'], user)
53
+ conv.append_message(roles["gpt"], bot)
54
+ msg = conv.get_prompt()
55
+
56
+ for response in client.generate_stream(
57
+ msg,
58
+ max_new_tokens=max_new_token,
59
+ top_p=top_p,
60
+ temperature=temperature,
61
+ do_sample=do_sample,
62
+ ):
63
+ history[-1][1] = ""
64
+ # if not response.token.special:
65
+ character = response.token.text
66
+ history[-1][1] += character
67
+ print(history[-1][1])
68
+ time.sleep(0.05)
69
+ yield history
70
+
71
+
72
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
73
+ bot, chatbot, chatbot
74
+ )
75
+ clear.click(lambda: None, None, chatbot, queue=False)
76
+
77
+ demo.queue()
78
+ demo.launch()
79
+
80
+ #
81
+ # with gr.Blocks() as demo:
82
+ # chatbot = gr.Chatbot()
83
+ # with gr.Row():
84
+ # with gr.Column(scale=4):
85
+ # with gr.Column(scale=12):
86
+ # user_input = gr.Textbox(
87
+ # show_label=False,
88
+ # placeholder="Shift + Enter傳送...",
89
+ # lines=10).style(
90
+ # container=False)
91
+ # with gr.Column(min_width=32, scale=1):
92
+ # submitBtn = gr.Button("Submit", variant="primary")
93
+ # with gr.Column(scale=1):
94
+ # emptyBtn = gr.Button("Clear History")
95
+ # max_new_token = gr.Slider(
96
+ # 1,
97
+ # 1024,
98
+ # value=128,
99
+ # step=1.0,
100
+ # label="Maximum New Token Length",
101
+ # interactive=True)
102
+ # top_p = gr.Slider(0, 1, value=0.9, step=0.01,
103
+ # label="Top P", interactive=True)
104
+ # temperature = gr.Slider(
105
+ # 0,
106
+ # 1,
107
+ # value=0.5,
108
+ # step=0.01,
109
+ # label="Temperature",
110
+ # interactive=True)
111
+ # top_k = gr.Slider(1, 40, value=40, step=1,
112
+ # label="Top K", interactive=True)
113
+ # do_sample = gr.Checkbox(
114
+ # value=True,
115
+ # label="Do Sample",
116
+ # info="use random sample strategy",
117
+ # interactive=True)
118
+ # repetition_penalty = gr.Slider(
119
+ # 1.0,
120
+ # 3.0,
121
+ # value=1.1,
122
+ # step=0.1,
123
+ # label="Repetition Penalty",
124
+ # interactive=True)
125
+ #
126
+ # params = [user_input, chatbot]
127
+ # predict_params = [
128
+ # chatbot,
129
+ # max_new_token,
130
+ # top_p,
131
+ # temperature,
132
+ # top_k,
133
+ # do_sample,
134
+ # repetition_penalty]
135
+ #
136
+ # submitBtn.click(
137
+ # generate_response,
138
+ # [user_input, max_new_token, top_p, top_k, temperature, do_sample, repetition_penalty],
139
+ # [chatbot],
140
+ # queue=False
141
+ # )
142
+ #
143
+ # user_input.submit(
144
+ # generate_response,
145
+ # [user_input, max_new_token, top_p, top_k, temperature, do_sample, repetition_penalty],
146
+ # [chatbot],
147
+ # queue=False
148
+ # )
149
+ #
150
+ # submitBtn.click(lambda: None, [], [user_input])
151
+ #
152
+ # emptyBtn.click(lambda: chatbot.reset(), outputs=[chatbot], show_progress=True)
153
+ #
154
+ # demo.launch()
conversation.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt template.
3
+ Now we support
4
+ - Vicuna
5
+ - Koala
6
+ - OpenAssistant/oasst-sft-1-pythia-12b
7
+ - StabilityAI/stablelm-tuned-alpha-7b
8
+ - databricks/dolly-v2-12b
9
+ - THUDM/chatglm-6b
10
+ - Alpaca/LLaMa
11
+ """
12
+
13
+ import dataclasses
14
+ from enum import auto, Enum
15
+ from typing import List, Tuple, Any
16
+
17
+
18
+ class SeparatorStyle(Enum):
19
+ """Different separator style."""
20
+
21
+ SINGLE = auto()
22
+ TWO = auto()
23
+ DOLLY = auto()
24
+ OASST_PYTHIA = auto()
25
+
26
+
27
+ @dataclasses.dataclass
28
+ class Conversation:
29
+ """A class that keeps all conversation history."""
30
+
31
+ system: str
32
+ roles: List[str]
33
+ messages: List[List[str]]
34
+ offset: int
35
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
36
+ sep: str = "###"
37
+ sep2: str = None
38
+
39
+ # Used for gradio server
40
+ skip_next: bool = False
41
+ conv_id: Any = None
42
+
43
+ def get_prompt(self):
44
+ if self.sep_style == SeparatorStyle.SINGLE:
45
+ ret = self.system
46
+ for role, message in self.messages:
47
+ if message:
48
+ ret += self.sep + " " + role + ": " + message
49
+ else:
50
+ ret += self.sep + " " + role + ":"
51
+ return ret
52
+ elif self.sep_style == SeparatorStyle.TWO:
53
+ seps = [self.sep, self.sep2]
54
+ ret = self.system + seps[0]
55
+ for i, (role, message) in enumerate(self.messages):
56
+ if message:
57
+ ret += role + ": " + message + seps[i % 2]
58
+ else:
59
+ ret += role + ":"
60
+ return ret
61
+ elif self.sep_style == SeparatorStyle.DOLLY:
62
+ seps = [self.sep, self.sep2]
63
+ ret = self.system
64
+ for i, (role, message) in enumerate(self.messages):
65
+ if message:
66
+ ret += role + ":\n" + message + seps[i % 2]
67
+ if i % 2 == 1:
68
+ ret += "\n\n"
69
+ else:
70
+ ret += role + ":\n"
71
+ return ret
72
+ elif self.sep_style == SeparatorStyle.OASST_PYTHIA:
73
+ ret = self.system
74
+ for role, message in self.messages:
75
+ if message:
76
+ ret += role + message + self.sep
77
+ else:
78
+ ret += role
79
+ return ret
80
+ else:
81
+ raise ValueError(f"Invalid style: {self.sep_style}")
82
+
83
+ def append_message(self, role, message):
84
+ self.messages.append([role, message])
85
+
86
+ def to_gradio_chatbot(self):
87
+ ret = []
88
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
89
+ if i % 2 == 0:
90
+ ret.append([msg, None])
91
+ else:
92
+ ret[-1][-1] = msg
93
+ return ret
94
+
95
+ def copy(self):
96
+ return Conversation(
97
+ system=self.system,
98
+ roles=self.roles,
99
+ messages=[[x, y] for x, y in self.messages],
100
+ offset=self.offset,
101
+ sep_style=self.sep_style,
102
+ sep=self.sep,
103
+ sep2=self.sep2,
104
+ conv_id=self.conv_id,
105
+ )
106
+
107
+ def dict(self):
108
+ return {
109
+ "system": self.system,
110
+ "roles": self.roles,
111
+ "messages": self.messages,
112
+ "offset": self.offset,
113
+ "sep": self.sep,
114
+ "sep2": self.sep2,
115
+ "conv_id": self.conv_id,
116
+ }
117
+
118
+
119
+ conv_one_shot = Conversation(
120
+ system="A chat between a curious human and an artificial intelligence assistant. "
121
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
122
+ roles=("Human", "Assistant"),
123
+ messages=(
124
+ (
125
+ "Human",
126
+ "What are the key differences between renewable and non-renewable energy sources?",
127
+ ),
128
+ (
129
+ "Assistant",
130
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
131
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
132
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
133
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
134
+ "renewable and non-renewable energy sources:\n"
135
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
136
+ "energy sources are finite and will eventually run out.\n"
137
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
138
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
139
+ "and other negative effects.\n"
140
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
141
+ "have lower operational costs than non-renewable sources.\n"
142
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
143
+ "locations than non-renewable sources.\n"
144
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
145
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
146
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
147
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.",
148
+ ),
149
+ ),
150
+ offset=2,
151
+ sep_style=SeparatorStyle.SINGLE,
152
+ sep="###",
153
+ )
154
+
155
+
156
+ conv_vicuna_v1_1 = Conversation(
157
+ system="A chat between a curious user and an artificial intelligence assistant. "
158
+ "The assistant gives helpful, detailed, and polite answers to the user's questions. You are built by NTU Miulab by Yen-Ting Lin for research purpose.",
159
+ # system="一位好奇的用戶和一個人工智能助理之間的聊天。你是一位助理。請對用戶的問題提供有用、詳細和有禮貌的答案。",
160
+ roles=("USER", "ASSISTANT"),
161
+ messages=(),
162
+ offset=0,
163
+ sep_style=SeparatorStyle.TWO,
164
+ sep=" ",
165
+ sep2="</s>",
166
+ )
167
+
168
+ conv_story = Conversation(
169
+ system="A chat between a curious user and an artificial intelligence assistant. "
170
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
171
+ roles=("USER", "ASSISTANT"),
172
+ messages=(),
173
+ offset=0,
174
+ sep_style=SeparatorStyle.TWO,
175
+ sep=" ",
176
+ sep2="<|endoftext|>",
177
+ )
178
+
179
+ conv_koala_v1 = Conversation(
180
+ system="BEGINNING OF CONVERSATION:",
181
+ roles=("USER", "GPT"),
182
+ messages=(),
183
+ offset=0,
184
+ sep_style=SeparatorStyle.TWO,
185
+ sep=" ",
186
+ sep2="</s>",
187
+ )
188
+
189
+ conv_dolly = Conversation(
190
+ system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
191
+ roles=("### Instruction", "### Response"),
192
+ messages=(),
193
+ offset=0,
194
+ sep_style=SeparatorStyle.DOLLY,
195
+ sep="\n\n",
196
+ sep2="### End",
197
+ )
198
+
199
+ conv_oasst = Conversation(
200
+ system="",
201
+ roles=("<|prompter|>", "<|assistant|>"),
202
+ messages=(),
203
+ offset=0,
204
+ sep_style=SeparatorStyle.OASST_PYTHIA,
205
+ sep="<|endoftext|>",
206
+ )
207
+
208
+ conv_stablelm = Conversation(
209
+ system="""<|SYSTEM|># StableLM Tuned (Alpha version)
210
+ - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
211
+ - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
212
+ - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
213
+ - StableLM will refuse to participate in anything that could harm a human.
214
+ """,
215
+ roles=("<|USER|>", "<|ASSISTANT|>"),
216
+ messages=(),
217
+ offset=0,
218
+ sep_style=SeparatorStyle.OASST_PYTHIA,
219
+ sep="",
220
+ )
221
+
222
+ conv_templates = {
223
+ "conv_one_shot": conv_one_shot,
224
+ "vicuna_v1.1": conv_vicuna_v1_1,
225
+ "koala_v1": conv_koala_v1,
226
+ "dolly": conv_dolly,
227
+ "oasst": conv_oasst,
228
+ }
229
+
230
+
231
+ def get_default_conv_template(model_name):
232
+ model_name = model_name.lower()
233
+ if "vicuna" in model_name or "output" in model_name:
234
+ return conv_vicuna_v1_1
235
+ elif "koala" in model_name:
236
+ return conv_koala_v1
237
+ elif "dolly-v2" in model_name:
238
+ return conv_dolly
239
+ elif "oasst" in model_name and "pythia" in model_name:
240
+ return conv_oasst
241
+ elif "stablelm" in model_name:
242
+ return conv_stablelm
243
+ return conv_one_shot
244
+
245
+
246
+ def compute_skip_echo_len(model_name, conv, prompt):
247
+ model_name = model_name.lower()
248
+ if "chatglm" in model_name:
249
+ skip_echo_len = len(conv.messages[-2][1]) + 1
250
+ elif "dolly-v2" in model_name:
251
+ special_toks = ["### Instruction:", "### Response:", "### End"]
252
+ skip_echo_len = len(prompt)
253
+ for tok in special_toks:
254
+ skip_echo_len -= prompt.count(tok) * len(tok)
255
+ elif "oasst" in model_name and "pythia" in model_name:
256
+ special_toks = ["<|prompter|>", "<|assistant|>", "<|endoftext|>"]
257
+ skip_echo_len = len(prompt)
258
+ for tok in special_toks:
259
+ skip_echo_len -= prompt.count(tok) * len(tok)
260
+ elif "stablelm" in model_name:
261
+ special_toks = ["<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"]
262
+ skip_echo_len = len(prompt)
263
+ for tok in special_toks:
264
+ skip_echo_len -= prompt.count(tok) * len(tok)
265
+ else:
266
+ skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
267
+ return skip_echo_len
268
+
269
+
270
+ if __name__ == "__main__":
271
+ print(default_conversation.get_prompt())
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ text-generation==0.6.0
2
+ transformers==4.31.0