chansung commited on
Commit
629e8c6
1 Parent(s): 8248004

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -29
app.py CHANGED
@@ -61,7 +61,10 @@ def fill_up_placeholders(txt):
61
  "" if len(placeholders) >= 1 else txt
62
  )
63
 
64
- async def chat_stream(idx, local_data, instruction_txtbox, chat_state):
 
 
 
65
  res = [
66
  chat_state["ppmanager_type"].from_json(json.dumps(ppm))
67
  for ppm in local_data
@@ -71,8 +74,18 @@ async def chat_stream(idx, local_data, instruction_txtbox, chat_state):
71
  ppm.add_pingpong(
72
  PingPong(instruction_txtbox, "")
73
  )
74
- prompt = build_prompts(ppm, "global context", 3)
75
- async for result in gen_text(prompt, hf_model=MODEL_ID, hf_token=TOKEN):
 
 
 
 
 
 
 
 
 
 
76
  ppm.append_pong(result)
77
  yield ppm.build_uis(), str(res)
78
 
@@ -198,7 +211,7 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
198
  with gr.Column():
199
  with gr.Column():
200
  gr.Markdown("#### Global context")
201
- with gr.Accordion("global context will persist during conversation, and it is placed at the top of the prompt", open=False):
202
  global_context = gr.Textbox(
203
  "global context",
204
  lines=5,
@@ -218,31 +231,12 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
218
 
219
  gr.Markdown("#### GenConfig for **response** text generation")
220
  with gr.Row():
221
- res_temp = gr.Slider(0.0, 2.0, 0, step=0.1, label="temp", interactive=True)
222
- res_topp = gr.Slider(0.0, 2.0, 0, step=0.1, label="top_p", interactive=True)
223
- res_topk = gr.Slider(20, 1000, 0, step=1, label="top_k", interactive=True)
224
- res_rpen = gr.Slider(0.0, 2.0, 0, step=0.1, label="rep_penalty", interactive=True)
225
- res_mnts = gr.Slider(64, 8192, 0, step=1, label="new_tokens", interactive=True)
226
- res_beams = gr.Slider(1, 4, 0, step=1, label="beams")
227
- res_cache = gr.Radio([True, False], value=0, label="cache", interactive=True)
228
  res_sample = gr.Radio([True, False], value=0, label="sample", interactive=True)
229
- res_eosid = gr.Number(value=0, visible=False, precision=0)
230
- res_padid = gr.Number(value=0, visible=False, precision=0)
231
-
232
- with gr.Column(visible=False):
233
- gr.Markdown("#### GenConfig for **summary** text generation")
234
- with gr.Row():
235
- sum_temp = gr.Slider(0.0, 2.0, 0, step=0.1, label="temp", interactive=True)
236
- sum_topp = gr.Slider(0.0, 2.0, 0, step=0.1, label="top_p", interactive=True)
237
- sum_topk = gr.Slider(20, 1000, 0, step=1, label="top_k", interactive=True)
238
- sum_rpen = gr.Slider(0.0, 2.0, 0, step=0.1, label="rep_penalty", interactive=True)
239
- sum_mnts = gr.Slider(64, 8192, 0, step=1, label="new_tokens", interactive=True)
240
- sum_beams = gr.Slider(1, 8, 0, step=1, label="beams", interactive=True)
241
- sum_cache = gr.Radio([True, False], value=0, label="cache", interactive=True)
242
- sum_sample = gr.Radio([True, False], value=0, label="sample", interactive=True)
243
- sum_eosid = gr.Number(value=0, visible=False, precision=0)
244
- sum_padid = gr.Number(value=0, visible=False, precision=0)
245
-
246
  with gr.Column():
247
  gr.Markdown("#### Context managements")
248
  with gr.Row():
@@ -255,7 +249,8 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
255
 
256
  instruction_txtbox.submit(
257
  chat_stream,
258
- [idx, local_data, instruction_txtbox, chat_state],
 
259
  [chatbot, local_data]
260
  )
261
 
 
61
  "" if len(placeholders) >= 1 else txt
62
  )
63
 
64
+ async def chat_stream(
65
+ idx, local_data, instruction_txtbox, chat_state,
66
+ global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv
67
+ ):
68
  res = [
69
  chat_state["ppmanager_type"].from_json(json.dumps(ppm))
70
  for ppm in local_data
 
74
  ppm.add_pingpong(
75
  PingPong(instruction_txtbox, "")
76
  )
77
+ prompt = build_prompts(ppm, global_context, ctx_num_lconv)
78
+ async for result in gen_text(
79
+ prompt, hf_model=MODEL_ID, hf_token=TOKEN,
80
+ parameters={
81
+ 'max_new_tokens': res_mnts,
82
+ 'do_sample': res_sample,
83
+ 'return_full_text': False,
84
+ 'temperature': res_temp,
85
+ 'top_k': res_topk,
86
+ 'repetition_penalty': res_rpen
87
+ }
88
+ ):
89
  ppm.append_pong(result)
90
  yield ppm.build_uis(), str(res)
91
 
 
211
  with gr.Column():
212
  with gr.Column():
213
  gr.Markdown("#### Global context")
214
+ with gr.Accordion("global context will persist during conversation, and it is placed at the top of the prompt", open=True):
215
  global_context = gr.Textbox(
216
  "global context",
217
  lines=5,
 
231
 
232
  gr.Markdown("#### GenConfig for **response** text generation")
233
  with gr.Row():
234
+ res_temp = gr.Slider(0.0, 2.0, 1.0, step=0.1, label="temp", interactive=True)
235
+ res_topk = gr.Slider(20, 1000, 50, step=1, label="top_k", interactive=True)
236
+ res_rpen = gr.Slider(0.0, 2.0, 1.2, step=0.1, label="rep_penalty", interactive=True)
237
+ res_mnts = gr.Slider(64, 8192, 512, step=1, label="new_tokens", interactive=True)
 
 
 
238
  res_sample = gr.Radio([True, False], value=0, label="sample", interactive=True)
239
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  with gr.Column():
241
  gr.Markdown("#### Context managements")
242
  with gr.Row():
 
249
 
250
  instruction_txtbox.submit(
251
  chat_stream,
252
+ [idx, local_data, instruction_txtbox, chat_state,
253
+ global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv],
254
  [chatbot, local_data]
255
  )
256