OlivierDehaene commited on
Commit
cb5d912
1 Parent(s): 511ba7f

support open-assistant model

Browse files
Files changed (1) hide show
  1. app.py +82 -18
app.py CHANGED
@@ -25,6 +25,8 @@ def get_usernames(model: str):
25
  Returns:
26
  (str, str, str, str): pre-prompt, username, bot name, separator
27
  """
 
 
28
  if model == "Rallio67/joi2_20B_instruct_alpha":
29
  return "", "User: ", "Joi: ", "\n\n"
30
  if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
@@ -35,6 +37,7 @@ def get_usernames(model: str):
35
  def predict(
36
  model: str,
37
  inputs: str,
 
38
  top_p: float,
39
  temperature: float,
40
  top_k: int,
@@ -66,8 +69,16 @@ def predict(
66
 
67
  partial_words = ""
68
 
69
- for i, response in enumerate(
70
- client.generate_stream(
 
 
 
 
 
 
 
 
71
  total_inputs,
72
  top_p=top_p if top_p < 1.0 else None,
73
  top_k=top_k,
@@ -78,7 +89,8 @@ def predict(
78
  max_new_tokens=500,
79
  stop_sequences=[user_name.rstrip(), assistant_name.rstrip()],
80
  )
81
- ):
 
82
  if response.token.special:
83
  continue
84
 
@@ -105,23 +117,46 @@ def reset_textbox():
105
 
106
 
107
  def radio_on_change(
108
- value: str, disclaimer, top_p, top_k, temperature, repetition_penalty, watermark
 
 
 
 
 
 
 
109
  ):
110
- if value == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
111
- top_p = top_p.update(value=0.25)
112
- top_k = top_k.update(value=50)
113
- temperature = temperature.update(value=0.6)
114
- repetition_penalty = repetition_penalty.update(value=1.01)
 
 
 
 
 
 
 
115
  watermark = watermark.update(False)
116
  disclaimer = disclaimer.update(visible=True)
117
  else:
118
- top_p = top_p.update(value=0.95)
119
- top_k = top_k.update(value=4)
120
- temperature = temperature.update(value=0.5)
 
121
  repetition_penalty = repetition_penalty.update(value=1.03)
122
  watermark = watermark.update(True)
123
  disclaimer = disclaimer.update(visible=False)
124
- return disclaimer, top_p, top_k, temperature, repetition_penalty, watermark
 
 
 
 
 
 
 
 
125
 
126
 
127
  title = """<h1 align="center">🔥Large Language Model API 🚀Streaming🚀</h1>"""
@@ -149,8 +184,9 @@ with gr.Blocks(
149
  gr.HTML(title)
150
  with gr.Column(elem_id="col_container"):
151
  model = gr.Radio(
152
- value="togethercomputer/GPT-NeoXT-Chat-Base-20B",
153
  choices=[
 
154
  "togethercomputer/GPT-NeoXT-Chat-Base-20B",
155
  "Rallio67/joi2_20B_instruct_alpha",
156
  "google/flan-t5-xxl",
@@ -167,11 +203,19 @@ with gr.Blocks(
167
  inputs = gr.Textbox(
168
  placeholder="Hi there!", label="Type an input and press Enter"
169
  )
170
- disclaimer = gr.Markdown(openchat_disclaimer)
171
  state = gr.State([])
172
  b1 = gr.Button()
173
 
174
  with gr.Accordion("Parameters", open=False):
 
 
 
 
 
 
 
 
175
  top_p = gr.Slider(
176
  minimum=-0,
177
  maximum=1.0,
@@ -179,6 +223,7 @@ with gr.Blocks(
179
  step=0.05,
180
  interactive=True,
181
  label="Top-p (nucleus sampling)",
 
182
  )
183
  temperature = gr.Slider(
184
  minimum=-0,
@@ -187,6 +232,7 @@ with gr.Blocks(
187
  step=0.1,
188
  interactive=True,
189
  label="Temperature",
 
190
  )
191
  top_k = gr.Slider(
192
  minimum=1,
@@ -195,6 +241,7 @@ with gr.Blocks(
195
  step=1,
196
  interactive=True,
197
  label="Top-k",
 
198
  )
199
  repetition_penalty = gr.Slider(
200
  minimum=0.1,
@@ -204,14 +251,29 @@ with gr.Blocks(
204
  interactive=True,
205
  label="Repetition Penalty",
206
  )
207
- watermark = gr.Checkbox(value=False, label="Text watermarking")
208
 
209
  model.change(
210
  lambda value: radio_on_change(
211
- value, disclaimer, top_p, top_k, temperature, repetition_penalty, watermark
 
 
 
 
 
 
 
212
  ),
213
  inputs=model,
214
- outputs=[disclaimer, top_p, top_k, temperature, repetition_penalty, watermark],
 
 
 
 
 
 
 
 
215
  )
216
 
217
  inputs.submit(
@@ -219,6 +281,7 @@ with gr.Blocks(
219
  [
220
  model,
221
  inputs,
 
222
  top_p,
223
  temperature,
224
  top_k,
@@ -234,6 +297,7 @@ with gr.Blocks(
234
  [
235
  model,
236
  inputs,
 
237
  top_p,
238
  temperature,
239
  top_k,
 
25
  Returns:
26
  (str, str, str, str): pre-prompt, username, bot name, separator
27
  """
28
+ if model == "OpenAssistant/oasst-sft-1-pythia-12b":
29
+ return "", "<|prompter|", "<|assistant|>", "<|endoftext|>"
30
  if model == "Rallio67/joi2_20B_instruct_alpha":
