jharrison27 commited on
Commit
3dbe475
1 Parent(s): 3701fee

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +105 -131
  2. helper.py +19 -24
app.py CHANGED
@@ -2,165 +2,139 @@ import torch
2
  import transformers
3
  import gradio as gr
4
  from ragatouille import RAGPretrainedModel
5
- from huggingface_hub import InferenceClient
6
  import re
7
  from datetime import datetime
8
  import json
9
-
10
  import arxiv
11
- from helper import *
12
-
13
- retrieve_results = 20
14
- show_examples = True
15
- llm_models_to_choose = ['mistralai/Mixtral-8x7B-Instruct-v0.1','mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-7b-it', 'None']
16
 
17
- generate_kwargs = dict(
18
- temperature = None,
19
- max_new_tokens = 512,
20
- top_p = None,
21
- do_sample = False,
22
- )
 
 
 
 
 
 
23
 
24
- ## RAG Model
25
  RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
26
 
27
  try:
28
- gr.Info("Setting up retriever, please wait...")
29
- rag_initial_output = RAG.search("What is Generative AI in Healthcare?", k = 1)
30
- gr.Info("Retriever working successfully!")
31
-
32
- except:
33
- gr.Warning("Retriever not working!")
34
 
35
- ## Header
36
  mark_text = '# 🩺🔍 Search Results\n'
37
  header_text = "## Arxiv Paper Summary With QA Retrieval Augmented Generation \n"
38
 
39
  try:
40
- with open("README.md", "r") as f:
41
- mdfile = f.read()
42
- date_pattern = r'Index Last Updated : \d{4}-\d{2}-\d{2}'
43
- match = re.search(date_pattern, mdfile)
44
- date = match.group().split(': ')[1]
45
- formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y')
46
- header_text += f'Index Last Updated: {formatted_date}\n'
47
- index_info = f"Semantic Search - up to {formatted_date}"
48
- except:
49
- index_info = "Semantic Search"
50
-
51
- database_choices = [index_info,'Arxiv Search - Latest - (EXPERIMENTAL)']
52
-
53
- ## Arxiv API
54
  arx_client = arxiv.Client()
55
  is_arxiv_available = True
56
- check_arxiv_result = get_arxiv_live_search("What is Self Rewarding AI and how can it be used in Multi-Agent Systems?", arx_client, retrieve_results)
57
  if len(check_arxiv_result) == 0:
58
- is_arxiv_available = False
59
- print("Arxiv search not working, switching to default search ...")
60
- database_choices = [index_info]
61
-
62
 
63
-
64
- ## Show examples
65
- sample_outputs = {
66
- 'output_placeholder': 'The LLM will provide an answer to your question here...',
67
- 'search_placeholder': '''
68
- 1. What is MoE?
69
- 2. What are Multi Agent Systems?
70
- 3. What is Self Rewarding AI?
71
- 4. What is Semantic and Episodic memory?
72
- 5. What is AutoGen?
73
- 6. What is ChatDev?
74
- 7. What is Omniverse?
75
- 8. What is Lumiere?
76
- 9. What is SORA?
77
- '''
78
- }
79
-
80
- output_placeholder = sample_outputs['output_placeholder']
81
- md_text_initial = sample_outputs['search_placeholder']
82
-
83
-
84
- with gr.Blocks(theme = gr.themes.Soft()) as demo:
85
  header = gr.Markdown(header_text)
86
 
87
  with gr.Group():
88
- msg = gr.Textbox(label = 'Search', placeholder = 'What is Generative AI in Healthcare?')
89
 
