Spaces:
Runtime error
Runtime error
Added download, streaming and initial placeholder
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ from ragatouille import RAGPretrainedModel
|
|
5 |
from huggingface_hub import InferenceClient
|
6 |
import re
|
7 |
from datetime import datetime
|
|
|
8 |
|
9 |
retrieve_results = 10
|
10 |
|
@@ -16,15 +17,16 @@ generate_kwargs = dict(
|
|
16 |
)
|
17 |
|
18 |
RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
|
|
|
19 |
try:
|
20 |
gr.Info("Setting up retriever, please wait...")
|
21 |
-
|
22 |
gr.Info("Retriever working successfully!")
|
23 |
except:
|
24 |
gr.Warning("Retriever not working!")
|
25 |
|
26 |
mark_text = '# 🔍 Search Results\n'
|
27 |
-
header_text = "#
|
28 |
try:
|
29 |
with open("README.md", "r") as f:
|
30 |
mdfile = f.read()
|
@@ -36,6 +38,12 @@ try:
|
|
36 |
except:
|
37 |
pass
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def rag_cleaner(inp):
|
40 |
rank = inp['rank']
|
41 |
title = inp['document_metadata']['title']
|
@@ -59,15 +67,15 @@ def get_rag(message):
|
|
59 |
with gr.Blocks(theme = gr.themes.Soft()) as demo:
|
60 |
header = gr.Markdown(header_text)
|
61 |
with gr.Group():
|
62 |
-
msg = gr.Textbox(label = 'Search')
|
63 |
with gr.Accordion("Advanced Settings", open=False):
|
64 |
with gr.Row(equal_height = True):
|
65 |
llm_model = gr.Dropdown(choices = ['mistralai/Mixtral-8x7B-Instruct-v0.1','mistralai/Mistral-7B-Instruct-v0.2', 'None'], value = 'mistralai/Mistral-7B-Instruct-v0.2', label = 'LLM Model')
|
66 |
llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results to sent as context")
|
67 |
|
68 |
-
output_text = gr.Textbox(show_label = True, container = True, label = 'LLM Answer', visible = True)
|
69 |
input = gr.Textbox(show_label = False, visible = False)
|
70 |
-
gr_md = gr.Markdown(mark_text)
|
71 |
|
72 |
def update_with_rag_md(message, llm_results_use = 5):
|
73 |
rag_out = get_rag(message)
|
@@ -76,8 +84,9 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
|
|
76 |
rag_answer = rag_out[i]
|
77 |
title = rag_answer['document_metadata']['title'].replace('\n','')
|
78 |
|
79 |
-
score = round(rag_answer['score'], 2)
|
80 |
-
|
|
|
81 |
paper_abs = rag_answer['content']
|
82 |
authors = rag_answer['document_metadata']['authors'].replace('\n','')
|
83 |
authors_formatted = f'*{authors}*' + ' \n\n'
|
@@ -90,9 +99,16 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
|
|
90 |
if llm_model_picked == 'None':
|
91 |
return gr.Textbox(visible = False)
|
92 |
client = InferenceClient(llm_model_picked)
|
93 |
-
output = client.text_generation(prompt, **generate_kwargs, stream=False, details=False, return_full_text=False)
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
msg.submit(update_with_rag_md, [msg, llm_results], [gr_md, input]).success(ask_llm, [input, llm_model], output_text)
|
98 |
|
|
|
5 |
from huggingface_hub import InferenceClient
|
6 |
import re
|
7 |
from datetime import datetime
|
8 |
+
import json
|
9 |
|
10 |
retrieve_results = 10
|
11 |
|
|
|
17 |
)
|
18 |
|
19 |
RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
|
20 |
+
|
21 |
try:
|
22 |
gr.Info("Setting up retriever, please wait...")
|
23 |
+
rag_initial_output = RAG.search("what is Mistral?", k = 1)
|
24 |
gr.Info("Retriever working successfully!")
|
25 |
except:
|
26 |
gr.Warning("Retriever not working!")
|
27 |
|
28 |
mark_text = '# 🔍 Search Results\n'
|
29 |
+
header_text = "# ArXivCS RAG \n"
|
30 |
try:
|
31 |
with open("README.md", "r") as f:
|
32 |
mdfile = f.read()
|
|
|
38 |
except:
|
39 |
pass
|
40 |
|
41 |
+
with open("sample_outputs.json", "r") as f:
|
42 |
+
sample_outputs = json.load(f)
|
43 |
+
output_placeholder = sample_outputs['output_placeholder']
|
44 |
+
md_text_initial = sample_outputs['search_placeholder']
|
45 |
+
|
46 |
+
|
47 |
def rag_cleaner(inp):
|
48 |
rank = inp['rank']
|
49 |
title = inp['document_metadata']['title']
|
|
|
67 |
with gr.Blocks(theme = gr.themes.Soft()) as demo:
|
68 |
header = gr.Markdown(header_text)
|
69 |
with gr.Group():
|
70 |
+
msg = gr.Textbox(label = 'Search', placeholder = 'What is Mistral?')
|
71 |
with gr.Accordion("Advanced Settings", open=False):
|
72 |
with gr.Row(equal_height = True):
|
73 |
llm_model = gr.Dropdown(choices = ['mistralai/Mixtral-8x7B-Instruct-v0.1','mistralai/Mistral-7B-Instruct-v0.2', 'None'], value = 'mistralai/Mistral-7B-Instruct-v0.2', label = 'LLM Model')
|
74 |
llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results to sent as context")
|
75 |
|
76 |
+
output_text = gr.Textbox(show_label = True, container = True, label = 'LLM Answer', visible = True, placeholder = output_placeholder)
|
77 |
input = gr.Textbox(show_label = False, visible = False)
|
78 |
+
gr_md = gr.Markdown(mark_text + md_text_initial)
|
79 |
|
80 |
def update_with_rag_md(message, llm_results_use = 5):
|
81 |
rag_out = get_rag(message)
|
|
|
84 |
rag_answer = rag_out[i]
|
85 |
title = rag_answer['document_metadata']['title'].replace('\n','')
|
86 |
|
87 |
+
#score = round(rag_answer['score'], 2)
|
88 |
+
date = rag_answer['document_metadata']['_time']
|
89 |
+
paper_title = f'''### {date} | [{title}](https://arxiv.org/abs/{rag_answer['document_id']}) | [⬇️](https://arxiv.org/pdf/{rag_answer['document_id']})\n'''
|
90 |
paper_abs = rag_answer['content']
|
91 |
authors = rag_answer['document_metadata']['authors'].replace('\n','')
|
92 |
authors_formatted = f'*{authors}*' + ' \n\n'
|
|
|
99 |
if llm_model_picked == 'None':
|
100 |
return gr.Textbox(visible = False)
|
101 |
client = InferenceClient(llm_model_picked)
|
102 |
+
#output = client.text_generation(prompt, **generate_kwargs, stream=False, details=False, return_full_text=False)
|
103 |
+
stream = client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
104 |
+
#output = output.lstrip(' \n') if output.lstrip().startswith('\n') else output
|
105 |
+
output = ""
|
106 |
+
|
107 |
+
for response in stream:
|
108 |
+
output += response.token.text
|
109 |
+
yield output
|
110 |
+
return output
|
111 |
+
#return gr.Textbox(output, visible = True)
|
112 |
|
113 |
msg.submit(update_with_rag_md, [msg, llm_results], [gr_md, input]).success(ask_llm, [input, llm_model], output_text)
|
114 |
|