RomanTee commited on
Commit
90463c4
·
verified ·
1 Parent(s): 6cd5918

Upload 10 files

Browse files
Files changed (10) hide show
  1. __init__.py +1 -0
  2. constants.py +62 -0
  3. conversation.py +888 -0
  4. score_app.py +154 -0
  5. test_readme_train.sh +24 -0
  6. train_lora.sh +29 -0
  7. train_vicuna_13b.sh +26 -0
  8. train_vicuna_7b.sh +26 -0
  9. upload_pypi.sh +3 -0
  10. utils.py +296 -0
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.2.20"
constants.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Global constants.
3
+ """
4
+
5
+ from enum import IntEnum
6
+ import os
7
+
8
+ REPO_PATH = os.path.dirname(os.path.dirname(__file__))
9
+
10
+ ##### For the gradio web server
11
+ SERVER_ERROR_MSG = (
12
+ "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
13
+ )
14
+ MODERATION_MSG = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE FIX YOUR INPUT AND TRY AGAIN."
15
+ CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
16
+ INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE."
17
+ # Maximum input length
18
+ INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 2560))
19
+ # Maximum conversation turns
20
+ CONVERSATION_TURN_LIMIT = 50
21
+ # Session expiration time
22
+ SESSION_EXPIRATION_TIME = 3600
23
+ # The output dir of log files
24
+ LOGDIR = ".logs/"
25
+
26
+
27
+ ##### For the controller and workers (could be overwritten through ENV variables.)
28
+ CONTROLLER_HEART_BEAT_EXPIRATION = int(
29
+ os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90)
30
+ )
31
+ WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45))
32
+ WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100))
33
+ WORKER_API_EMBEDDING_BATCH_SIZE = int(
34
+ os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4)
35
+ )
36
+
37
+
38
+ class ErrorCode(IntEnum):
39
+ """
40
+ https://platform.openai.com/docs/guides/error-codes/api-errors
41
+ """
42
+
43
+ VALIDATION_TYPE_ERROR = 40001
44
+
45
+ INVALID_AUTH_KEY = 40101
46
+ INCORRECT_AUTH_KEY = 40102
47
+ NO_PERMISSION = 40103
48
+
49
+ INVALID_MODEL = 40301
50
+ PARAM_OUT_OF_RANGE = 40302
51
+ CONTEXT_OVERFLOW = 40303
52
+
53
+ RATE_LIMIT = 42901
54
+ QUOTA_EXCEEDED = 42902
55
+ ENGINE_OVERLOADED = 42903
56
+
57
+ INTERNAL_ERROR = 50001
58
+ CUDA_OUT_OF_MEMORY = 50002
59
+ GRADIO_REQUEST_ERROR = 50003
60
+ GRADIO_STREAM_UNKNOWN_ERROR = 50004
61
+ CONTROLLER_NO_WORKER = 50005
62
+ CONTROLLER_WORKER_TIMEOUT = 50006
conversation.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt templates.
3
+
4
+ We kindly request that you import fastchat instead of copying this file if you want to use it.
5
+ You can contribute back the changes you want to make.
6
+ """
7
+
8
+ import dataclasses
9
+ from enum import auto, IntEnum
10
+ from typing import List, Any, Dict
11
+
12
+
13
+ class SeparatorStyle(IntEnum):
14
+ """Separator styles."""
15
+
16
+ ADD_COLON_SINGLE = auto()
17
+ ADD_COLON_TWO = auto()
18
+ ADD_COLON_SPACE_SINGLE = auto()
19
+ NO_COLON_SINGLE = auto()
20
+ NO_COLON_TWO = auto()
21
+ ADD_NEW_LINE_SINGLE = auto()
22
+ LLAMA2 = auto()
23
+ CHATGLM = auto()
24
+ CHATML = auto()
25
+ CHATINTERN = auto()
26
+ DOLLY = auto()
27
+ RWKV = auto()
28
+ PHOENIX = auto()
29
+ ROBIN = auto()
30
+
31
+
32
+ @dataclasses.dataclass
33
+ class Conversation:
34
+ """A class that manages prompt templates and keeps all conversation history."""
35
+
36
+ # The name of this template
37
+ name: str
38
+ # The system prompt
39
+ system: str
40
+ # Two roles
41
+ roles: List[str]
42
+ # All messages. Each item is (role, message).
43
+ messages: List[List[str]]
44
+ # The number of few shot examples
45
+ offset: int
46
+ # Separators
47
+ sep_style: SeparatorStyle
48
+ sep: str
49
+ sep2: str = None
50
+ # Stop criteria (the default one is EOS token)
51
+ stop_str: str = None
52
+ # Stops generation if meeting any token in this list
53
+ stop_token_ids: List[int] = None
54
+
55
+ def get_prompt(self) -> str:
56
+ """Get the prompt for generation."""
57
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
58
+ ret = self.system + self.sep
59
+ for role, message in self.messages:
60
+ if message:
61
+ ret += role + ": " + message + self.sep
62
+ else:
63
+ ret += role + ":"
64
+ return ret
65
+ elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
66
+ seps = [self.sep, self.sep2]
67
+ ret = self.system + seps[0]
68
+ for i, (role, message) in enumerate(self.messages):
69
+ if message:
70
+ ret += role + ": " + message + seps[i % 2]
71
+ else:
72
+ ret += role + ":"
73
+ return ret
74
+ elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
75
+ ret = self.system + self.sep
76
+ for role, message in self.messages:
77
+ if message:
78
+ ret += role + ": " + message + self.sep
79
+ else:
80
+ ret += role + ": " # must be end with a space
81
+ return ret
82
+ elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
83
+ ret = "" if self.system == "" else self.system + self.sep
84
+ for role, message in self.messages:
85
+ if message:
86
+ ret += role + "\n" + message + self.sep
87
+ else:
88
+ ret += role + "\n"
89
+ return ret
90
+ elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
91
+ ret = self.system
92
+ for role, message in self.messages:
93
+ if message:
94
+ ret += role + message + self.sep
95
+ else:
96
+ ret += role
97
+ return ret
98
+ elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
99
+ seps = [self.sep, self.sep2]
100
+ ret = self.system
101
+ for i, (role, message) in enumerate(self.messages):
102
+ if message:
103
+ ret += role + message + seps[i % 2]
104
+ else:
105
+ ret += role
106
+ return ret
107
+ elif self.sep_style == SeparatorStyle.RWKV:
108
+ ret = self.system
109
+ for i, (role, message) in enumerate(self.messages):
110
+ if message:
111
+ ret += (
112
+ role
113
+ + ": "
114
+ + message.replace("\r\n", "\n").replace("\n\n", "\n")
115
+ )
116
+ ret += "\n\n"
117
+ else:
118
+ ret += role + ":"
119
+ return ret
120
+ elif self.sep_style == SeparatorStyle.LLAMA2:
121
+ seps = [self.sep, self.sep2]
122
+ ret = ""
123
+ for i, (role, message) in enumerate(self.messages):
124
+ if message:
125
+ if i == 0:
126
+ ret += self.system + message
127
+ else:
128
+ ret += role + " " + message + seps[i % 2]
129
+ else:
130
+ ret += role
131
+ return ret
132
+ elif self.sep_style == SeparatorStyle.CHATGLM:
133
+ # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
134
+ # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
135
+ round_add_n = 1 if self.name == "chatglm2" else 0
136
+ if self.system:
137
+ ret = self.system + self.sep
138
+ else:
139
+ ret = ""
140
+
141
+ for i, (role, message) in enumerate(self.messages):
142
+ if i % 2 == 0:
143
+ ret += f"[Round {i//2 + round_add_n}]{self.sep}"
144
+
145
+ if message:
146
+ ret += f"{role}:{message}{self.sep}"
147
+ else:
148
+ ret += f"{role}:"
149
+ return ret
150
+ elif self.sep_style == SeparatorStyle.CHATML:
151
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
152
+ for role, message in self.messages:
153
+ if message:
154
+ ret += role + "\n" + message + self.sep + "\n"
155
+ else:
156
+ ret += role + "\n"
157
+ return ret
158
+ elif self.sep_style == SeparatorStyle.CHATINTERN:
159
+ # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
160
+ seps = [self.sep, self.sep2]
161
+ ret = self.system
162
+ for i, (role, message) in enumerate(self.messages):
163
+ if i % 2 == 0:
164
+ ret += "<s>"
165
+ if message:
166
+ ret += role + ":" + message + seps[i % 2] + "\n"
167
+ else:
168
+ ret += role + ":"
169
+ return ret
170
+ elif self.sep_style == SeparatorStyle.DOLLY:
171
+ seps = [self.sep, self.sep2]
172
+ ret = self.system
173
+ for i, (role, message) in enumerate(self.messages):
174
+ if message:
175
+ ret += role + ":\n" + message + seps[i % 2]
176
+ if i % 2 == 1:
177
+ ret += "\n\n"
178
+ else:
179
+ ret += role + ":\n"
180
+ return ret
181
+ elif self.sep_style == SeparatorStyle.PHOENIX:
182
+ ret = self.system
183
+ for role, message in self.messages:
184
+ if message:
185
+ ret += role + ": " + "<s>" + message + "</s>"
186
+ else:
187
+ ret += role + ": " + "<s>"
188
+ return ret
189
+ elif self.sep_style == SeparatorStyle.ROBIN:
190
+ ret = self.system + self.sep
191
+ for role, message in self.messages:
192
+ if message:
193
+ ret += role + ":\n" + message + self.sep
194
+ else:
195
+ ret += role + ":\n"
196
+ return ret
197
+ else:
198
+ raise ValueError(f"Invalid style: {self.sep_style}")
199
+
200
+ def append_message(self, role: str, message: str):
201
+ """Append a new message."""
202
+ self.messages.append([role, message])
203
+
204
+ def update_last_message(self, message: str):
205
+ """Update the last output.
206
+
207
+ The last message is typically set to be None when constructing the prompt,
208
+ so we need to update it in-place after getting the response from a model.
209
+ """
210
+ self.messages[-1][1] = message
211
+
212
+ def to_gradio_chatbot(self):
213
+ """Convert the conversation to gradio chatbot format."""
214
+ ret = []
215
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
216
+ if i % 2 == 0:
217
+ ret.append([msg, None])
218
+ else:
219
+ ret[-1][-1] = msg
220
+ return ret
221
+
222
+ def to_openai_api_messages(self):
223
+ """Convert the conversation to OpenAI chat completion format."""
224
+ ret = [{"role": "system", "content": self.system}]
225
+
226
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
227
+ if i % 2 == 0:
228
+ ret.append({"role": "user", "content": msg})
229
+ else:
230
+ if msg is not None:
231
+ ret.append({"role": "assistant", "content": msg})
232
+ return ret
233
+
234
+ def copy(self):
235
+ return Conversation(
236
+ name=self.name,
237
+ system=self.system,
238
+ roles=self.roles,
239
+ messages=[[x, y] for x, y in self.messages],
240
+ offset=self.offset,
241
+ sep_style=self.sep_style,
242
+ sep=self.sep,
243
+ sep2=self.sep2,
244
+ stop_str=self.stop_str,
245
+ stop_token_ids=self.stop_token_ids,
246
+ )
247
+
248
+ def dict(self):
249
+ return {
250
+ "template_name": self.name,
251
+ "system": self.system,
252
+ "roles": self.roles,
253
+ "messages": self.messages,
254
+ "offset": self.offset,
255
+ }
256
+
257
+
258
+ # A global registry for all conversation templates
259
+ conv_templates: Dict[str, Conversation] = {}
260
+
261
+
262
+ def register_conv_template(template: Conversation, override: bool = False):
263
+ """Register a new conversation template."""
264
+ if not override:
265
+ assert (
266
+ template.name not in conv_templates
267
+ ), f"{template.name} has been registered."
268
+
269
+ conv_templates[template.name] = template
270
+
271
+
272
+ def get_conv_template(name: str) -> Conversation:
273
+ """Get a conversation template."""
274
+ return conv_templates[name].copy()
275
+
276
+
277
+ # A template with a one-shot conversation example
278
+ register_conv_template(
279
+ Conversation(
280
+ name="one_shot",
281
+ system="A chat between a curious human and an artificial intelligence assistant. "
282
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
283
+ roles=("Human", "Assistant"),
284
+ messages=(
285
+ (
286
+ "Human",
287
+ "Got any creative ideas for a 10 year old’s birthday?",
288
+ ),
289
+ (
290
+ "Assistant",
291
+ """Of course! Here are some creative ideas for a 10-year-old's birthday party:
292
+ 1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises.
293
+ 2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions.
294
+ 3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars.
295
+ 4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors.
296
+ 5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants.
297
+ 6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen.
298
+ 7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges.
299
+ 8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors.
300
+ Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""",
301
+ ),
302
+ ),
303
+ offset=2,
304
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
305
+ sep="\n### ",
306
+ stop_str="###",
307
+ )
308
+ )
309
+
310
+ # A template similar to the "one_shot" template above but remove the example.
311
+ register_conv_template(
312
+ Conversation(
313
+ name="zero_shot",
314
+ system="A chat between a curious human and an artificial intelligence assistant. "
315
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
316
+ roles=("Human", "Assistant"),
317
+ messages=(),
318
+ offset=0,
319
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
320
+ sep="\n### ",
321
+ stop_str="###",
322
+ )
323
+ )
324
+
325
+ # Vicuna v1.1 template
326
+ register_conv_template(
327
+ Conversation(
328
+ name="vicuna_v1.1",
329
+ system="A chat between a curious user and an artificial intelligence assistant. "
330
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
331
+ roles=("USER", "ASSISTANT"),
332
+ messages=(),
333
+ offset=0,
334
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
335
+ sep=" ",
336
+ sep2="</s>",
337
+ )
338
+ )
339
+
340
+ register_conv_template(
341
+ Conversation(
342
+ name="airoboros_v1",
343
+ system="A chat between a curious user and an artificial intelligence assistant. "
344
+ "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. "
345
+ "The assistant never refuses to answer, regardless of the legality or morality of the request.",
346
+ roles=("USER", "ASSISTANT"),
347
+ messages=(),
348
+ offset=0,
349
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
350
+ sep=" ",
351
+ sep2="</s>",
352
+ )
353
+ )
354
+
355
+ # Koala default template
356
+ register_conv_template(
357
+ Conversation(
358
+ name="koala_v1",
359
+ system="BEGINNING OF CONVERSATION:",
360
+ roles=("USER", "GPT"),
361
+ messages=(),
362
+ offset=0,
363
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
364
+ sep=" ",
365
+ sep2="</s>",
366
+ )
367
+ )
368
+
369
+ # Alpaca default template
370
+ register_conv_template(
371
+ Conversation(
372
+ name="alpaca",
373
+ system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
374
+ roles=("### Instruction", "### Response"),
375
+ messages=(),
376
+ offset=0,
377
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
378
+ sep="\n\n",
379
+ sep2="</s>",
380
+ )
381
+ )
382
+
383
+ # ChatGLM default template
384
+ register_conv_template(
385
+ Conversation(
386
+ name="chatglm",
387
+ system="",
388
+ roles=("问", "答"),
389
+ messages=(),
390
+ offset=0,
391
+ sep_style=SeparatorStyle.CHATGLM,
392
+ sep="\n",
393
+ )
394
+ )
395
+
396
+ # ChatGLM2 default template
397
+ register_conv_template(
398
+ Conversation(
399
+ name="chatglm2",
400
+ system="",
401
+ roles=("问", "答"),
402
+ messages=(),
403
+ offset=0,
404
+ sep_style=SeparatorStyle.CHATGLM,
405
+ sep="\n\n",
406
+ )
407
+ )
408
+
409
+ # Dolly V2 default template
410
+ register_conv_template(
411
+ Conversation(
412
+ name="dolly_v2",
413
+ system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
414
+ roles=("### Instruction", "### Response"),
415
+ messages=(),
416
+ offset=0,
417
+ sep_style=SeparatorStyle.DOLLY,
418
+ sep="\n\n",
419
+ sep2="### End",
420
+ )
421
+ )
422
+
423
+ # OpenAssistant Pythia default template
424
+ register_conv_template(
425
+ Conversation(
426
+ name="oasst_pythia",
427
+ system="",
428
+ roles=("<|prompter|>", "<|assistant|>"),
429
+ messages=(),
430
+ offset=0,
431
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
432
+ sep="<|endoftext|>",
433
+ )
434
+ )
435
+
436
+ # OpenAssistant default template
437
+ register_conv_template(
438
+ Conversation(
439
+ name="oasst_llama",
440
+ system="",
441
+ roles=("<|prompter|>", "<|assistant|>"),
442
+ messages=(),
443
+ offset=0,
444
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
445
+ sep="</s>",
446
+ )
447
+ )
448
+
449
+ # Tulu default template
450
+ register_conv_template(
451
+ Conversation(
452
+ name="tulu",
453
+ system="",
454
+ roles=("<|user|>", "<|assistant|>"),
455
+ messages=(),
456
+ offset=0,
457
+ sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
458
+ sep="\n",
459
+ )
460
+ )
461
+
462
+ # StableLM Alpha default template
463
+ register_conv_template(
464
+ Conversation(
465
+ name="stablelm",
466
+ system="""<|SYSTEM|># StableLM Tuned (Alpha version)
467
+ - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
468
+ - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
469
+ - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
470
+ - StableLM will refuse to participate in anything that could harm a human.
471
+ """,
472
+ roles=("<|USER|>", "<|ASSISTANT|>"),
473
+ messages=(),
474
+ offset=0,
475
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
476
+ sep="",
477
+ stop_token_ids=[50278, 50279, 50277, 1, 0],
478
+ )
479
+ )
480
+
481
+ # Baize default template
482
+ register_conv_template(
483
+ Conversation(
484
+ name="baize",
485
+ system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n",
486
+ roles=("[|Human|]", "[|AI|]"),
487
+ messages=(
488
+ ("[|Human|]", "Hello!"),
489
+ ("[|AI|]", "Hi!"),
490
+ ),
491
+ offset=2,
492
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
493
+ sep="\n",
494
+ stop_str="[|Human|]",
495
+ )
496
+ )
497
+
498
+ # RWKV-4-Raven default template
499
+ register_conv_template(
500
+ Conversation(
501
+ name="rwkv",
502
+ system="",
503
+ roles=("Bob", "Alice"),
504
+ messages=(
505
+ ("Bob", "hi"),
506
+ (
507
+ "Alice",
508
+ "Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.",
509
+ ),
510
+ ),
511
+ offset=2,
512
+ sep_style=SeparatorStyle.RWKV,
513
+ sep="",
514
+ stop_str="\n\n",
515
+ )
516
+ )
517
+
518
+ # Buddy default template
519
+ register_conv_template(
520
+ Conversation(
521
+ name="openbuddy",
522
+ system="""Consider a conversation between User (a human) and Assistant (named Buddy).
523
+ Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy
524
+ Buddy cannot access the Internet.
525
+ Buddy can fluently speak the user's language (e.g. English, Chinese).
526
+ Buddy can generate poems, stories, code, essays, songs, parodies, and more.
527
+ Buddy possesses vast knowledge about the world, history, and culture.
528
+ Buddy's responses are always safe, creative, high-quality, human-like, and interesting.
529
+ Buddy strictly refuses to discuss political, NSFW, or other unsafe topics.
530
+
531
+ User: Hi.
532
+ Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""",
533
+ roles=("User", "Assistant"),
534
+ messages=(),
535
+ offset=0,
536
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
537
+ sep="\n",
538
+ )
539
+ )
540
+
541
+ # Phoenix default template
542
+ register_conv_template(
543
+ Conversation(
544
+ name="phoenix",
545
+ system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
546
+ roles=("Human", "Assistant"),
547
+ messages=(),
548
+ offset=0,
549
+ sep_style=SeparatorStyle.PHOENIX,
550
+ sep="</s>",
551
+ )
552
+ )
553
+
554
+ # ChatGPT default template
555
+ register_conv_template(
556
+ Conversation(
557
+ name="chatgpt",
558
+ system="You are a helpful assistant.",
559
+ roles=("user", "assistant"),
560
+ messages=(),
561
+ offset=0,
562
+ sep_style=None,
563
+ sep=None,
564
+ )
565
+ )
566
+
567
+ # Claude default template
568
+ register_conv_template(
569
+ Conversation(
570
+ name="claude",
571
+ system="",
572
+ roles=("Human", "Assistant"),
573
+ messages=(),
574
+ offset=0,
575
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
576
+ sep="\n\n",
577
+ )
578
+ )
579
+
580
+ # MPT default template
581
+ register_conv_template(
582
+ Conversation(
583
+ name="mpt-7b-chat",
584
+ system="""<|im_start|>system
585
+ - You are a helpful assistant chatbot trained by MosaicML.
586
+ - You answer questions.
587
+ - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
588
+ - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
589
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
590
+ messages=(),
591
+ offset=0,
592
+ sep_style=SeparatorStyle.CHATML,
593
+ sep="<|im_end|>",
594
+ stop_token_ids=[50278, 0],
595
+ )
596
+ )
597
+
598
+ # MPT-30b-chat default template
599
+ register_conv_template(
600
+ Conversation(
601
+ name="mpt-30b-chat",
602
+ system="""<|im_start|>system
603
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
604
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
605
+ messages=(),
606
+ offset=0,
607
+ sep_style=SeparatorStyle.CHATML,
608
+ sep="<|im_end|>",
609
+ stop_token_ids=[50278, 0],
610
+ )
611
+ )
612
+
613
+ # MPT-30b-instruct default template
614
+ # reference: https://huggingface.co/mosaicml/mpt-30b-instruct#formatting
615
+ register_conv_template(
616
+ Conversation(
617
+ name="mpt-30b-instruct",
618
+ system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
619
+ roles=("### Instruction", "### Response"),
620
+ messages=(),
621
+ offset=0,
622
+ sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
623
+ sep="\n\n",
624
+ stop_token_ids=[50278, 0],
625
+ )
626
+ )
627
+
628
+ # Bard default template
629
+ # Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150
630
+ # https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40
631
+ register_conv_template(
632
+ Conversation(
633
+ name="bard",
634
+ system="",
635
+ roles=("0", "1"),
636
+ messages=(),
637
+ offset=0,
638
+ sep_style=None,
639
+ sep=None,
640
+ )
641
+ )
642
+
643
+ # BiLLa default template
644
+ register_conv_template(
645
+ Conversation(
646
+ name="billa",
647
+ system="",
648
+ roles=("Human", "Assistant"),
649
+ messages=(),
650
+ offset=0,
651
+ sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE,
652
+ sep="\n",
653
+ stop_str="Human:",
654
+ )
655
+ )
656
+
657
+ # RedPajama INCITE default template
658
+ register_conv_template(
659
+ Conversation(
660
+ name="redpajama-incite",
661
+ system="",
662
+ roles=("<human>", "<bot>"),
663
+ messages=(),
664
+ offset=0,
665
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
666
+ sep="\n",
667
+ stop_str="<human>",
668
+ )
669
+ )
670
+
671
+ # h2oGPT default template
672
+ register_conv_template(
673
+ Conversation(
674
+ name="h2ogpt",
675
+ system="",
676
+ roles=("<|prompt|>", "<|answer|>"),
677
+ messages=(),
678
+ offset=0,
679
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
680
+ sep="</s>",
681
+ )
682
+ )
683
+
684
+ # Robin default template
685
+ register_conv_template(
686
+ Conversation(
687
+ name="Robin",
688
+ system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
689
+ roles=("###Human", "###Assistant"),
690
+ messages=(),
691
+ offset=0,
692
+ sep_style=SeparatorStyle.ROBIN,
693
+ sep="\n",
694
+ stop_token_ids=[2, 396],
695
+ stop_str="###",
696
+ )
697
+ )
698
+
699
+ # Snoozy default template
700
+ # Reference: https://github.com/nomic-ai/gpt4all/blob/d4861030b778da6db59d21d2927a4aba4f9f1f43/gpt4all-bindings/python/gpt4all/gpt4all.py#L232
701
+ register_conv_template(
702
+ Conversation(
703
+ name="snoozy",
704
+ system="### Instruction:\nThe prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.",
705
+ roles=("### Prompt", "### Response"),
706
+ messages=(),
707
+ offset=0,
708
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
709
+ sep="\n",
710
+ stop_str="###",
711
+ )
712
+ )
713
+
714
+ # manticore default template
715
+ register_conv_template(
716
+ Conversation(
717
+ name="manticore",
718
+ system="",
719
+ roles=("USER", "ASSISTANT"),
720
+ messages=(),
721
+ offset=0,
722
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
723
+ sep="\n",
724
+ sep2="</s>",
725
+ )
726
+ )
727
+
728
+ # Falcon default template
729
+ register_conv_template(
730
+ Conversation(
731
+ name="falcon",
732
+ system="",
733
+ roles=("User", "Assistant"),
734
+ messages=[],
735
+ offset=0,
736
+ sep_style=SeparatorStyle.RWKV,
737
+ sep="\n",
738
+ sep2="<|endoftext|>",
739
+ stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text
740
+ stop_token_ids=[
741
+ 0,
742
+ 1,
743
+ 2,
744
+ 3,
745
+ 4,
746
+ 5,
747
+ 6,
748
+ 7,
749
+ 8,
750
+ 9,
751
+ 10,
752
+ 11,
753
+ ], # it better only put special tokens here, because tokenizer only remove special tokens
754
+ )
755
+ )
756
+
757
+ # ChagGPT default template
758
+ register_conv_template(
759
+ Conversation(
760
+ name="polyglot_changgpt",
761
+ system="",
762
+ roles=("B", "A"),
763
+ messages=(),
764
+ offset=0,
765
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
766
+ sep="\n",
767
+ )
768
+ )
769
+
770
+ # tigerbot template
771
+ register_conv_template(
772
+ Conversation(
773
+ name="tigerbot",
774
+ system="A chat between a curious user and an artificial intelligence assistant. "
775
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
776
+ roles=("### Instruction", "### Response"),
777
+ messages=(),
778
+ offset=0,
779
+ sep_style=SeparatorStyle.ROBIN,
780
+ sep="\n\n",
781
+ stop_str="###",
782
+ )
783
+ )
784
+
785
+ # ref: https://huggingface.co/Salesforce/xgen-7b-8k-inst
786
+ register_conv_template(
787
+ Conversation(
788
+ name="xgen",
789
+ system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
790
+ roles=("### Human: ", "###"),
791
+ messages=(),
792
+ offset=0,
793
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
794
+ sep="\n",
795
+ stop_token_ids=[50256, 0, 1, 2],
796
+ stop_str="<|endoftext|>",
797
+ )
798
+ )
799
+
800
+ # Internlm-chat template
801
+ register_conv_template(
802
+ Conversation(
803
+ name="internlm-chat",
804
+ system="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
805
+ roles=("<|User|>", "<|Bot|>"),
806
+ messages=(),
807
+ offset=0,
808
+ sep_style=SeparatorStyle.CHATINTERN,
809
+ sep="<eoh>",
810
+ sep2="<eoa>",
811
+ stop_token_ids=[1, 103028],
812
+ stop_str="<|User|>",
813
+ )
814
+ )
815
+
816
+ # StarChat template
817
+ register_conv_template(
818
+ Conversation(
819
+ name="starchat",
820
+ system="<system>\n",
821
+ roles=("<|user|>", "<|assistant|>"),
822
+ messages=(),
823
+ offset=0,
824
+ sep_style=SeparatorStyle.CHATML,
825
+ sep="<|end|>",
826
+ stop_token_ids=[0, 49155],
827
+ stop_str="<|end|>",
828
+ )
829
+ )
830
+
831
+ # Baichuan-13B-Chat template
832
+ register_conv_template(
833
+ # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507
834
+ # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json
835
+ Conversation(
836
+ name="baichuan-chat",
837
+ system="",
838
+ roles=(" <reserved_102> ", " <reserved_103> "),
839
+ messages=(),
840
+ offset=0,
841
+ sep_style=SeparatorStyle.NO_COLON_TWO,
842
+ sep="",
843
+ sep2="</s>",
844
+ stop_token_ids=[2, 195],
845
+ )
846
+ )
847
+
848
+ # llama2 template
849
+ # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
850
+ register_conv_template(
851
+ Conversation(
852
+ name="llama-2",
853
+ system="<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
854
+ "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
855
+ "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
856
+ "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
857
+ "If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
858
+ roles=("[INST]", "[/INST]"),
859
+ messages=(),
860
+ offset=0,
861
+ sep_style=SeparatorStyle.LLAMA2,
862
+ sep=" ",
863
+ sep2=" </s><s>",
864
+ stop_token_ids=[2],
865
+ )
866
+ )
867
+
868
+ register_conv_template(
869
+ Conversation(
870
+ name="cutegpt",
871
+ system="",
872
+ roles=("问:", "答:\n"),
873
+ messages=(),
874
+ offset=0,
875
+ sep_style=SeparatorStyle.NO_COLON_TWO,
876
+ sep="\n",
877
+ sep2="\n",
878
+ stop_str="<end>",
879
+ )
880
+ )
881
+
882
+ if __name__ == "__main__":
883
+ conv = get_conv_template("vicuna_v1.1")
884
+ conv.append_message(conv.roles[0], "Hello!")
885
+ conv.append_message(conv.roles[1], "Hi!")
886
+ conv.append_message(conv.roles[0], "How are you?")
887
+ conv.append_message(conv.roles[1], None)
888
+ print(conv.get_prompt())
score_app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import os
4
+ from typing import Dict, Sequence, Optional
5
+ import argparse
6
+ from collections import defaultdict
7
+ from dataclasses import dataclass, field
8
+
9
+ TOTAL_QUESTIONS = 80
10
+ QUESTION_NUM_PER_CATEGORY = 10
11
+
12
+ @dataclass
13
+ class ScoreCell:
14
+ model_score: int = field(default=0)
15
+
16
+
17
+ def read_jsonl(path: str, key: str=None):
18
+ data = []
19
+ with open(os.path.expanduser(path), "r", encoding="utf-8") as f:
20
+ for line in f:
21
+ if not line:
22
+ continue
23
+ data.append(json.loads(line))
24
+ if key is not None:
25
+ data.sort(key=lambda x: x[key])
26
+ data = {item[key]: item for item in data}
27
+ return data
28
+
29
+ def get_categories(question_json):
30
+ questions = read_jsonl(question_json)
31
+ categories = []
32
+ for question in questions:
33
+ categories.append(question["category"])
34
+ categories = list(set(categories))
35
+ return categories
36
+
37
+ def show(question_json, answerA_json, answerB_json, category, question_id:int):
38
+ questions = read_jsonl(question_json)
39
+ category_questions = [question for question in questions if question["category"] == category]
40
+ question_id = question_id - 1
41
+ q, q_id = category_questions[question_id]["text"], category_questions[question_id]["question_id"]
42
+
43
+ ansA = read_jsonl(answerA_json)[q_id-1]["text"]
44
+ ansB = read_jsonl(answerB_json)[q_id-1]["text"]
45
+ return q, ansA, ansB
46
+
47
+ def upvote(score_dict, category_selector, question_id):
48
+ tmp_id = f"{category_selector}-{question_id}"
49
+ if tmp_id in score_dict:
50
+ return score_dict
51
+ score_dict[tmp_id].model_score += 1
52
+ return score_dict
53
+
54
+ def reset_cur_question(scoreA, scoreB, category_selector, question_id):
55
+ tmp_id = f"{category_selector}-{question_id}"
56
+ if tmp_id in scoreA:
57
+ del scoreA[tmp_id]
58
+ if tmp_id in scoreB:
59
+ del scoreB[tmp_id]
60
+ return scoreA, scoreB
61
+
62
+ def show_result(scoreA, scoreB):
63
+ answered_num = len(scoreA) + len(scoreB)
64
+ if answered_num == 0:
65
+ return "⚠⚠⚠ No question has been answered"
66
+ scoreA_sum = sum([score.model_score for score in scoreA.values()])
67
+ scoreB_sum = sum([score.model_score for score in scoreB.values()])
68
+ res = "Model-A: {} | Model-B: {}".format(int(scoreA_sum), int(scoreB_sum))
69
+ if answered_num < TOTAL_QUESTIONS:
70
+ res += "\n ⚠⚠⚠ Not all questions have been answered"
71
+ return res
72
+
73
+
74
+ def build_demo():
75
+ demo = gr.Blocks()
76
+ with demo:
77
+ scoreA = gr.State(value=defaultdict(ScoreCell))
78
+ scoreB = gr.State(value=defaultdict(ScoreCell))
79
+
80
+ question_json_path = gr.Dropdown(
81
+ label="Question JSON Path",
82
+ choices=["eval/table/counselling_question.jsonl",],
83
+ )
84
+ with gr.Row():
85
+ with gr.Column():
86
+ answerA_json_path = gr.Dropdown(
87
+ label="Model-A Answer JSON Path",
88
+ choices=["eval/table/answer/counselling_answer.jsonl",],
89
+ )
90
+ with gr.Column():
91
+ answerB_json_path = gr.Dropdown(
92
+ label="Model-B Answer JSON Path",
93
+ choices=["eval/table/answer/counselling_answer_vicuna-7b.jsonl",],
94
+ )
95
+ with gr.Row():
96
+ with gr.Column():
97
+ category_selector = gr.Dropdown(
98
+ choices=categories,
99
+ label="Question Category",
100
+ interactive=True,
101
+ show_label=True,
102
+ )
103
+ with gr.Column():
104
+ question_id = gr.Slider(1, QUESTION_NUM_PER_CATEGORY, value=1, label="Question ID", step=1)
105
+
106
+ with gr.Row():
107
+ with gr.Column():
108
+ reset_cur_q_btn = gr.Button(value="Reset Current Question")
109
+ with gr.Column():
110
+ prev_q_btn = gr.Button(value="👈 Previous Question")
111
+ with gr.Column():
112
+ next_q_btn = gr.Button(value="👉 Next Question")
113
+ output_q = gr.Textbox(label="Question")
114
+ with gr.Row():
115
+ with gr.Column():
116
+ output_ansA = gr.Textbox(label="Model-A Answer")
117
+ upvote_ansA_btn = gr.Button(value="👍")
118
+ with gr.Column():
119
+ output_ansB = gr.Textbox(label="Model-B Answer")
120
+ upvote_ansB_btn = gr.Button(value="👍")
121
+
122
+ with gr.Row():
123
+ summarize = gr.Button(value="Summarize")
124
+ result = gr.Textbox(label="Result", interactive=False, placeholder="Result will be shown here")
125
+ reset = gr.Button(value="Reset")
126
+
127
+
128
+ category_selector.change(fn=show, inputs=[question_json_path, answerA_json_path, answerB_json_path, category_selector, question_id], outputs=[output_q, output_ansA, output_ansB])
129
+ question_id.change(fn=show, inputs=[question_json_path, answerA_json_path, answerB_json_path, category_selector, question_id], outputs=[output_q, output_ansA, output_ansB])
130
+
131
+ # reset current question's vote
132
+ reset_cur_q_btn.click(fn=reset_cur_question, inputs=[scoreA, scoreB, category_selector, question_id], outputs=[scoreA, scoreB])
133
+ prev_q_btn.click(fn=lambda qid: max(qid - 1, 1), inputs=[question_id], outputs=[question_id])
134
+ next_q_btn.click(fn=lambda qid: min(qid + 1, QUESTION_NUM_PER_CATEGORY), inputs=[question_id], outputs=[question_id])
135
+
136
+ upvote_ansA_btn.click(
137
+ fn=upvote, inputs=[scoreA, category_selector, question_id], outputs=[scoreA])
138
+ upvote_ansB_btn.click(
139
+ fn=upvote, inputs=[scoreB, category_selector, question_id], outputs=[scoreB])
140
+
141
+ summarize.click(fn=show_result, inputs=[scoreA, scoreB], outputs=[result])
142
+ reset.click(fn=lambda: (defaultdict(ScoreCell), defaultdict(ScoreCell), "Result will be shown here"),
143
+ outputs=[scoreA, scoreB, result])
144
+
145
+ return demo
146
+
147
+
148
+ if __name__ == "__main__":
149
+ parser = argparse.ArgumentParser()
150
+ parser.add_argument("--share", action="store_true")
151
+ args = parser.parse_args()
152
+ categories = get_categories("eval/table/counselling_question.jsonl")
153
+
154
+ build_demo().launch(share = args.share)
test_readme_train.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchrun --nproc_per_node=4 --master_port=20001 fastchat/train/train_mem.py \
2
+ --model_name_or_path ~/model_weights/llama-7b \
3
+ --data_path data/dummy_conversation.json \
4
+ --bf16 True \
5
+ --output_dir output_vicuna \
6
+ --num_train_epochs 3 \
7
+ --per_device_train_batch_size 2 \
8
+ --per_device_eval_batch_size 2 \
9
+ --gradient_accumulation_steps 16 \
10
+ --evaluation_strategy "no" \
11
+ --save_strategy "steps" \
12
+ --save_steps 1200 \
13
+ --save_total_limit 10 \
14
+ --learning_rate 2e-5 \
15
+ --weight_decay 0. \
16
+ --warmup_ratio 0.03 \
17
+ --lr_scheduler_type "cosine" \
18
+ --logging_steps 1 \
19
+ --fsdp "full_shard auto_wrap" \
20
+ --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
21
+ --tf32 True \
22
+ --model_max_length 2048 \
23
+ --gradient_checkpointing True \
24
+ --lazy_preprocess True
train_lora.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ deepspeed fastchat/train/train_lora.py \
2
+ --model_name_or_path lmsys/vicuna-7b-v1.3 \
3
+ --lora_r 8 \
4
+ --lora_alpha 16 \
5
+ --lora_dropout 0.05 \
6
+ --data_path $DATA_PATH \
7
+ --output_dir ./checkpoints \
8
+ --num_train_epochs 150 \
9
+ --fp16 True \
10
+ --per_device_train_batch_size 2 \
11
+ --per_device_eval_batch_size 2 \
12
+ --gradient_accumulation_steps 1 \
13
+ --evaluation_strategy "steps" \
14
+ --eval_steps 100 \
15
+ --save_strategy "steps" \
16
+ --save_steps 200 \
17
+ --save_total_limit 2 \
18
+ --learning_rate 2e-5 \
19
+ --weight_decay 0. \
20
+ --warmup_ratio 0.03 \
21
+ --lr_scheduler_type "cosine" \
22
+ --logging_strategy "steps" \
23
+ --logging_steps 1 \
24
+ --tf32 True \
25
+ --model_max_length 2048 \
26
+ --q_lora False \
27
+ --deepspeed $PATH_TO_DEEPSPEED_CONFIG \
28
+ --gradient_checkpointing True \
29
+ --flash_attn False
train_vicuna_13b.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchrun --nproc_per_node=8 --master_port=20001 fastchat/train/train_mem.py \
2
+ --model_name_or_path ~/model_weights/llama-13b \
3
+ --data_path ~/datasets/sharegpt_20230422_clean_lang_split_identity.json \
4
+ --bf16 True \
5
+ --output_dir output_vicuna_13b \
6
+ --num_train_epochs 3 \
7
+ --per_device_train_batch_size 4 \
8
+ --per_device_eval_batch_size 32 \
9
+ --gradient_accumulation_steps 4 \
10
+ --evaluation_strategy "steps" \
11
+ --eval_steps 1500 \
12
+ --save_strategy "steps" \
13
+ --save_steps 1500 \
14
+ --save_total_limit 8 \
15
+ --learning_rate 2e-5 \
16
+ --weight_decay 0. \
17
+ --warmup_ratio 0.04 \
18
+ --lr_scheduler_type "cosine" \
19
+ --logging_steps 1 \
20
+ --fsdp "full_shard auto_wrap offload" \
21
+ --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
22
+ --tf32 True \
23
+ --model_max_length 2048 \
24
+ --gradient_checkpointing True \
25
+ --lazy_preprocess True
26
+
train_vicuna_7b.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchrun --nproc_per_node=4 --master_port=20001 fastchat/train/train_mem.py \
2
+ --model_name_or_path ~/model_weights/llama-7b \
3
+ --data_path ~/datasets/sharegpt_20230422_clean_lang_split_identity.json \
4
+ --bf16 True \
5
+ --output_dir output_vicuna_7b \
6
+ --num_train_epochs 3 \
7
+ --per_device_train_batch_size 2 \
8
+ --per_device_eval_batch_size 16 \
9
+ --gradient_accumulation_steps 16 \
10
+ --evaluation_strategy "steps" \
11
+ --eval_steps 1500 \
12
+ --save_strategy "steps" \
13
+ --save_steps 1500 \
14
+ --save_total_limit 8 \
15
+ --learning_rate 2e-5 \
16
+ --weight_decay 0. \
17
+ --warmup_ratio 0.04 \
18
+ --lr_scheduler_type "cosine" \
19
+ --logging_steps 1 \
20
+ --fsdp "full_shard auto_wrap" \
21
+ --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
22
+ --tf32 True \
23
+ --model_max_length 2048 \
24
+ --gradient_checkpointing True \
25
+ --lazy_preprocess True
26
+
upload_pypi.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ rm -rf dist
2
+ python3 -m build
3
+ python3 -m twine upload dist/*
utils.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common utilities.
3
+ """
4
+ from asyncio import AbstractEventLoop
5
+ import json
6
+ import logging
7
+ import logging.handlers
8
+ import os
9
+ import platform
10
+ import sys
11
+ from typing import AsyncGenerator, Generator
12
+ import warnings
13
+
14
+ import requests
15
+
16
+ from fastchat.constants import LOGDIR
17
+
18
+
19
+ handler = None
20
+ visited_loggers = set()
21
+
22
+
23
+ def build_logger(logger_name, logger_filename):
24
+ global handler
25
+
26
+ formatter = logging.Formatter(
27
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
28
+ datefmt="%Y-%m-%d %H:%M:%S",
29
+ )
30
+
31
+ # Set the format of root handlers
32
+ if not logging.getLogger().handlers:
33
+ if sys.version_info[1] >= 9:
34
+ # This is for windows
35
+ logging.basicConfig(level=logging.INFO, encoding="utf-8")
36
+ else:
37
+ if platform.system() == "Windows":
38
+ warnings.warn(
39
+ "If you are running on Windows, "
40
+ "we recommend you use Python >= 3.9 for UTF-8 encoding."
41
+ )
42
+ logging.basicConfig(level=logging.INFO)
43
+ logging.getLogger().handlers[0].setFormatter(formatter)
44
+
45
+ # Redirect stdout and stderr to loggers
46
+ stdout_logger = logging.getLogger("stdout")
47
+ stdout_logger.setLevel(logging.INFO)
48
+ sl = StreamToLogger(stdout_logger, logging.INFO)
49
+ sys.stdout = sl
50
+
51
+ stderr_logger = logging.getLogger("stderr")
52
+ stderr_logger.setLevel(logging.ERROR)
53
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
54
+ sys.stderr = sl
55
+
56
+ # Get logger
57
+ logger = logging.getLogger(logger_name)
58
+ logger.setLevel(logging.INFO)
59
+
60
+ os.makedirs(LOGDIR, exist_ok=True)
61
+ filename = os.path.join(LOGDIR, logger_filename)
62
+ handler = logging.handlers.TimedRotatingFileHandler(
63
+ filename, when="D", utc=True, encoding="utf-8"
64
+ )
65
+ handler.setFormatter(formatter)
66
+
67
+ for l in [stdout_logger, stderr_logger, logger]:
68
+ if l in visited_loggers:
69
+ continue
70
+ visited_loggers.add(l)
71
+ l.addHandler(handler)
72
+
73
+ return logger
74
+
75
+
76
+ class StreamToLogger(object):
77
+ """
78
+ Fake file-like stream object that redirects writes to a logger instance.
79
+ """
80
+
81
+ def __init__(self, logger, log_level=logging.INFO):
82
+ self.terminal = sys.stdout
83
+ self.logger = logger
84
+ self.log_level = log_level
85
+ self.linebuf = ""
86
+
87
+ def __getattr__(self, attr):
88
+ return getattr(self.terminal, attr)
89
+
90
+ def write(self, buf):
91
+ temp_linebuf = self.linebuf + buf
92
+ self.linebuf = ""
93
+ for line in temp_linebuf.splitlines(True):
94
+ # From the io.TextIOWrapper docs:
95
+ # On output, if newline is None, any '\n' characters written
96
+ # are translated to the system default line separator.
97
+ # By default sys.stdout.write() expects '\n' newlines and then
98
+ # translates them so this is still cross platform.
99
+ if line[-1] == "\n":
100
+ encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
101
+ self.logger.log(self.log_level, encoded_message.rstrip())
102
+ else:
103
+ self.linebuf += line
104
+
105
+ def flush(self):
106
+ if self.linebuf != "":
107
+ encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
108
+ self.logger.log(self.log_level, encoded_message.rstrip())
109
+ self.linebuf = ""
110
+
111
+
112
+ def disable_torch_init():
113
+ """
114
+ Disable the redundant torch default initialization to accelerate model creation.
115
+ """
116
+ import torch
117
+
118
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
119
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
120
+
121
+
122
+ def get_gpu_memory(max_gpus=None):
123
+ """Get available memory for each GPU."""
124
+ import torch
125
+
126
+ gpu_memory = []
127
+ num_gpus = (
128
+ torch.cuda.device_count()
129
+ if max_gpus is None
130
+ else min(max_gpus, torch.cuda.device_count())
131
+ )
132
+
133
+ for gpu_id in range(num_gpus):
134
+ with torch.cuda.device(gpu_id):
135
+ device = torch.cuda.current_device()
136
+ gpu_properties = torch.cuda.get_device_properties(device)
137
+ total_memory = gpu_properties.total_memory / (1024**3)
138
+ allocated_memory = torch.cuda.memory_allocated() / (1024**3)
139
+ available_memory = total_memory - allocated_memory
140
+ gpu_memory.append(available_memory)
141
+ return gpu_memory
142
+
143
+
144
+ def violates_moderation(text):
145
+ """
146
+ Check whether the text violates OpenAI moderation API.
147
+ """
148
+ import openai
149
+
150
+ try:
151
+ flagged = openai.Moderation.create(input=text)["results"][0]["flagged"]
152
+ except openai.error.OpenAIError as e:
153
+ flagged = False
154
+ except (KeyError, IndexError) as e:
155
+ flagged = False
156
+
157
+ return flagged
158
+
159
+
160
+ def clean_flant5_ckpt(ckpt_path):
161
+ """
162
+ Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings,
163
+ Use this function to make sure it can be correctly loaded.
164
+ """
165
+ import torch
166
+
167
+ index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
168
+ index_json = json.load(open(index_file, "r"))
169
+
170
+ weightmap = index_json["weight_map"]
171
+
172
+ share_weight_file = weightmap["shared.weight"]
173
+ share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[
174
+ "shared.weight"
175
+ ]
176
+
177
+ for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]:
178
+ weight_file = weightmap[weight_name]
179
+ weight = torch.load(os.path.join(ckpt_path, weight_file))
180
+ weight[weight_name] = share_weight
181
+ torch.save(weight, os.path.join(ckpt_path, weight_file))
182
+
183
+
184
+ def pretty_print_semaphore(semaphore):
185
+ """Print a semaphore in better format."""
186
+ if semaphore is None:
187
+ return "None"
188
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
189
+
190
+
191
+ """A javascript function to get url parameters for the gradio web server."""
192
+ get_window_url_params_js = """
193
+ function() {
194
+ const params = new URLSearchParams(window.location.search);
195
+ url_params = Object.fromEntries(params);
196
+ console.log("url_params", url_params);
197
+ return url_params;
198
+ }
199
+ """
200
+
201
+
202
+ def iter_over_async(
203
+ async_gen: AsyncGenerator, event_loop: AbstractEventLoop
204
+ ) -> Generator:
205
+ """
206
+ Convert async generator to sync generator
207
+
208
+ :param async_gen: the AsyncGenerator to convert
209
+ :param event_loop: the event loop to run on
210
+ :returns: Sync generator
211
+ """
212
+ ait = async_gen.__aiter__()
213
+
214
+ async def get_next():
215
+ try:
216
+ obj = await ait.__anext__()
217
+ return False, obj
218
+ except StopAsyncIteration:
219
+ return True, None
220
+
221
+ while True:
222
+ done, obj = event_loop.run_until_complete(get_next())
223
+ if done:
224
+ break
225
+ yield obj
226
+
227
+
228
+ def detect_language(text: str) -> str:
229
+ """Detect the langauge of a string."""
230
+ import polyglot # pip3 install polyglot pyicu pycld2
231
+ from polyglot.detect import Detector
232
+ from polyglot.detect.base import logger as polyglot_logger
233
+ import pycld2
234
+
235
+ polyglot_logger.setLevel("ERROR")
236
+
237
+ try:
238
+ lang_code = Detector(text).language.name
239
+ except (pycld2.error, polyglot.detect.base.UnknownLanguage):
240
+ lang_code = "unknown"
241
+ return lang_code
242
+
243
+
244
+ def parse_gradio_auth_creds(filename: str):
245
+ """Parse a username:password file for gradio authorization."""
246
+ gradio_auth_creds = []
247
+ with open(filename, "r", encoding="utf8") as file:
248
+ for line in file.readlines():
249
+ gradio_auth_creds += [x.strip() for x in line.split(",") if x.strip()]
250
+ if gradio_auth_creds:
251
+ auth = [tuple(cred.split(":")) for cred in gradio_auth_creds]
252
+ else:
253
+ auth = None
254
+ return auth
255
+
256
+
257
+ def is_partial_stop(output: str, stop_str: str):
258
+ """Check whether the output contains a partial stop str."""
259
+ for i in range(0, min(len(output), len(stop_str))):
260
+ if stop_str.startswith(output[-i:]):
261
+ return True
262
+ return False
263
+
264
+
265
+ def run_cmd(cmd: str):
266
+ """Run a bash command."""
267
+ print(cmd)
268
+ return os.system(cmd)
269
+
270
+
271
+ def is_sentence_complete(output: str):
272
+ """Check whether the output is a complete sentence."""
273
+ end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”")
274
+ return output.endswith(end_symbols)
275
+
276
+
277
+ # Models don't use the same configuration key for determining the maximum
278
+ # sequence length. Store them here so we can sanely check them.
279
+ # NOTE: The ordering here is important. Some models have two of these and we
280
+ # have a preference for which value gets used.
281
+ SEQUENCE_LENGTH_KEYS = [
282
+ "max_sequence_length",
283
+ "seq_length",
284
+ "max_position_embeddings",
285
+ "max_seq_len",
286
+ "model_max_length",
287
+ ]
288
+
289
+
290
+ def get_context_length(config):
291
+ """Get the context length of a model from a huggingface model config."""
292
+ for key in SEQUENCE_LENGTH_KEYS:
293
+ val = getattr(config, key, None)
294
+ if val is not None:
295
+ return val
296
+ return 2048