90
- with gr.Accordion("Advanced Settings", open=False):
91
- with gr.Row(equal_height = True):
92
- llm_model = gr.Dropdown(choices = llm_models_to_choose, value = 'mistralai/Mistral-7B-Instruct-v0.2', label = 'LLM Model')
93
- llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context")
94
- database_src = gr.Dropdown(choices = database_choices, value = index_info, label = 'Search Source')
95
- stream_results = gr.Checkbox(value = True, label = "Stream output", visible = False)
96
-
97
- output_text = gr.Textbox(show_label = True, container = True, label = 'LLM Answer', visible = True, placeholder = output_placeholder)
98
- input = gr.Textbox(show_label = False, visible = False)
99
- gr_md = gr.Markdown(mark_text + md_text_initial)
100
-
101
- def update_with_rag_md(message, llm_results_use = 5, database_choice = index_info, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2'):
102
  prompt_text_from_data = ""
103
  database_to_use = database_choice
 
104
  if database_choice == index_info:
105
- rag_out = get_rag(message)
106
  else:
107
- arxiv_search_success = True
108
- try:
109
- rag_out = get_arxiv_live_search(message, arx_client, retrieve_results)
110
- if len(rag_out) == 0:
111
- arxiv_search_success = False
112
- except:
113
- arxiv_search_success = False
114
-
115
-
116
- if not arxiv_search_success:
117
- gr.Warning("Arxiv Search not working, switching to semantic search ...")
118
- rag_out = get_rag(message)
119
- database_to_use = index_info
120
-
121
  md_text_updated = mark_text
122
- for i in range(retrieve_results):
123
- rag_answer = rag_out[i]
124
- if i < llm_results_use:
125
- md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source = database_to_use, return_prompt_formatting = True)
126
- prompt_text_from_data += f"{i+1}. {prompt_text}"
127
- else:
128
- md_text_paper = get_md_text_abstract(rag_answer, source = database_to_use)
129
- md_text_updated += md_text_paper
130
- prompt = get_prompt_text(message, prompt_text_from_data, llm_model_picked = llm_model_picked)
131
  return md_text_updated, prompt
