awacke1 commited on
Commit
7dc68f9
1 Parent(s): 483b218

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -111
app.py CHANGED
@@ -6,13 +6,12 @@ from huggingface_hub import InferenceClient
6
  import re
7
  from datetime import datetime
8
  import json
9
- import os
10
 
11
  import arxiv
12
  from utils import get_md_text_abstract, search_cleaner, get_arxiv_live_search
13
 
14
- retrieve_results = 10
15
- show_examples = False
16
  llm_models_to_choose = ['mistralai/Mixtral-8x7B-Instruct-v0.1','mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-7b-it', 'None']
17
 
18
  generate_kwargs = dict(
@@ -22,6 +21,7 @@ generate_kwargs = dict(
22
  do_sample = False,
23
  )
24
 
 
25
  RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
26
 
27
  try:
@@ -32,6 +32,7 @@ try:
32
  except:
33
  gr.Warning("Retriever not working!")
34
 
 
35
  mark_text = '# 🩺🔍 Search Results\n'
36
  header_text = "## Arxiv Paper Summary With QA Retrieval Augmented Generation \n"
37
 
@@ -49,6 +50,7 @@ except:
49
 
50
  database_choices = [index_info,'Arxiv Search - Latest - (EXPERIMENTAL)']
51
 
 
52
  arx_client = arxiv.Client()
53
  is_arxiv_available = True
54
  check_arxiv_result = get_arxiv_live_search("What is Self Rewarding AI and how can it be used in Multi-Agent Systems?", arx_client, retrieve_results)
@@ -57,15 +59,27 @@ if len(check_arxiv_result) == 0:
57
  print("Arxiv search not working, switching to default search ...")
58
  database_choices = [index_info]
59
 
60
- if show_examples:
61
- with open("sample_outputs.json", "r") as f:
62
- sample_outputs = json.load(f)
63
- output_placeholder = sample_outputs['output_placeholder']
64
- md_text_initial = sample_outputs['search_placeholder']
65
-
66
- else:
67
- output_placeholder = None
68
- md_text_initial = ''
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def rag_cleaner(inp):
71
  rank = inp['rank']
@@ -122,123 +136,87 @@ def SaveResponseAndRead(result):
122
  '''
123
  gr.HTML(documentHTML5)
124
 
125
- def save_search_results(prompt, results, response):
126
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
127
- #filename = f"{timestamp}_{re.sub(r'[^\\w\\-_\\. ]', '_', prompt)}.txt"
128
- filename = f"{timestamp} - {prompt}.txt"
129
- with open(filename, "w") as f:
130
- f.write(f"# {prompt}\n\n")
131
- f.write(f"## Search Results\n\n{results}\n\n")
132
- f.write(f"## LLM Response\n\n{response}\n")
133
- return filename
134
-
135
- def get_past_searches():
136
- txt_files = [f for f in os.listdir(".") if f.endswith(".txt") and f != "requirements.txt"]
137
- return txt_files
138
 
139
  with gr.Blocks(theme = gr.themes.Soft()) as demo:
140
  header = gr.Markdown(header_text)
 
 
 
 
 
 
 
 
 
 
141
 
142
- with gr.Row():
143
- with gr.Column():
144
- msg = gr.Textbox(label = 'Search', placeholder = 'What is Mistral?')
145
-
146
- with gr.Accordion("Advanced Settings", open=False):
147
- with gr.Row(equal_height = True):
148
- llm_model = gr.Dropdown(choices = llm_models_to_choose, value = 'mistralai/Mistral-7B-Instruct-v0.2', label = 'LLM Model')
149
- llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context")
150
- database_src = gr.Dropdown(choices = database_choices, value = index_info, label = 'Search Source')
151
- stream_results = gr.Checkbox(value = True, label = "Stream output", visible = False)
152
-
153
- output_text = gr.Textbox(show_label = True, container = True, label = 'LLM Answer', visible = True, placeholder = output_placeholder)
154
- input = gr.Textbox(show_label = False, visible = False)
155
- gr_md = gr.Markdown(mark_text + md_text_initial)
156
-
157
- with gr.Column():
158
- past_searches = gr.Dropdown(choices=get_past_searches(), label="Past Searches")
159
- past_search_content = gr.Textbox(label="Past Search Content", visible=False)
160
-
161
- def update_past_search_content(past_search):
162
- if past_search:
163
- with open(past_search, "r") as f:
164
- content = f.read()
165
- return gr.Textbox.update(value=content, visible=True)
166
- else:
167
- return gr.Textbox.update(visible=False)
168
 
169
- past_searches.change(update_past_search_content, past_searches, past_search_content)
170
-
171
  def update_with_rag_md(message, llm_results_use = 5, database_choice = index_info, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2'):
172
  prompt_text_from_data = ""
173
  database_to_use = database_choice
174
  if database_choice == index_info:
175
- rag_out = get_rag(message)
176
  else:
177
- arxiv_search_success = True
178
- try:
179
- rag_out = get_arxiv_live_search(message, arx_client, retrieve_results)
180
- if len(rag_out) == 0:
181
- arxiv_search_success = False
182
- except:
183
- arxiv_search_success = False
184
-
185
- if not arxiv_search_success:
186
- gr.Warning("Arxiv Search not working, switching to semantic search ...")
187
- rag_out = get_rag(message)
188
- database_to_use = index_info
189
-
 
190
  md_text_updated = mark_text
191
  for i in range(retrieve_results):
192
- rag_answer = rag_out[i]
193
- if i < llm_results_use:
194
- md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source = database_to_use, return_prompt_formatting = True)
195
- prompt_text_from_data += f"{i+1}. {prompt_text}"
196
- else:
197
- md_text_paper = get_md_text_abstract(rag_answer, source = database_to_use)
198
- md_text_updated += md_text_paper
199
  prompt = get_prompt_text(message, prompt_text_from_data, llm_model_picked = llm_model_picked)
200
-
201
- filename = save_search_results(message, md_text_updated, "")
202
-
203
- with open(filename, "r") as f:
204
- md_content = f.read()
205
-
206
- return md_content, prompt, get_past_searches()
207
-
208
  def ask_llm(prompt, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2', stream_outputs = False):
209
- model_disabled_text = "LLM Model is disabled"
210
- output = ""
211
 
212
- if llm_model_picked == 'None':
213
- if stream_outputs:
214
- for out in model_disabled_text:
215
- output += out
216
- yield output
217
- return output
218
- else:
219
- return model_disabled_text
220
 
221
- client = InferenceClient(llm_model_picked)
222
- try:
223
- stream = client.text_generation(prompt, **generate_kwargs, stream=stream_outputs, details=False, return_full_text=False)
224
 
225
- except:
226
- gr.Warning("LLM Inference rate limit reached, try again later!")
227
- return ""
228
 
229
- if stream_outputs:
230
- for response in stream:
231
- output += response
232
- SaveResponseAndRead(response)
233
- yield output
234
- return output
235
- else:
236
- return stream
237
 
238
- msg.submit(update_with_rag_md, [msg, llm_results, database_src, llm_model], [gr_md, input, past_searches]).success(ask_llm, [input, llm_model, stream_results], output_text).then(
239
- lambda response: save_search_results(msg.value, gr_md.value, response),
240
- [msg, gr_md, output_text],
241
- None
242
- )
243
 
244
  demo.queue().launch()
 
6
  import re
7
  from datetime import datetime
8
  import json
 
9
 
10
  import arxiv
11
  from utils import get_md_text_abstract, search_cleaner, get_arxiv_live_search
12
 
13
+ retrieve_results = 20
14
+ show_examples = True
15
  llm_models_to_choose = ['mistralai/Mixtral-8x7B-Instruct-v0.1','mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-7b-it', 'None']
16
 
17
  generate_kwargs = dict(
 
21
  do_sample = False,
22
  )
23
 
24
+ ## RAG Model
25
  RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
26
 
27
  try:
 
32
  except:
33
  gr.Warning("Retriever not working!")
34
 
35
+ ## Header
36
  mark_text = '# 🩺🔍 Search Results\n'
37
  header_text = "## Arxiv Paper Summary With QA Retrieval Augmented Generation \n"
38
 
 
50
 
51
  database_choices = [index_info,'Arxiv Search - Latest - (EXPERIMENTAL)']
52
 
53
+ ## Arxiv API
54
  arx_client = arxiv.Client()
55
  is_arxiv_available = True
56
  check_arxiv_result = get_arxiv_live_search("What is Self Rewarding AI and how can it be used in Multi-Agent Systems?", arx_client, retrieve_results)
 
59
  print("Arxiv search not working, switching to default search ...")
60
  database_choices = [index_info]
61
 
62
+
63
+
64
+ ## Show examples
65
+ sample_outputs = {
66
+ 'output_placeholder': 'The LLM will provide an answer to your question here...',
67
+ 'search_placeholder': '''
68
+ 1. What is MoE?
69
+ 2. What are Multi Agent Systems?
70
+ 3. What is Self Rewarding AI?
71
+ 4. What is Semantic and Episodic memory?
72
+ 5. What is AutoGen?
73
+ 6. What is ChatDev?
74
+ 7. What is Omniverse?
75
+ 8. What is Lumiere?
76
+ 9. What is SORA?
77
+ '''
78
+ }
79
+
80
+ output_placeholder = sample_outputs['output_placeholder']
81
+ md_text_initial = sample_outputs['search_placeholder']
82
+
83
 
84
  def rag_cleaner(inp):
85
  rank = inp['rank']
 
136
  '''
137
  gr.HTML(documentHTML5)
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  with gr.Blocks(theme = gr.themes.Soft()) as demo:
141
  header = gr.Markdown(header_text)
142
+
143
+ with gr.Group():
144
+ msg = gr.Textbox(label = 'Search', placeholder = 'What is Mistral?')
145
+
146
+ with gr.Accordion("Advanced Settings", open=False):
147
+ with gr.Row(equal_height = True):
148
+ llm_model = gr.Dropdown(choices = llm_models_to_choose, value = 'mistralai/Mistral-7B-Instruct-v0.2', label = 'LLM Model')
149
+ llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context")
150
+ database_src = gr.Dropdown(choices = database_choices, value = index_info, label = 'Search Source')
151
+ stream_results = gr.Checkbox(value = True, label = "Stream output", visible = False)
152
 
153
+ output_text = gr.Textbox(show_label = True, container = True, label = 'LLM Answer', visible = True, placeholder = output_placeholder)
154
+ input = gr.Textbox(show_label = False, visible = False)
155
+ gr_md = gr.Markdown(mark_text + md_text_initial)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
 
 
157
  def update_with_rag_md(message, llm_results_use = 5, database_choice = index_info, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2'):
158
  prompt_text_from_data = ""
159
  database_to_use = database_choice
160
  if database_choice == index_info:
161
+ rag_out = get_rag(message)
162
  else:
163
+ arxiv_search_success = True
164
+ try:
165
+ rag_out = get_arxiv_live_search(message, arx_client, retrieve_results)
166
+ if len(rag_out) == 0:
167
+ arxiv_search_success = False
168
+ except:
169
+ arxiv_search_success = False
170
+
171
+
172
+ if not arxiv_search_success:
173
+ gr.Warning("Arxiv Search not working, switching to semantic search ...")
174
+ rag_out = get_rag(message)
175
+ database_to_use = index_info
176
+
177
  md_text_updated = mark_text
178
  for i in range(retrieve_results):
179
+ rag_answer = rag_out[i]
180
+ if i < llm_results_use:
181
+ md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source = database_to_use, return_prompt_formatting = True)
182
+ prompt_text_from_data += f"{i+1}. {prompt_text}"
183
+ else:
184
+ md_text_paper = get_md_text_abstract(rag_answer, source = database_to_use)
185
+ md_text_updated += md_text_paper
186
  prompt = get_prompt_text(message, prompt_text_from_data, llm_model_picked = llm_model_picked)
187
+ return md_text_updated, prompt
188
+
 
 
 
 
 
 
189
  def ask_llm(prompt, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2', stream_outputs = False):
190
+ model_disabled_text = "LLM Model is disabled"
191
+ output = ""
192
 
193
+ if llm_model_picked == 'None':
194
+ if stream_outputs:
195
+ for out in model_disabled_text:
196
+ output += out
197
+ yield output
198
+ return output
199
+ else:
200
+ return model_disabled_text
201
 
202
+ client = InferenceClient(llm_model_picked)
203
+ try:
204
+ stream = client.text_generation(prompt, **generate_kwargs, stream=stream_outputs, details=False, return_full_text=False)
205
 
206
+ except:
207
+ gr.Warning("LLM Inference rate limit reached, try again later!")
208
+ return ""
209
 
210
+ if stream_outputs:
211
+ for response in stream:
212
+ output += response
213
+ SaveResponseAndRead(response)
214
+ yield output
215
+ return output
216
+ else:
217
+ return stream
218
 
219
+
220
+ msg.submit(update_with_rag_md, [msg, llm_results, database_src, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text)
 
 
 
221
 
222
  demo.queue().launch()