lewtun HF staff commited on
Commit
e169f05
1 Parent(s): fcec4cf

Refactor for streaming

Browse files
Files changed (2) hide show
  1. app.py +97 -118
  2. dialogues.py +241 -0
app.py CHANGED
@@ -3,25 +3,21 @@ import os
3
  import shutil
4
 
5
  import gradio as gr
6
- import requests
7
  from huggingface_hub import Repository
8
- from share_btn import community_icon_html, loading_icon_html, share_btn_css, share_js
9
 
 
 
 
10
 
11
- HF_TOKEN = os.environ.get("H4_TOKEN", None)
12
  API_TOKEN = os.environ.get("API_TOKEN", None)
13
- STAR_CHAT_API_URL = os.environ.get("STAR_CHAT_API_URL", None)
14
- STAR_CHAT_GPT_API_URL = os.environ.get("STAR_CHAT_GPT_API_URL", None)
15
 
16
- API_TOKEN = "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"
17
- STAR_CHAT_API_URL = "https://i1qe9e7uv7jzsg8k.us-east-1.aws.endpoints.huggingface.cloud"
18
- STAR_CHAT_GPT_API_URL = "https://czpdnzuklyfoqjbs.us-east-1.aws.endpoints.huggingface.cloud"
19
-
20
- model_to_api = {
21
- "StarChat": STAR_CHAT_API_URL,
22
- "StarChatGPT": STAR_CHAT_GPT_API_URL,
23
- }
24
- PROMPT_TEMPLATE = "<|system|>\n{system}<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>"
25
 
26
  theme = gr.themes.Monochrome(
27
  primary_hue="indigo",
@@ -31,57 +27,25 @@ theme = gr.themes.Monochrome(
31
  font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
32
  )
33
 
34
- if HF_TOKEN:
35
- try:
36
- shutil.rmtree("./data/")
37
- except:
38
- pass
39
-
40
- repo = Repository(
41
- local_dir="./data/", clone_from="trl-lib/star-chat-prompts", use_auth_token=HF_TOKEN, repo_type="dataset"
42
- )
43
- repo.git_pull()
44
 
 
 
 
 
45
 
46
 
47
  def save_inputs_and_outputs(inputs, outputs, generate_kwargs):
48
  with open(os.path.join("data", "prompts.jsonl"), "a") as f:
49
- json.dump({"inputs": inputs, "outputs": outputs,
50
- "generate_kwargs": generate_kwargs}, f, ensure_ascii=False)
51
  f.write("\n")
52
  repo.push_to_hub()
53
 
54
 
55
- def inference(
56
- model, prompt, system_message, user_message, temperature, top_p, top_k, max_new_tokens, do_sample, eos_token_id
57
- ):
58
- headers = {"Authorization": f"Bearer {API_TOKEN}"}
59
- api_url = model_to_api[model]
60
- print(f"CUSTOM_LOG {model} - {api_url}")
61
- response = requests.post(
62
- api_url,
63
- headers=headers,
64
- json={
65
- "inputs": prompt,
66
- "parameters": {
67
- "do_sample": do_sample,
68
- "temperature": temperature,
69
- "top_p": top_p,
70
- "top_k": top_k,
71
- "max_new_tokens": max_new_tokens,
72
- "eos_token_id": eos_token_id,
73
- },
74
- },
75
- )
76
-
77
- if response.status_code != 200:
78
- return None
79
- completion = response.json()[0]["generated_text"]
80
- if user_message in completion:
81
- completion = completion.lstrip()[len(f"{system_message}\n{user_message}\n"):]
82
- return completion
83
-
84
-
85
  def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep):
86
  past = []
87
  for data in chatbot:
@@ -107,7 +71,7 @@ def has_no_history(chatbot, history):
107
 
108
 