132
-
133
- def ask_llm(prompt, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2', stream_outputs = False):
134
- model_disabled_text = "LLM Model is disabled"
135
- output = ""
136
 
137
- if llm_model_picked == 'None':
138
- if stream_outputs:
139
- for out in model_disabled_text:
140
- output += out
141
- yield output
142
- return output
143
- else:
144
- return model_disabled_text
145
-
146
- client = InferenceClient(llm_model_picked)
147
- try:
148
- stream = client.text_generation(prompt, **generate_kwargs, stream=stream_outputs, details=False, return_full_text=False)
149
-
150
- except:
151
- gr.Warning("LLM Inference rate limit reached, try again later!")
152
- return ""
153
-
154
- if stream_outputs:
155
- for response in stream:
156
- output += response
157
- SaveResponseAndRead(response)
158
- yield output
159
- return output
160
- else:
161
- return stream
162
-
163
-
164
- msg.submit(update_with_rag_md, [msg, llm_results, database_src, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text)
165
 
166
  demo.queue().launch()
 
2
  import transformers
3
  import gradio as gr
4
  from ragatouille import RAGPretrainedModel
 
5
  import re
6
  from datetime import datetime
7
  import json
 
8
  import arxiv
 
 
 
 
 
9
 
10
+ from helper import rag_cleaner, get_prompt_text, get_references, get_rag, SaveResponseAndRead, get_md_text_abstract, search_cleaner, get_arxiv_live_search
11
+
12
+ # Constants
13
+ RETRIEVE_RESULTS = 20
14
+ LLM_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-7b-it', 'None']
15
+ DEFAULT_LLM_MODEL = 'mistralai/Mistral-7B-Instruct-v0.2'
16
+ GENERATE_KWARGS = {
17
+ "temperature": None,
18
+ "max_new_tokens": 512,
19
+ "top_p": None,
20
+ "do_sample": False,
21
+ }
22
 
23
+ # RAG Model setup
24
  RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
25
 
26
  try:
27
+ gr.Info("Setting up retriever, please wait...")
28
+ rag_initial_output = RAG.search("What is Generative AI in Healthcare?", k=1)
29
+ gr.Info("Retriever working successfully!")
30
+ except Exception as e:
31
+ gr.Warning(f"Retriever not working: {str(e)}")
 
32
 
33
+ # Header setup
34
  mark_text = '# 🩺🔍 Search Results\n'
35
  header_text = "## Arxiv Paper Summary With QA Retrieval Augmented Generation \n"
36
 
37
  try:
38
+ with open("README.md", "r") as f:
39
+ mdfile = f.read()
40
+ date_pattern = r'Index Last Updated : \d{4}-\d{2}-\d{2}'
41
+ match = re.search(date_pattern, mdfile)
42
+ date = match.group().split(': ')[1]
43
+ formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y')
44
+ header_text += f'Index Last Updated: {formatted_date}\n'
45
+ index_info = f"Semantic Search - up to {formatted_date}"
46
+ except FileNotFoundError:
47
+ index_info = "Semantic Search"
48
+
49
+ database_choices = [index_info, 'Arxiv Search - Latest - (EXPERIMENTAL)']
50
+
51
+ # Arxiv API setup
52
  arx_client = arxiv.Client()
53
  is_arxiv_available = True
54
+ check_arxiv_result = get_arxiv_live_search("What is Self Rewarding AI and how can it be used in Multi-Agent Systems?", arx_client, RETRIEVE_RESULTS)
55
  if len(check_arxiv_result) == 0:
56
+ is_arxiv_available = False
57
+ print("Arxiv search not working, switching to default search ...")
58
+ database_choices = [index_info]
 
59
 
60
+ # Gradio UI setup
61
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  header = gr.Markdown(header_text)
63
 
64
  with gr.Group():
65
+ search_query = gr.Textbox(label='Search', placeholder='What is Generative AI in Healthcare?')
66
 
67
+ with gr.Accordion("Advanced Settings", open=False):
68
+ with gr.Row(equal_height=True):
69
+ llm_model = gr.Dropdown(choices=LLM_MODELS, value=DEFAULT_LLM_MODEL, label='LLM Model')
70
+ llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context")
71
+ database_src = gr.Dropdown(choices=database_choices, value=index_info, label='Search Source')
72
+ stream_results = gr.Checkbox(value=True, label="Stream output", visible=False)
73
+
74
+ output_text = gr.Textbox(show_label=True, container=True, label='LLM Answer', visible=True)
75
+ input = gr.Textbox(show_label=False, visible=False)
76
+ gr_md = gr.Markdown(mark_text)
77
+
78
+ def update_with_rag_md(search_query, llm_results_use=5, database_choice=index_info, llm_model_picked=DEFAULT_LLM_MODEL):
79
  prompt_text_from_data = ""
80
  database_to_use = database_choice
81
+
82
  if database_choice == index_info:
83
+ rag_out = get_rag(search_query, RAG, RETRIEVE_RESULTS)
84
  else:
85
+ arxiv_search_success = True
86
+ try:
87
+ rag_out = get_arxiv_live_search(search_query, arx_client, RETRIEVE_RESULTS)
88
+ if len(rag_out) == 0:
89
+ arxiv_search_success = False
90
+ except Exception as e:
91
+ arxiv_search_success = False
92
+ gr.Warning(f"Arxiv Search not working: {str(e)}, switching to semantic search ...")
93
+
94
+ if not arxiv_search_success:
95
+ rag_out = get_rag(search_query, RAG, RETRIEVE_RESULTS)
96
+ database_to_use = index_info
97
+
 
98
  md_text_updated = mark_text
99
+ for i, rag_answer in enumerate(rag_out):
100
+ if i < llm_results_use:
101
+ md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source=database_to_use, return_prompt_formatting=True)
102
+ prompt_text_from_data += f"{i+1}. {prompt_text}"
103
+ else:
104
+ md_text_paper = get_md_text_abstract(rag_answer, source=database_to_use)
105
+ md_text_updated += md_text_paper
106
+
107
+ prompt = get_prompt_text(search_query, prompt_text_from_data, llm_model_picked=llm_model_picked)
108
  return md_text_updated, prompt
109
+
110
+ def ask_llm(prompt, llm_model_picked=DEFAULT_LLM_MODEL, stream_outputs=False):
111
+ model_disabled_text = "LLM Model is disabled"
112
+ output = ""
113
 
114
+ if llm_model_picked == 'None':
115
+ if stream_outputs:
116
+ for out in model_disabled_text:
117
+ output += out
118
+ yield output
119
+ else:
120
+ return model_disabled_text
121
+
122
+ client = InferenceClient(llm_model_picked)
123
+ try:
124
+ response = client.text_generation(prompt, stream=stream_outputs, details=False, return_full_text=False, **GENERATE_KWARGS)
125
+
126
+ if stream_outputs:
127
+ for token in response:
128
+ output += token
129
+ yield SaveResponseAndRead(output)
130
+ else:
131
+ output = response
132
+ except Exception as e:
133
+ gr.Warning(f"LLM Inference failed: {str(e)}")
134
+ output = ""
135
+
136
+ return output
137
+
138
+ search_query.submit(update_with_rag_md, [search_query, llm_results, database_src, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text)
 
 
 
139
 
140
  demo.queue().launch()
helper.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import datetime
2
  import string
3
  import nltk
@@ -5,11 +8,6 @@ nltk.download('stopwords')
5
  from nltk.corpus import stopwords
6
  stop_words = stopwords.words('english')
7
  import arxiv
8
- import gradio as gr
9
- import re
10
- from datetime import datetime
11
- import json
12
-
13
 
14
  def rag_cleaner(inp):
15
  rank = inp['rank']
@@ -18,28 +16,27 @@ def rag_cleaner(inp):
18
  date = inp['document_metadata']['_time']
19
  return f"{rank}. <b> {title} </b> \n Date : {date} \n Abstract: {content}"
20
 
21
- def get_prompt_text(question, context, formatted = True, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2'):
 
 
 
22
  if formatted:
23
- sys_instruction = f"Context:\n {context} \n Given the following scientific paper abstracts, take a deep breath and lets think step by step to answer the question. Cite the titles of your sources when answering, do not cite links or dates."
24
- message = f"Question: {question}"
25
-
26
- if 'mistralai' in llm_model_picked:
27
- return f"<s>" + f"[INST] {sys_instruction}" + f" {message}[/INST]"
28
-
29
- elif 'gemma' in llm_model_picked:
30
- return f"<bos><start_of_turn>user\n{sys_instruction}" + f" {message}<end_of_turn>\n"
31
-
32
  return f"Context:\n {context} \n Given the following info, take a deep breath and lets think step by step to answer the question: {question}. Cite the titles of your sources when answering.\n\n"
33
 
34
- def get_references(question, retriever, k = retrieve_results):
35
  rag_out = retriever.search(query=question, k=k)
36
  return rag_out
37
 
38
- def get_rag(message):
39
- return get_references(message, RAG)
40
 
41
  def SaveResponseAndRead(result):
42
- documentHTML5='''
43
  <!DOCTYPE html>
44
  <html>
45
  <head>
@@ -56,17 +53,15 @@ def SaveResponseAndRead(result):
56
  <h1>🔊 Read It Aloud</h1>
57
  <textarea id="textArea" rows="10" cols="80">
58
  '''
59
- documentHTML5 = documentHTML5 + result
60
- documentHTML5 = documentHTML5 + '''
61
  </textarea>
62
  <br>
63
  <button onclick="readAloud()">🔊 Read Aloud</button>
64
  </body>
65
  </html>
66
  '''
67
- gr.HTML(documentHTML5)
68
-
69
-
70
 
71
 
72
  def get_md_text_abstract(rag_answer, source = ['Arxiv Search', 'Semantic Search'][1], return_prompt_formatting = False):
 
1
+ import sys
2
+ import gradio as gr
3
+ from huggingface_hub import InferenceClient
4
  import datetime
5
  import string
6
  import nltk
 
8
  from nltk.corpus import stopwords
9
  stop_words = stopwords.words('english')
10
  import arxiv
 
 
 
 
 
11
 
12
  def rag_cleaner(inp):
13
  rank = inp['rank']
 
16
  date = inp['document_metadata']['_time']
17
  return f"{rank}. <b> {title} </b> \n Date : {date} \n Abstract: {content}"
18
 
19
+ def get_prompt_text(question, context, formatted=True, llm_model_picked='mistralai/Mistral-7B-Instruct-v0.2'):
20
+ sys_instruction = f"Context:\n {context} \n Given the following scientific paper abstracts, take a deep breath and lets think step by step to answer the question. Cite the titles of your sources when answering, do not cite links or dates."
21
+ message = f"Question: {question}"
22
+
23
  if formatted:
24
+ if 'mistralai' in llm_model_picked:
25
+ return f"<s>[INST] {sys_instruction} {message}[/INST]"
26
+ elif 'gemma' in llm_model_picked:
27
+ return f"<bos><start_of_turn>user\n{sys_instruction} {message}<end_of_turn>\n"
28
+
 
 
 
 
29
  return f"Context:\n {context} \n Given the following info, take a deep breath and lets think step by step to answer the question: {question}. Cite the titles of your sources when answering.\n\n"
30
 
31
+ def get_references(question, retriever, k):
32
  rag_out = retriever.search(query=question, k=k)
33
  return rag_out
34
 
35
+ def get_rag(message, RAG, RETRIEVE_RESULTS):
36
+ return get_references(message, RAG, k=RETRIEVE_RESULTS)
37
 
38
  def SaveResponseAndRead(result):
39
+ documentHTML5 = '''
40
  <!DOCTYPE html>
41
  <html>
42
  <head>
 
53
  <h1>🔊 Read It Aloud</h1>
54
  <textarea id="textArea" rows="10" cols="80">
55
  '''
56
+ documentHTML5 += result
57
+ documentHTML5 += '''
58
  </textarea>
59
  <br>
60
  <button onclick="readAloud()">🔊 Read Aloud</button>
61
  </body>
62
  </html>
63
  '''
64
+ return gr.HTML(documentHTML5)
 
 
65
 
66
 
67
  def get_md_text_abstract(rag_answer, source = ['Arxiv Search', 'Semantic Search'][1], return_prompt_formatting = False):