31
  return "", "User: ", "Joi: ", "\n\n"
32
  if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
 
37
  def predict(
38
  model: str,
39
  inputs: str,
40
+ typical_p: float,
41
  top_p: float,
42
  temperature: float,
43
  top_k: int,
 
69
 
70
  partial_words = ""
71
 
72
+ if model == "OpenAssistant/oasst-sft-1-pythia-12b":
73
+ iterator = client.generate_stream(
74
+ total_inputs,
75
+ typical_p=typical_p,
76
+ repetition_penalty=repetition_penalty,
77
+ watermark=watermark,
78
+ max_new_tokens=500,
79
+ )
80
+ else:
81
+ iterator = client.generate_stream(
82
  total_inputs,
83
  top_p=top_p if top_p < 1.0 else None,
84
  top_k=top_k,
 
89
  max_new_tokens=500,
90
  stop_sequences=[user_name.rstrip(), assistant_name.rstrip()],
91
  )
92
+
93
+ for i, response in enumerate(iterator):
94
  if response.token.special:
95
  continue
96
 
 
117
 
118
 
119
  def radio_on_change(
120
+ value: str,
121
+ disclaimer,
122
+ typical_p,
123
+ top_p,
124
+ top_k,
125
+ temperature,
126
+ repetition_penalty,
127
+ watermark,
128
  ):
129
+ if value == "OpenAssistant/oasst-sft-1-pythia-12b":
130
+ typical_p = typical_p.update(value=0.2, visible=True)
131
+ top_p = top_p.update(visible=False)
132
+ top_k = top_k.update(visible=False)
133
+ temperature = temperature.update(visible=False)
134
+ disclaimer = disclaimer.update(visible=False)
135
+ elif value == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
136
+ typical_p = typical_p.update(visible=False)
137
+ top_p = top_p.update(value=0.25, visible=True)
138
+ top_k = top_k.update(value=50, visible=True)
139
+ temperature = temperature.update(value=0.6, visible=True)
140
+ repetition_penalty = repetition_penalty.update(value=1.01, visible=True)
141
  watermark = watermark.update(False)
142
  disclaimer = disclaimer.update(visible=True)
143
  else:
144
+ typical_p = typical_p.update(visible=False)
145
+ top_p = top_p.update(value=0.95, visible=True)
146
+ top_k = top_k.update(value=4, visible=True)
147
+ temperature = temperature.update(value=0.5, visible=True)
148
  repetition_penalty = repetition_penalty.update(value=1.03)
149
  watermark = watermark.update(True)
150
  disclaimer = disclaimer.update(visible=False)
151
+ return (
152
+ disclaimer,
153
+ typical_p,
154
+ top_p,
155
+ top_k,
156
+ temperature,
157
+ repetition_penalty,
158
+ watermark,
159
+ )
160
 
161
 
162
  title = """<h1 align="center">🔥Large Language Model API 🚀Streaming🚀</h1>"""
 
184
  gr.HTML(title)
185
  with gr.Column(elem_id="col_container"):
186
  model = gr.Radio(
187
+ value="OpenAssistant/oasst-sft-1-pythia-12b",
188
  choices=[
189
+ "OpenAssistant/oasst-sft-1-pythia-12b",
190
  "togethercomputer/GPT-NeoXT-Chat-Base-20B",
191
  "Rallio67/joi2_20B_instruct_alpha",
192
  "google/flan-t5-xxl",
 
203
  inputs = gr.Textbox(
204
  placeholder="Hi there!", label="Type an input and press Enter"
205
  )
206
+ disclaimer = gr.Markdown(openchat_disclaimer, visible=False)
207
  state = gr.State([])
208
  b1 = gr.Button()
209
 
210
  with gr.Accordion("Parameters", open=False):
211
+ typical_p = gr.Slider(
212
+ minimum=-0,
213
+ maximum=1.0,
214
+ value=0.2,
215
+ step=0.05,
216
+ interactive=True,
217
+ label="Typical P mass",
218
+ )
219
  top_p = gr.Slider(
220
  minimum=-0,
221
  maximum=1.0,
 
223
  step=0.05,
224
  interactive=True,
225
  label="Top-p (nucleus sampling)",
226
+ visible=False,
227
  )
228
  temperature = gr.Slider(
229
  minimum=-0,
 
232
  step=0.1,
233
  interactive=True,
234
  label="Temperature",
235
+ visible=False,
236
  )
237
  top_k = gr.Slider(
238
  minimum=1,
 
241
  step=1,
242
  interactive=True,
243
  label="Top-k",
244
+ visible=False,
245
  )
246
  repetition_penalty = gr.Slider(
247
  minimum=0.1,
 
251
  interactive=True,
252
  label="Repetition Penalty",
253
  )
254
+ watermark = gr.Checkbox(value=True, label="Text watermarking")
255
 
256
  model.change(
257
  lambda value: radio_on_change(
258
+ value,
259
+ disclaimer,
260
+ typical_p,
261
+ top_p,
262
+ top_k,
263
+ temperature,
264
+ repetition_penalty,
265
+ watermark,
266
  ),
267
  inputs=model,
268
+ outputs=[
269
+ disclaimer,
270
+ typical_p,
271
+ top_p,
272
+ top_k,
273
+ temperature,
274
+ repetition_penalty,
275
+ watermark,
276
+ ],
277
  )
278
 
279
  inputs.submit(
 
281
  [
282
  model,
283
  inputs,
284
+ typical_p,
285
  top_p,
286
  temperature,
287
  top_k,
 
297
  [
298
  model,
299
  inputs,
300
+ typical_p,
301
  top_p,
302
  temperature,
303
  top_k,