109
  def generate(
110
- model,
111
  system_message,
112
  user_message,
113
  chatbot,
@@ -122,32 +86,74 @@ def generate(
122
  if not user_message:
123
  return chatbot, history, user_message, ""
124
 
125
- prompt = PROMPT_TEMPLATE.format(system=system_message, prompt=user_message)
126
-
127
  history.append(user_message)
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  generate_kwargs = {
130
  "temperature": temperature,
131
  "top_p": top_p,
132
  "top_k": top_k,
133
  "max_new_tokens": max_new_tokens,
134
- "do_sample": True,
135
- "eos_token_id": [49155, 32003],
136
  }
137
 
138
- response = inference(model, prompt, system_message, user_message, **generate_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- history.append(response)
141
- chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)]
142
 
143
- if HF_TOKEN and do_save:
144
- try:
145
- print("Pushing prompt and completion to the Hub")
146
- save_inputs_and_outputs(prompt, output, generate_kwargs)
147
- except Exception as e:
148
- print(e)
149
 
150
- return chat, history, user_message, ""
151
 
152
 
153
  examples = [
@@ -160,7 +166,6 @@ examples = [
160
 
161
 
162
  def regenerate(
163
- model,
164
  system_message,
165
  user_message,
166
  chatbot,
@@ -183,33 +188,20 @@ def regenerate(
183
  chatbot = chatbot[:-1]
184
  history = history[:-2]
185
 
186
- return generate(
187
- model, system_message, user_message, chatbot, history, temperature, top_p, top_k, max_new_tokens, do_save
188
- )
189
 
190
 
191
  def clear_chat():
192
  return [], []
193
 
194
 
195
- def radio_on_change():
196
- return [], []
197
-
198
- # def radio_on_change(
199
- # model, system_message, user_message, chatbot, history, temperature, top_p, top_k, max_new_tokens, do_save
200
- # ):
201
- # return generate(
202
- # model, system_message, user_message, chatbot, history, temperature, top_p, top_k, max_new_tokens, do_save
203
- # )
204
-
205
-
206
  def process_example(args):
207
  for [x, y] in generate(args):
208
  pass
209
  return [x, y]
210
 
211
 
212
- title = """<h1 align="center">⭐ StarChat Demo 💬</h1>"""
213
  custom_css = """
214
  #banner-image {
215
  display: block;
@@ -253,42 +245,34 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=custom_css) as demo:
253
  gr.HTML(title)
254
  gr.Image("StarCoderBanner.png", elem_id="banner-image", show_label=False)
255
  gr.Markdown(
256
- """
257
- StarChat is an instruction fine-tuned model based on [StarCoder](https://huggingface.co/bigcode/starcoder), a 16B parameter model trained on one trillion tokens sourced from 80+ programming languages, GitHub issues, Git commits, and Jupyter notebooks (all permissively licensed). With an enterprise-friendly license, 8,192 token context length, and fast large-batch inference via [multi-query attention](https://arxiv.org/abs/1911.02150), StarCoder is currently the best open-source choice for code-based applications. For more details, check out our [blog post]().
 
 
258
 
259
- ⚠️ **Intended Use**: this app and its supporting models ([StarChat](https://huggingface.co/HuggingFaceH4/starchat) and [StarChatGPT](https://huggingface.co/HuggingFaceH4/starchatgpt)) are provided as educational tools to explain instruction fine-tuning; not to serve as replacement for human expertise. For more details on the model's limitations in terms of factuality and biases, see the model cards: [StarChat](https://huggingface.co/HuggingFaceH4/starchat#bias-risks-and-limitations) and [StarChatGPT](https://huggingface.co/HuggingFaceH4/starchatgpt#bias-risks-and-limitations).
260
-
261
- ⚠️ **Data Collection**: by default, we are collecting the prompts entered in this app to further improve and evaluate the model. Do not share any personal or sensitive information while using the app! You can opt out of this data collection by removing the checkbox below.
262
  """
263
  )
264
 
 
 
 
 
 
 
 
265
  with gr.Row():
266
  with gr.Column(scale=1):
267
  system_message = gr.Textbox(elem_id="system-message", label="System prompt")
268
 
269
  with gr.Column(scale=2):
270
  with gr.Box():
271
- model = gr.Radio(
272
- value="StarChat",
273
- choices=[
274
- "StarChat",
275
- "StarChatGPT",
276
- ],
277
- label="Model",
278
- interactive=True,
279
- )
280
  output = gr.Markdown()
281
  chatbot = gr.Chatbot(elem_id="chat-message", label="Chat")
282
 
283
  with gr.Row():
284
  with gr.Column(scale=3):
285
- do_save = gr.Checkbox(
286
- value=True,
287
- label="Store data",
288
- info="You agree to the storage of your prompt and generated text for research and development purposes:",
289
- )
290
- user_message = gr.Textbox(placeholder="Enter your message here",
291
- show_label=False, elem_id="q-input")
292
  with gr.Row():
293
  send_button = gr.Button("Send", elem_id="send-btn", visible=True)
294
  regenerate_button = gr.Button("Regenerate", elem_id="send-btn", visible=True)
@@ -298,7 +282,7 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=custom_css) as demo:
298
  # with gr.Group(elem_id="share-btn-container"):
299
  # community_icon = gr.HTML(community_icon_html, visible=True)
300
  # loading_icon = gr.HTML(loading_icon_html, visible=True)
301
- # share_button = gr.Button("Share to community", elem_id="share-btn", visible=True)
302
  with gr.Row():
303
  gr.Examples(
304
  examples=examples,
@@ -311,9 +295,9 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=custom_css) as demo:
311
  with gr.Column(scale=1):
312
  temperature = gr.Slider(
313
  label="Temperature",
314
- value=0.8,
315
  minimum=0.0,
316
- maximum=2.0,
317
  step=0.1,
318
  interactive=True,
319
  info="Higher values produce more diverse outputs",
@@ -329,7 +313,7 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=custom_css) as demo:
329
  )
330
  top_p = gr.Slider(
331
  label="Top-p (nucleus sampling)",
332
- value=0.25,
333
  minimum=0.0,
334
  maximum=1,
335
  step=0.05,
@@ -338,7 +322,7 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=custom_css) as demo:
338
  )
339
  max_new_tokens = gr.Slider(
340
  label="Max new tokens",
341
- value=512,
342
  minimum=0,
343
  maximum=2048,
344
  step=4,
@@ -353,7 +337,6 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=custom_css) as demo:
353
  user_message.submit(
354
  generate,
355
  inputs=[
356
- model,
357
  system_message,
358
  user_message,
359
  chatbot,
@@ -370,7 +353,6 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=custom_css) as demo:
370
  send_button.click(
371
  generate,
372
  inputs=[
373
- model,
374
  system_message,
375
  user_message,
376
  chatbot,
@@ -387,7 +369,6 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=custom_css) as demo:
387
  regenerate_button.click(
388
  regenerate,
389
  inputs=[
390
- model,
391
  system_message,
392
  last_user_message,
393
  chatbot,
@@ -402,8 +383,6 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=custom_css) as demo:
402
  )
403
 
404
  clear_chat_button.click(clear_chat, outputs=[chatbot, history])
405
-
406
- model.change(radio_on_change, outputs=[chatbot, history])
407
  # share_button.click(None, [], [], _js=share_js)
408
 
409
  demo.queue(concurrency_count=16).launch(debug=True)
 
3
  import shutil
4
 
5
  import gradio as gr
 
6
  from huggingface_hub import Repository
7
+ from text_generation import Client
8
 
9
+ from dialogues import DialogueTemplate
10
+ from share_btn import (community_icon_html, loading_icon_html, share_btn_css,
11
+ share_js)
12
 
13
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
  API_TOKEN = os.environ.get("API_TOKEN", None)
15
+ API_URL = os.environ.get("API_URL", None)
 
16
 
17
+ client = Client(
18
+ API_URL,
19
+ headers={"Authorization": f"Bearer {API_TOKEN}"},
20
+ )
 
 
 
 
 
21
 
22
  theme = gr.themes.Monochrome(
23
  primary_hue="indigo",
 
27
  font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
28
  )
29
 
30
+ # if HF_TOKEN:
31
+ # try:
32
+ # shutil.rmtree("./data/")
33
+ # except:
34
+ # pass
 
 
 
 
 
35
 
36
+ # repo = Repository(
37
+ # local_dir="./data/", clone_from="trl-lib/star-chat-prompts", use_auth_token=HF_TOKEN, repo_type="dataset"
38
+ # )
39
+ # repo.git_pull()
40
 
41
 
42
  def save_inputs_and_outputs(inputs, outputs, generate_kwargs):
43
  with open(os.path.join("data", "prompts.jsonl"), "a") as f:
44
+ json.dump({"inputs": inputs, "outputs": outputs, "generate_kwargs": generate_kwargs}, f, ensure_ascii=False)
 
45
  f.write("\n")
46
  repo.push_to_hub()
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep):
50
  past = []
51
  for data in chatbot:
 
71
 
72
 
73
  def generate(
74
+ # model,
75
  system_message,
76
  user_message,
77
  chatbot,
 
86
  if not user_message:
87
  return chatbot, history, user_message, ""
88
 
 
 
89
  history.append(user_message)
90
 
91
+ past_messages = []
92
+ for data in chatbot:
93
+ user_data, model_data = data
94
+
95
+ past_messages.extend(
96
+ [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}]
97
+ )
98
+
99
+ if len(past_messages) < 1:
100
+ dialogue_template = DialogueTemplate(
101
+ system=system_message, messages=[{"role": "user", "content": user_message}]
102
+ )
103
+ prompt = dialogue_template.get_inference_prompt()
104
+ else:
105
+ dialogue_template = DialogueTemplate(
106
+ system=system_message, messages=past_messages + [{"role": "user", "content": user_message}]
107
+ )
108
+ prompt = dialogue_template.get_inference_prompt()
109
+
110
  generate_kwargs = {
111
  "temperature": temperature,
112
  "top_p": top_p,
113
  "top_k": top_k,
114
  "max_new_tokens": max_new_tokens,
 
 
115
  }
116
 
117
+ temperature = float(temperature)
118
+ if temperature < 1e-2:
119
+ temperature = 1e-2
120
+ top_p = float(top_p)
121
+
122
+ generate_kwargs = dict(
123
+ temperature=temperature,
124
+ max_new_tokens=max_new_tokens,
125
+ top_p=top_p,
126
+ do_sample=True,
127
+ truncate=999,
128
+ seed=42,
129
+ stop_sequences=["<|end|>"],
130
+ )
131
+
132
+ stream = client.generate_stream(
133
+ prompt,
134
+ **generate_kwargs,
135
+ )
136
+
137
+ output = ""
138
+ for idx, response in enumerate(stream):
139
+ if response.token.special:
140
+ continue
141
+ output += response.token.text
142
+ if idx == 0:
143
+ history.append(" " + output)
144
+ else:
145
+ history[-1] = output
146
 
147
+ chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)]
 
148
 
149
+ # if HF_TOKEN and do_save:
150
+ # try:
151
+ # print("Pushing prompt and completion to the Hub")
152
+ # save_inputs_and_outputs(prompt, output, generate_kwargs)
153
+ # except Exception as e:
154
+ # print(e)
155
 
156
+ yield chat, history, user_message, ""
157
 
158
 
159
  examples = [
 
166
 
167
 
168
  def regenerate(
 
169
  system_message,
170
  user_message,
171
  chatbot,
 
188
  chatbot = chatbot[:-1]
189
  history = history[:-2]
190
 
191
+ return generate(system_message, user_message, chatbot, history, temperature, top_p, top_k, max_new_tokens, do_save)
 
 
192
 
193
 
194
  def clear_chat():
195
  return [], []
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
198
  def process_example(args):
199
  for [x, y] in generate(args):
200
  pass
201
  return [x, y]
202
 
203
 
204
+ title = """<h1 align="center">⭐ Chat with StarCoder Demo 💬</h1>"""
205
  custom_css = """
206
  #banner-image {
207
  display: block;
 
245
  gr.HTML(title)
246
  gr.Image("StarCoderBanner.png", elem_id="banner-image", show_label=False)
247
  gr.Markdown(
248
+ """
249
+ This demo showcases an instruction fine-tuned model based on [StarCoder](https://huggingface.co/bigcode/starcoder), a 16B parameter model trained on one trillion tokens sourced from 80+ programming languages, GitHub issues, Git commits, and Jupyter notebooks (all permissively licensed). With an enterprise-friendly license, 8,192 token context length, and fast large-batch inference via [multi-query attention](https://arxiv.org/abs/1911.02150), StarCoder is currently the best open-source choice for code-based applications. For more details, check out our [blog post]().
250
+
251
+ ⚠️ **Intended Use**: this app and its [supporting model](https://huggingface.co/HuggingFaceH4/starcoderbase-finetuned-oasst1) are provided as educational tools to explain instruction fine-tuning; not to serve as replacement for human expertise. For more details on the model's limitations in terms of factuality and biases, see the [model card](https://huggingface.co/HuggingFaceH4/starcoderbase-finetuned-oasst1#bias-risks-and-limitations).
252
 
253
+ ⚠️ **Data Collection**: by default, we are collecting the prompts entered in this app to further improve and evaluate the model. Do NOT share any personal or sensitive information while using the app! You can opt out of this data collection by removing the checkbox below.
 
 
254
  """
255
  )
256
 
257
+ with gr.Row():
258
+ do_save = gr.Checkbox(
259
+ value=True,
260
+ label="Store data",
261
+ info="You agree to the storage of your prompt and generated text for research and development purposes:",
262
+ )
263
+
264
  with gr.Row():
265
  with gr.Column(scale=1):
266
  system_message = gr.Textbox(elem_id="system-message", label="System prompt")
267
 
268
  with gr.Column(scale=2):
269
  with gr.Box():
 
 
 
 
 
 
 
 
 
270
  output = gr.Markdown()
271
  chatbot = gr.Chatbot(elem_id="chat-message", label="Chat")
272
 
273
  with gr.Row():
274
  with gr.Column(scale=3):
275
+ user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input")
 
 
 
 
 
 
276
  with gr.Row():
277
  send_button = gr.Button("Send", elem_id="send-btn", visible=True)
278
  regenerate_button = gr.Button("Regenerate", elem_id="send-btn", visible=True)
 
282
  # with gr.Group(elem_id="share-btn-container"):
283
  # community_icon = gr.HTML(community_icon_html, visible=True)
284
  # loading_icon = gr.HTML(loading_icon_html, visible=True)
285
+ # share_button = gr.Button("Share to community", elem_id="share-btn", visible=True)
286
  with gr.Row():
287
  gr.Examples(
288
  examples=examples,
 
295
  with gr.Column(scale=1):
296
  temperature = gr.Slider(
297
  label="Temperature",
298
+ value=0.2,
299
  minimum=0.0,
300
+ maximum=1.0,
301
  step=0.1,
302
  interactive=True,
303
  info="Higher values produce more diverse outputs",
 
313
  )
314
  top_p = gr.Slider(
315
  label="Top-p (nucleus sampling)",
316
+ value=0.95,
317
  minimum=0.0,
318
  maximum=1,
319
  step=0.05,
 
322
  )
323
  max_new_tokens = gr.Slider(
324
  label="Max new tokens",
325
+ value=384,
326
  minimum=0,
327
  maximum=2048,
328
  step=4,
 
337
  user_message.submit(
338
  generate,
339
  inputs=[
 
340
  system_message,
341
  user_message,
342
  chatbot,
 
353
  send_button.click(
354
  generate,
355
  inputs=[
 
356
  system_message,
357
  user_message,
358
  chatbot,
 
369
  regenerate_button.click(
370
  regenerate,
371
  inputs=[
 
372
  system_message,
373
  last_user_message,
374
  chatbot,
 
383
  )
384
 
385
  clear_chat_button.click(clear_chat, outputs=[chatbot, history])
 
 
386
  # share_button.click(None, [], [], _js=share_js)
387
 
388
  demo.queue(concurrency_count=16).launch(debug=True)
dialogues.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import os
18
+ from dataclasses import asdict, dataclass
19
+ from pathlib import Path
20
+ from typing import Any, Dict, List, Optional, Type, TypeVar, Union
21
+
22
+ from huggingface_hub import ModelHubMixin, hf_hub_download
23
+
24
+ # Generic variable that is either ModelHubMixin or a subclass thereof
25
+ T = TypeVar("T", bound="ModelHubMixin")
26
+
27
+ TEMPLATE_FILENAME = "dialogue_template.json"
28
+ IGNORE_INDEX = -100
29
+
30
+
31
+ @dataclass
32
+ class DialogueTemplate(ModelHubMixin):
33
+ """Converts all turns of a dialogue between a user and assistant to a standardized format.
34
+
35
+ Adapted from OpenAI's ChatML (https://github.com/openai/openai-python/blob/main/chatml.md) and Vicuna (https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py)
36
+ """
37
+
38
+ system: str
39
+ messages: List[Dict[str, str]] = None
40
+ system_token: str = "<|system|>"
41
+ user_token: str = "<|user|>"
42
+ assistant_token: str = "<|assistant|>"
43
+ end_token: str = "<|end|>"
44
+
45
+ def get_training_prompt(self) -> str:
46
+ prompt = self.system_token + "\n" + self.system + self.end_token + "\n"
47
+ if self.messages is None:
48
+ raise ValueError("Dialogue template must have at least one message.")
49
+ for message in self.messages:
50
+ if message["role"] == "user":
51
+ prompt += self.user_token + "\n" + message["content"] + self.end_token + "\n"
52
+ else:
53
+ prompt += self.assistant_token + "\n" + message["content"] + self.end_token + "\n"
54
+ return prompt
55
+
56
+ def get_inference_prompt(self) -> str:
57
+ prompt = self.system_token + "\n" + self.system + self.end_token + "\n"
58
+ if self.messages is None:
59
+ raise ValueError("Dialogue template must have at least one message.")
60
+ for message in self.messages:
61
+ if message["role"] == "user":
62
+ prompt += self.user_token + "\n" + message["content"] + self.end_token + "\n"
63
+ else:
64
+ prompt += self.assistant_token + "\n" + message["content"] + self.end_token + "\n"
65
+ prompt += self.assistant_token
66
+ return prompt
67
+
68
+ def get_dialogue(self):
69
+ """Helper function to format the messages as an easy-to-read dialogue."""
70
+ prompt = ""
71
+ if self.messages is None:
72
+ raise ValueError("Dialogue template must have at least one message.")
73
+ for message in self.messages:
74
+ if message["role"] == "user":
75
+ prompt += "\n\nHuman: " + message["content"]
76
+ else:
77
+ prompt += "\n\nAssistant: " + message["content"]
78
+ return prompt
79
+
80
+ def get_special_tokens(self) -> List[str]:
81
+ return [self.system_token, self.user_token, self.assistant_token, self.end_token]
82
+
83
+ def copy(self):
84
+ return DialogueTemplate(
85
+ system=self.system,
86
+ messages=self.messages,
87
+ system_token=self.system_token,
88
+ user_token=self.user_token,
89
+ assistant_token=self.assistant_token,
90
+ end_token=self.end_token,
91
+ )
92
+
93
+ def to_dict(self) -> Dict[str, Any]:
94
+ return {k: v for k, v in asdict(self).items()}
95
+
96
+ @classmethod
97
+ def from_dict(cls, data):
98
+ return DialogueTemplate(
99
+ system=data["system"] if "system" in data else "",
100
+ messages=data["messages"] if "messages" in data else None,
101
+ system_token=data["system_token"] if "system_token" in data else "<|system|>",
102
+ user_token=data["user_token"] if "user_token" in data else "<|user|>",
103
+ assistant_token=data["assistant_token"] if "assistant_token" in data else "<|assistant|>",
104
+ end_token=data["end_token"] if "end_token" in data else "<|end|>",
105
+ )
106
+
107
+ def _save_pretrained(self, save_directory: Union[str, Path]) -> None:
108
+ save_directory = Path(save_directory)
109
+ save_directory.mkdir(exist_ok=True)
110
+ with open(save_directory / "dialogue_template.json", "w") as f:
111
+ json.dump(self.to_dict(), f, indent=2)
112
+
113
+ @classmethod
114
+ def _from_pretrained(
115
+ cls: Type[T],
116
+ *,
117
+ model_id: str,
118
+ revision: Optional[str],
119
+ cache_dir: Optional[Union[str, Path]],
120
+ force_download: bool,
121
+ proxies: Optional[Dict],
122
+ resume_download: bool,
123
+ local_files_only: bool,
124
+ token: Optional[Union[str, bool]],
125
+ **model_kwargs,
126
+ ) -> T:
127
+ """Loads the dialogue template from a local directory or the Huggingface Hub.
128
+
129
+ Args:
130
+ model_id (`str`):
131
+ ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`).
132
+ revision (`str`, *optional*):
133
+ Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the
134
+ latest commit on `main` branch.
135
+ force_download (`bool`, *optional*, defaults to `False`):
136
+ Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding
137
+ the existing cache.
138
+ resume_download (`bool`, *optional*, defaults to `False`):
139
+ Whether to delete incompletely received files. Will attempt to resume the download if such a file exists.
140
+ proxies (`Dict[str, str]`, *optional*):
141
+ A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128',
142
+ 'http://hostname': 'foo.bar:4012'}`).
143
+ token (`str` or `bool`, *optional*):
144
+ The token to use as HTTP bearer authorization for remote files. By default, it will use the token
145
+ cached when running `huggingface-cli login`.
146
+ cache_dir (`str`, `Path`, *optional*):
147
+ Path to the folder where cached files are stored.
148
+ local_files_only (`bool`, *optional*, defaults to `False`):
149
+ If `True`, avoid downloading the file and return the path to the local cached file if it exists.
150
+ model_kwargs:
151
+ Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method.
152
+ """
153
+ if os.path.isdir(model_id): # Can either be a local directory
154
+ print("Loading dialogue template from local directory")
155
+ template_file = os.path.join(model_id, TEMPLATE_FILENAME)
156
+ else: # Or a template on the Hub
157
+ template_file = hf_hub_download( # Download from the hub, passing same input args
158
+ repo_id=model_id,
159
+ filename=TEMPLATE_FILENAME,
160
+ revision=revision,
161
+ cache_dir=cache_dir,
162
+ force_download=force_download,
163
+ proxies=proxies,
164
+ resume_download=resume_download,
165
+ token=token,
166
+ local_files_only=local_files_only,
167
+ )
168
+
169
+ # Load template
170
+ with open(template_file, "r") as f:
171
+ data = json.load(f)
172
+ return cls.from_dict(data=data)
173
+
174
+
175
+ # A shortened version of the system message in Anthropic's HHH prompt: https://gist.github.com/jareddk/2509330f8ef3d787fc5aaac67aab5f11#file-hhh_prompt-txt
176
+ default_template = DialogueTemplate(
177
+ system="Below is a dialogue between a human user and an AI assistant. The assistant is happy to help with almost anything, and will do its best to understand exactly what is needed.",
178
+ )
179
+
180
+ # OpenAI and OpenAssistant train on few to no system messages.
181
+ # TODO: consider defining this as the `default` template
182
+ no_system_template = DialogueTemplate(
183
+ system="",
184
+ )
185
+
186
+ alpaca_template = DialogueTemplate(
187
+ system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
188
+ user_token="### Instruction:",
189
+ assistant_token="### Response:",
190
+ )
191
+
192
+ SUPPORTED_DIALOGUE_TEMPLATES = {
193
+ "default": default_template,
194
+ "no_system": no_system_template,
195
+ "alpaca": alpaca_template,
196
+ }
197
+
198
+
199
+ def get_dialogue_template(template: str) -> DialogueTemplate:
200
+ if template not in SUPPORTED_DIALOGUE_TEMPLATES.keys():
201
+ raise ValueError(f"Template {template} is not supported!")
202
+ return SUPPORTED_DIALOGUE_TEMPLATES[template].copy()
203
+
204
+
205
+ def prepare_dialogue(example, dialogue_template, is_train=True):
206
+ """Format example to single- or multi-turn dialogue."""
207
+ # TODO: make this simpler by just ensuring every dataset has a messages column
208
+ if "messages" in example.keys() and example["messages"] is not None:
209
+ dialogue_template.messages = example["messages"]
210
+ elif all(k in example.keys() for k in ("prompt", "completion")):
211
+ # Construct single-turn dialogue from prompt and completion
212
+ dialogue_template.messages = [
213
+ {"role": "user", "content": example["prompt"]},
214
+ {"role": "assistant", "content": example["completion"]},
215
+ ]
216
+ elif "prompt" in example.keys():
217
+ # Construct single-turn dialogue from prompt (inference only)
218
+ dialogue_template.messages = [
219
+ {"role": "user", "content": example["prompt"]},
220
+ ]
221
+ else:
222
+ raise ValueError(
223
+ f"Could not format example as dialogue! Require either `messages` or `[prompt, completion]` or `[prompt]` keys but found {list(example.keys())}"
224
+ )
225
+ if is_train:
226
+ example["text"] = dialogue_template.get_training_prompt()
227
+ else:
228
+ example["text"] = dialogue_template.get_inference_prompt()
229
+ return example
230
+
231
+
232
+ def mask_user_labels(tokenizer, dialogue_template, labels):
233
+ """Masks the user turns of a dialogue from the loss"""
234
+ user_token_id = tokenizer.convert_tokens_to_ids(dialogue_template.user_token)
235
+ assistant_token_id = tokenizer.convert_tokens_to_ids(dialogue_template.assistant_token)
236
+ for idx, label_id in enumerate(labels):
237
+ if label_id == user_token_id:
238
+ current_idx = idx
239
+ while labels[current_idx] != assistant_token_id and current_idx < len(labels):
240
+ labels[current_idx] = IGNORE_INDEX
241
+ current_idx += 1