Spaces:
Runtime error
Runtime error
Added option for selecting LLM and number of abstracts as input
Browse files
app.py
CHANGED
@@ -5,9 +5,6 @@ from ragatouille import RAGPretrainedModel
|
|
5 |
from huggingface_hub import InferenceClient
|
6 |
|
7 |
retrieve_results = 10
|
8 |
-
llm_results = 5
|
9 |
-
|
10 |
-
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
11 |
|
12 |
generate_kwargs = dict(
|
13 |
temperature = None,
|
@@ -24,7 +21,7 @@ try:
|
|
24 |
except:
|
25 |
gr.Warning("Retriever not working!")
|
26 |
|
27 |
-
mark_text = '#
|
28 |
|
29 |
def rag_cleaner(inp):
|
30 |
rank = inp['rank']
|
@@ -47,29 +44,43 @@ def get_rag(message):
|
|
47 |
return get_references(message, RAG)
|
48 |
|
49 |
with gr.Blocks(theme = gr.themes.Soft()) as demo:
|
|
|
50 |
with gr.Group():
|
51 |
msg = gr.Textbox(label = 'Search')
|
|
|
|
|
|
|
|
|
|
|
52 |
output_text = gr.Textbox(show_label = True, container = True, label = 'LLM Answer', visible = True)
|
53 |
input = gr.Textbox(show_label = False, visible = False)
|
54 |
gr_md = gr.Markdown(mark_text)
|
55 |
|
56 |
-
def update_with_rag_md(message):
|
57 |
rag_out = get_rag(message)
|
58 |
md_text_updated = mark_text
|
59 |
-
for i in range(
|
60 |
rag_answer = rag_out[i]
|
61 |
title = rag_answer['document_metadata']['title'].replace('\n','')
|
62 |
-
|
|
|
|
|
63 |
paper_abs = rag_answer['content']
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
66 |
return md_text_updated, prompt
|
67 |
|
68 |
-
def ask_llm(prompt):
|
|
|
|
|
|
|
69 |
output = client.text_generation(prompt, **generate_kwargs, stream=False, details=False, return_full_text=False)
|
70 |
output = output.lstrip(' \n') if output.lstrip().startswith('\n') else output
|
71 |
return gr.Textbox(output, visible = True)
|
72 |
|
73 |
-
msg.submit(update_with_rag_md, msg, [gr_md, input]).success(ask_llm, input, output_text)
|
74 |
|
75 |
demo.launch(debug = True)
|
|
|
5 |
from huggingface_hub import InferenceClient
|
6 |
|
7 |
retrieve_results = 10
|
|
|
|
|
|
|
8 |
|
9 |
generate_kwargs = dict(
|
10 |
temperature = None,
|
|
|
21 |
except:
|
22 |
gr.Warning("Retriever not working!")
|
23 |
|
24 |
+
mark_text = '# 🔍 Search Results\n'
|
25 |
|
26 |
def rag_cleaner(inp):
|
27 |
rank = inp['rank']
|
|
|
44 |
return get_references(message, RAG)
|
45 |
|
46 |
with gr.Blocks(theme = gr.themes.Soft()) as demo:
|
47 |
+
header = gr.Markdown("# ArXiv RAG")
|
48 |
with gr.Group():
|
49 |
msg = gr.Textbox(label = 'Search')
|
50 |
+
with gr.Accordion("Advanced Settings", open=False):
|
51 |
+
with gr.Row(equal_height = True):
|
52 |
+
llm_model = gr.Dropdown(choices = ['mistralai/Mixtral-8x7B-Instruct-v0.1','mistralai/Mistral-7B-Instruct-v0.2', 'None'], value = 'mistralai/Mistral-7B-Instruct-v0.2', label = 'LLM Model')
|
53 |
+
llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results to sent as context")
|
54 |
+
|
55 |
output_text = gr.Textbox(show_label = True, container = True, label = 'LLM Answer', visible = True)
|
56 |
input = gr.Textbox(show_label = False, visible = False)
|
57 |
gr_md = gr.Markdown(mark_text)
|
58 |
|
59 |
+
def update_with_rag_md(message, llm_results_use = 5):
|
60 |
rag_out = get_rag(message)
|
61 |
md_text_updated = mark_text
|
62 |
+
for i in range(retrieve_results):
|
63 |
rag_answer = rag_out[i]
|
64 |
title = rag_answer['document_metadata']['title'].replace('\n','')
|
65 |
+
|
66 |
+
score = round(rag_answer['score'], 2)
|
67 |
+
paper_title = f'''### **{score}** | [{title}](https://arxiv.org/abs/{rag_answer['document_id']})\n'''
|
68 |
paper_abs = rag_answer['content']
|
69 |
+
authors = rag_answer['document_metadata']['authors'].replace('\n','')
|
70 |
+
authors_formatted = f'*{authors}*' + ' \n\n'
|
71 |
+
|
72 |
+
md_text_updated += paper_title + authors_formatted + paper_abs + '\n---------------\n'+ '\n'
|
73 |
+
prompt = get_prompt_text(message, '\n\n'.join(rag_cleaner(out) for out in rag_out[:llm_results_use]))
|
74 |
return md_text_updated, prompt
|
75 |
|
76 |
+
def ask_llm(prompt, llm_model_picked = 'mistralai/Mistral-7B-Instruct-v0.2'):
|
77 |
+
if llm_model_picked == 'None':
|
78 |
+
return gr.Textbox(visible = False)
|
79 |
+
client = InferenceClient(llm_model_picked)
|
80 |
output = client.text_generation(prompt, **generate_kwargs, stream=False, details=False, return_full_text=False)
|
81 |
output = output.lstrip(' \n') if output.lstrip().startswith('\n') else output
|
82 |
return gr.Textbox(output, visible = True)
|
83 |
|
84 |
+
msg.submit(update_with_rag_md, [msg, llm_results], [gr_md, input]).success(ask_llm, [input, llm_model], output_text)
|
85 |
|
86 |
demo.launch(debug = True)
|