OlivierDehaene commited on
Commit
f15edb7
1 Parent(s): 173e7dd

openchat support

Browse files
Files changed (1) hide show
  1. app.py +76 -28
app.py CHANGED
@@ -4,6 +4,13 @@ import gradio as gr
4
 
5
  from text_generation import Client, InferenceAPIClient
6
 
 
 
 
 
 
 
 
7
 
8
  def get_client(model: str):
9
  if model == "Rallio67/joi2_20B_instruct_alpha":
@@ -14,26 +21,30 @@ def get_client(model: str):
14
 
15
 
16
  def get_usernames(model: str):
 
 
 
 
17
  if model == "Rallio67/joi2_20B_instruct_alpha":
18
- return "User: ", "Joi: "
19
  if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
20
- return "<human>: ", "<bot>: "
21
- return "User: ", "Assistant: "
22
 
23
 
24
  def predict(
25
- model: str,
26
- inputs: str,
27
- top_p: float,
28
- temperature: float,
29
- top_k: int,
30
- repetition_penalty: float,
31
- watermark: bool,
32
- chatbot,
33
- history,
34
  ):
35
  client = get_client(model)
36
- user_name, assistant_name = get_usernames(model)
37
 
38
  history.append(inputs)
39
 
@@ -43,19 +54,20 @@ def predict(
43
 
44
  if not user_data.startswith(user_name):
45
  user_data = user_name + user_data
46
- if not model_data.startswith("\n\n" + assistant_name):
47
- model_data = "\n\n" + assistant_name + model_data
48
 
49
- past.append(user_data + model_data + "\n\n")
50
 
51
  if not inputs.startswith(user_name):
52
  inputs = user_name + inputs
53
 
54
- total_inputs = "".join(past) + inputs + "\n\n" + assistant_name
55
 
56
  partial_words = ""
57
 
58
- for i, response in enumerate(client.generate_stream(
 
59
  total_inputs,
60
  top_p=top_p if top_p < 1.0 else None,
61
  top_k=top_k,
@@ -65,7 +77,8 @@ def predict(
65
  temperature=temperature,
66
  max_new_tokens=500,
67
  stop_sequences=[user_name.rstrip(), assistant_name.rstrip()],
68
- )):
 
69
  if response.token.special:
70
  continue
71
 
@@ -81,7 +94,8 @@ def predict(
81
  history[-1] = partial_words
82
 
83
  chat = [
84
- (history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)
 
85
  ]
86
  yield chat, history
87
 
@@ -90,6 +104,26 @@ def reset_textbox():
90
  return gr.update(value="")
91
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  title = """<h1 align="center">🔥Large Language Model API 🚀Streaming🚀</h1>"""
94
  description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
95
 
@@ -104,17 +138,21 @@ Assistant: <utterance>
104
  In this app, you can explore the outputs of multiple LLMs when prompted in this way.
105
  """
106
 
 
 
 
 
107
  with gr.Blocks(
108
- css="""#col_container {margin-left: auto; margin-right: auto;}
109
  #chatbot {height: 520px; overflow: auto;}"""
110
  ) as demo:
111
  gr.HTML(title)
112
  with gr.Column(elem_id="col_container"):
113
  model = gr.Radio(
114
- value="Rallio67/joi2_20B_instruct_alpha",
115
  choices=[
 
116
  "Rallio67/joi2_20B_instruct_alpha",
117
- # "togethercomputer/GPT-NeoXT-Chat-Base-20B",
118
  "google/flan-t5-xxl",
119
  "google/flan-ul2",
120
  "bigscience/bloom",
@@ -124,10 +162,12 @@ with gr.Blocks(
124
  label="Model",
125
  interactive=True,
126
  )
 
127
  chatbot = gr.Chatbot(elem_id="chatbot")
128
  inputs = gr.Textbox(
129
  placeholder="Hi there!", label="Type an input and press Enter"
130
  )
 
131
  state = gr.State([])
132
  b1 = gr.Button()
133
 
@@ -135,7 +175,7 @@ with gr.Blocks(
135
  top_p = gr.Slider(
136
  minimum=-0,
137
  maximum=1.0,
138
- value=0.95,
139
  step=0.05,
140
  interactive=True,
141
  label="Top-p (nucleus sampling)",
@@ -143,7 +183,7 @@ with gr.Blocks(
143
  temperature = gr.Slider(
144
  minimum=-0,
145
  maximum=5.0,
146
- value=0.5,
147
  step=0.1,
148
  interactive=True,
149
  label="Temperature",
@@ -151,7 +191,7 @@ with gr.Blocks(
151
  top_k = gr.Slider(
152
  minimum=1,
153
  maximum=50,
154
- value=4,
155
  step=1,
156
  interactive=True,
157
  label="Top-k",
@@ -159,12 +199,20 @@ with gr.Blocks(
159
  repetition_penalty = gr.Slider(
160
  minimum=0.1,
161
  maximum=3.0,
162
- value=1.03,
163
  step=0.01,
164
  interactive=True,
165
  label="Repetition Penalty",
166
  )
167
- watermark = gr.Checkbox(value=True, label="Text watermarking")
 
 
 
 
 
 
 
 
168
 
169
  inputs.submit(
170
  predict,
4
 
5
  from text_generation import Client, InferenceAPIClient
6
 
7
+ openchat_preprompt = (
8
+ "\n<human>: Hi!\n<bot>: My name is Bot, model version is 0.15, part of an open-source kit for "
9
+ "fine-tuning new bots! I was created by Together, LAION, and Ontocord.ai and the open-source "
10
+ "community. I am not human, not evil and not alive, and thus have no thoughts and feelings, "
11
+ "but I am programmed to be helpful, polite, honest, and friendly.\n"
12
+ )
13
+
14
 
15
  def get_client(model: str):
16
  if model == "Rallio67/joi2_20B_instruct_alpha":
21
 
22
 
23
  def get_usernames(model: str):
24
+ """
25
+ Returns:
26
+ (str, str, str, str): pre-prompt, username, bot name, separator
27
+ """
28
  if model == "Rallio67/joi2_20B_instruct_alpha":
29
+ return "", "User: ", "Joi: ", "\n\n"
30
  if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
31
+ return openchat_preprompt, "<human>: ", "<bot>: ", "\n"
32
+ return "", "User: ", "Assistant: ", "\n"
33
 
34
 
35
  def predict(
36
+ model: str,
37
+ inputs: str,
38
+ top_p: float,
39
+ temperature: float,
40
+ top_k: int,
41
+ repetition_penalty: float,
42
+ watermark: bool,
43
+ chatbot,
44
+ history,
45
  ):
46
  client = get_client(model)
47
+ preprompt, user_name, assistant_name, sep = get_usernames(model)
48
 
49
  history.append(inputs)
50
 
54
 
55
  if not user_data.startswith(user_name):
56
  user_data = user_name + user_data
57
+ if not model_data.startswith(sep + assistant_name):
58
+ model_data = sep + assistant_name + model_data
59
 
60
+ past.append(user_data + model_data.rstrip() + sep)
61
 
62
  if not inputs.startswith(user_name):
63
  inputs = user_name + inputs
64
 
65
+ total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
66
 
67
  partial_words = ""
68
 
69
+ for i, response in enumerate(
70
+ client.generate_stream(
71
  total_inputs,
72
  top_p=top_p if top_p < 1.0 else None,
73
  top_k=top_k,
77
  temperature=temperature,
78
  max_new_tokens=500,
79
  stop_sequences=[user_name.rstrip(), assistant_name.rstrip()],
80
+ )
81
+ ):
82
  if response.token.special:
83
  continue
84
 
94
  history[-1] = partial_words
95
 
96
  chat = [
97
+ (history[i].strip(), history[i + 1].strip())
98
+ for i in range(0, len(history) - 1, 2)
99
  ]
100
  yield chat, history
101
 
104
  return gr.update(value="")
105
 
106
 
107
+ def radio_on_change(
108
+ value: str, disclaimer, top_p, top_k, temperature, repetition_penalty, watermark
109
+ ):
110
+ if value == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
111
+ top_p = top_p.update(value=0.25)
112
+ top_k = top_k.update(value=50)
113
+ temperature = temperature.update(value=0.6)
114
+ repetition_penalty = repetition_penalty.update(value=1.01)
115
+ watermark = watermark.update(False)
116
+ disclaimer = disclaimer.update(visible=True)
117
+ else:
118
+ top_p = top_p.update(value=0.95)
119
+ top_k = top_k.update(value=4)
120
+ temperature = temperature.update(value=0.5)
121
+ repetition_penalty = repetition_penalty.update(value=1.03)
122
+ watermark = watermark.update(True)
123
+ disclaimer = disclaimer.update(visible=False)
124
+ return disclaimer, top_p, top_k, temperature, repetition_penalty, watermark
125
+
126
+
127
  title = """<h1 align="center">🔥Large Language Model API 🚀Streaming🚀</h1>"""
128
  description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
129
 
138
  In this app, you can explore the outputs of multiple LLMs when prompted in this way.
139
  """
140
 
141
+ openchat_disclaimer = """
142
+ <div align="center">Checkout the official <a href=https://huggingface.co/spaces/togethercomputer/OpenChatKit>OpenChatKit feedback app</a> for the full experience.</div>
143
+ """
144
+
145
  with gr.Blocks(
146
+ css="""#col_container {margin-left: auto; margin-right: auto;}
147
  #chatbot {height: 520px; overflow: auto;}"""
148
  ) as demo:
149
  gr.HTML(title)
150
  with gr.Column(elem_id="col_container"):
151
  model = gr.Radio(
152
+ value="togethercomputer/GPT-NeoXT-Chat-Base-20B",
153
  choices=[
154
+ "togethercomputer/GPT-NeoXT-Chat-Base-20B",
155
  "Rallio67/joi2_20B_instruct_alpha",
 
156
  "google/flan-t5-xxl",
157
  "google/flan-ul2",
158
  "bigscience/bloom",
162
  label="Model",
163
  interactive=True,
164
  )
165
+
166
  chatbot = gr.Chatbot(elem_id="chatbot")
167
  inputs = gr.Textbox(
168
  placeholder="Hi there!", label="Type an input and press Enter"
169
  )
170
+ disclaimer = gr.Markdown(openchat_disclaimer)
171
  state = gr.State([])
172
  b1 = gr.Button()
173
 
175
  top_p = gr.Slider(
176
  minimum=-0,
177
  maximum=1.0,
178
+ value=0.25,
179
  step=0.05,
180
  interactive=True,
181
  label="Top-p (nucleus sampling)",
183
  temperature = gr.Slider(
184
  minimum=-0,
185
  maximum=5.0,
186
+ value=0.6,
187
  step=0.1,
188
  interactive=True,
189
  label="Temperature",
191
  top_k = gr.Slider(
192
  minimum=1,
193
  maximum=50,
194
+ value=50,
195
  step=1,
196
  interactive=True,
197
  label="Top-k",
199
  repetition_penalty = gr.Slider(
200
  minimum=0.1,
201
  maximum=3.0,
202
+ value=1.01,
203
  step=0.01,
204
  interactive=True,
205
  label="Repetition Penalty",
206
  )
207
+ watermark = gr.Checkbox(value=False, label="Text watermarking")
208
+
209
+ model.change(
210
+ lambda value: radio_on_change(
211
+ value, disclaimer, top_p, top_k, temperature, repetition_penalty, watermark
212
+ ),
213
+ inputs=model,
214
+ outputs=[disclaimer, top_p, top_k, temperature, repetition_penalty, watermark],
215
+ )
216
 
217
  inputs.submit(
218
  predict,