chansung commited on
Commit
d82a572
1 Parent(s): a5ad5a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -11
app.py CHANGED
@@ -107,9 +107,40 @@ def reset_chat(idx, ld, state):
107
  gr.update(interactive=False),
108
  )
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  async def chat_stream(
111
  idx, local_data, instruction_txtbox, chat_state,
112
- global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv
 
113
  ):
114
  res = [
115
  chat_state["ppmanager_type"].from_json(json.dumps(ppm))
@@ -121,6 +152,14 @@ async def chat_stream(
121
  PingPong(instruction_txtbox, "")
122
  )
123
  prompt = build_prompts(ppm, global_context, ctx_num_lconv)
 
 
 
 
 
 
 
 
124
  async for result in gen_text(
125
  prompt, hf_model=MODEL_ID, hf_token=TOKEN,
126
  parameters={
@@ -283,14 +322,15 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
283
  elem_id="global-context"
284
  )
285
 
286
- # gr.Markdown("#### Internet search")
287
- # with gr.Row():
288
- # internet_option = gr.Radio(choices=["on", "off"], value="off", label="mode")
289
- # serper_api_key = gr.Textbox(
290
- # value= "" if args.serper_api_key is None else args.serper_api_key,
291
- # placeholder="Get one by visiting serper.dev",
292
- # label="Serper api key"
293
- # )
 
294
 
295
  gr.Markdown("#### GenConfig for **response** text generation")
296
  with gr.Row():
@@ -315,7 +355,8 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
315
  ).then(
316
  chat_stream,
317
  [idx, local_data, instruction_txtbox, chat_state,
318
- global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv],
 
319
  [instruction_txtbox, context_inspector, chatbot, local_data, regenerate]
320
  ).then(
321
  None, local_data, None,
@@ -346,7 +387,8 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
346
  regen_event = regenerate.click(
347
  rollback_last,
348
  [idx, local_data, chat_state,
349
- global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv],
 
350
  [context_inspector, chatbot, local_data, regenerate]
351
  ).then(
352
  None, local_data, None,
 
107
  gr.update(interactive=False),
108
  )
109
 
110
+ def internet_search(ppmanager, serper_api_key, global_context, ctx_num_lconv, device="cpu"):
111
+ internet_search_ppm = copy.deepcopy(ppm)
112
+ internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, "
113
+ f"give me an appropriate query to answer my question for google search. "
114
+ f"You should not say more than query. You should not say any words except the query."
115
+
116
+ internet_search_ppm.pingpongs[-1].ping = internet_search_prompt
117
+ internet_search_prompt = build_prompts(internet_search_ppm, "", win_size=ctx_num_lconv)
118
+
119
+ instruction = gen_text_none_stream(internet_search_prompt, hf_model=MODEL_ID, hf_token=TOKEN)
120
+ ###
121
+
122
+ searcher = SimilaritySearcher.from_pretrained(device=device)
123
+ iss = InternetSearchStrategy(
124
+ searcher,
125
+ instruction=instruction,
126
+ serper_api_key=serper_api_key
127
+ )(ppmanager)
128
+
129
+ step_ppm = None
130
+ while True:
131
+ try:
132
+ step_ppm, _ = next(iss)
133
+ yield "", step_ppm.build_uis()
134
+ except StopIteration:
135
+ break
136
+
137
+ search_prompt = build_prompts(step_ppm, global_context, ctx_num_lconv)
138
+ yield search_prompt, ppmanager.build_uis()
139
+
140
  async def chat_stream(
141
  idx, local_data, instruction_txtbox, chat_state,
142
+ global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
143
+ internet_option, serper_api_key
144
  ):
145
  res = [
146
  chat_state["ppmanager_type"].from_json(json.dumps(ppm))
 
152
  PingPong(instruction_txtbox, "")
153
  )
154
  prompt = build_prompts(ppm, global_context, ctx_num_lconv)
155
+
156
+ #######
157
+ if internet_option:
158
+ search_prompt = None
159
+ for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
160
+ search_prompt = tmp_prompt
161
+ yield "", uis, prompt, str(res)
162
+
163
  async for result in gen_text(
164
  prompt, hf_model=MODEL_ID, hf_token=TOKEN,
165
  parameters={
 
322
  elem_id="global-context"
323
  )
324
 
325
+ gr.Markdown("#### Internet search")
326
+ with gr.Row():
327
+ internet_option = gr.Radio(choices=["on", "off"], value="off", label="mode")
328
+ serper_api_key = gr.Textbox(
329
+ value= os.getenv("SERPER_API_KEY"),
330
+ placeholder="Get one by visiting serper.dev",
331
+ label="Serper api key",
332
+ visible=False
333
+ )
334
 
335
  gr.Markdown("#### GenConfig for **response** text generation")
336
  with gr.Row():
 
355
  ).then(
356
  chat_stream,
357
  [idx, local_data, instruction_txtbox, chat_state,
358
+ global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
359
+ internet_option, serper_api_key],
360
  [instruction_txtbox, context_inspector, chatbot, local_data, regenerate]
361
  ).then(
362
  None, local_data, None,
 
387
  regen_event = regenerate.click(
388
  rollback_last,
389
  [idx, local_data, chat_state,
390
+ global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
391
+ internet_option, serper_api_key],
392
  [context_inspector, chatbot, local_data, regenerate]
393
  ).then(
394
  None, local_data, None,