decode-elm / app.py
mehradans92's picture
Update app.py
741dc63
import streamlit as st # Web App
import os
from PIL import Image
from utils import *
import asyncio
import pickle
docs = None
api_key = ' '
st.set_page_config(layout="wide")
image = Image.open('arxiv_decode.png')
st.image(image, width=1000)
#title
st.title("Answering questions from scientific papers")
st.markdown("##### This tool will allow you to ask questions and get answers based on scientific papers. It uses OpenAI's GPT models, and you must have your own API key. Each query is about 10k tokens, which costs about only $0.20 on your own API key, charged by OpenAI.")
st.markdown("##### Current version searches on different pre-print servers including [arXiv](https://arxiv.org), [chemRxiv](https://chemrxiv.org/engage/chemrxiv/public-dashboard), [bioRxiv](https://www.biorxiv.org/) and [medRxiv](https://www.medrxiv.org/). 🚧Under development🚧")
st.markdown("Used libraries:\n * [PaperQA](https://github.com/whitead/paper-qa) \n* [langchain](https://github.com/hwchase17/langchain)")
st.markdown("See this [tweet](https://twitter.com/MehradAnsari/status/1627649959204888576) for a demo.")
api_key_url = 'https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key'
api_key = st.text_input('OpenAI API Key',
placeholder='sk-...',
help=f"['What is that?']({api_key_url})",
type="password",
value = '')
os.environ["OPENAI_API_KEY"] = f"{api_key}" #
if len(api_key) != 51:
st.warning('Please enter a valid OpenAI API key.', icon="⚠️")
max_results_current = 5
max_results = max_results_current
def search_click_callback(search_query, max_results, XRxiv_servers=[]):
global pdf_info, pdf_citation
search_engines = XRxivQuery(search_query, max_results, XRxiv_servers=XRxiv_servers)
pdf_info = search_engines.call_API()
search_engines.download_pdf()
return pdf_info
with st.form(key='columns_in_form', clear_on_submit = False):
c1, c2 = st.columns([5, 0.8])
with c1:
search_query = st.text_input("Input search query here:", placeholder='Keywords for most relevant search...', value=''
)
with c2:
max_results = st.number_input("Max papers", value=max_results_current)
max_results_current = max_results_current
st.markdown('Pre-print server')
checks = st.columns(4)
with checks[0]:
ArXiv_check = st.checkbox('arXiv')
with checks[1]:
ChemArXiv_check = st.checkbox('chemRxiv')
with checks[2]:
BioArXiv_check = st.checkbox('bioRxiv')
with checks[3]:
MedrXiv_check = st.checkbox('medRxiv')
searchButton = st.form_submit_button(label = 'Search')
if searchButton:
# checking which pre-print servers selected
XRxiv_servers = []
if ArXiv_check:
XRxiv_servers.append('rxiv')
if ChemArXiv_check:
XRxiv_servers.append('chemrxiv')
if BioArXiv_check:
XRxiv_servers.append('biorxiv')
if MedrXiv_check:
XRxiv_servers.append('medrxiv')
global pdf_info
pdf_info = search_click_callback(search_query, max_results, XRxiv_servers=XRxiv_servers)
if 'pdf_info' not in st.session_state:
st.session_state.key = 'pdf_info'
st.session_state['pdf_info'] = pdf_info
def answer_callback(question_query, word_count):
import paperqa
global docs
if docs is None:
pdf_info = st.session_state['pdf_info']
docs = paperqa.Docs()
pdf_paths = [f"{p[4]}/{p[0].replace(':','').replace('/','').replace('.','')}.pdf" for p in pdf_info]
pdf_citations = [p[5] for p in pdf_info]
print(list(zip(pdf_paths, pdf_citations)))
for d, c in zip(pdf_paths, pdf_citations):
docs.add(d, c)
docs._build_texts_index()
answer = docs.query(question_query, length_prompt=f'use {word_count:d} words')
st.success('Voila! 😃')
return answer.formatted_answer
with st.form(key='question_form', clear_on_submit = False):
c1, c2 = st.columns([6, 2])
with c1:
question_query = st.text_input("What do you wanna know from these papers?", placeholder='Input questions here...',
value='')
with c2:
word_count = st.slider("Suggested number of words in your answer?", 30, 300, 100)
submitButton = st.form_submit_button('Submit')
if submitButton:
with st.expander("Found papers:", expanded=True):
st.write(f"{st.session_state['all_reference_text']}")
with st.spinner('⏳ Please wait...'):
start = time.time()
final_answer = answer_callback(question_query, word_count)
length_answer = len(final_answer)
st.text_area("Answer:", final_answer, height=max(length_answer//4, 100))
end = time.time()
clock_time = end - start
with st.empty():
st.write(f"✔️ Task completed in {clock_time:.2f} seconds.")