from typing import List, Optional import streamlit as st import streamlit_pydantic as sp from pydantic import BaseModel, Field from PIL import Image import tempfile from pathlib import Path from src.Surveyor import Surveyor @st.experimental_singleton(suppress_st_warning=True) def get_surveyor_instance(_print_fn, _survey_print_fn): with st.spinner('Loading The-Researcher ...'): return Surveyor(print_fn=_print_fn, survey_print_fn=_survey_print_fn, high_gpu=True) def run_survey(surveyor, download_placeholder, research_keywords=None, arxiv_ids=None, max_search=None, num_papers=None): import hashlib import time hash = hashlib.sha1() hash.update(str(time.time()).encode('utf-8')) temp_hash = hash.hexdigest() survey_root = Path(temp_hash).resolve() dir_args = {f'{dname}_dir': survey_root / dname for dname in ['pdf', 'txt', 'img', 'tab', 'dump']} for d in dir_args.values(): d.mkdir(exist_ok=True, parents=True) print(survey_root) print(dir_args) dir_args = {k: f'{str(v.resolve())}/' for k, v in dir_args.items()} zip_file_name, survey_file_name = surveyor.survey(research_keywords, arxiv_ids, max_search=max_search, num_papers=num_papers, **dir_args) show_survey_download(zip_file_name, survey_file_name, download_placeholder) def show_survey_download(zip_file_name, survey_file_name, download_placeholder): with open(str(zip_file_name), "rb") as file: btn = download_placeholder.download_button( label="Download extracted topic-clustered-highlights, images and tables as zip", data=file, file_name=str(zip_file_name) ) with open(str(survey_file_name), "rb") as file: btn = download_placeholder.download_button( label="Download detailed generated survey file", data=file, file_name=str(survey_file_name) ) class KeywordsModel(BaseModel): research_keywords: Optional[str] = Field( '', description="Enter your research keywords:" ) max_search: int = Field( 10, ge=1, le=50, multiple_of=1, description="num_papers_to_search:" ) num_papers: int = Field( 3, ge=1, le=8, multiple_of=1, description="num_papers_to_select:" ) class ArxivIDsModel(BaseModel): arxiv_ids: Optional[str] = Field( '', description="Enter comma_separated arxiv ids for your curated set of papers (e.g. 2205.12755, 2205.10937, ...):" ) if __name__ == '__main__': if 'session_count' not in st.session_state: st.session_state.session_count = 0 demo_session_limit = 2 if st.session_state.session_count > demo_session_limit: st.write(f'{st.session_state.session_count} sessions running, this is a demo and only supports {demo_session_limit} parallel sessions, \n please try in sometime') st.sidebar.image(Image.open('logo_landscape.png'), use_column_width = 'always') st.title('Auto-Research') st.write('#### A no-code utility to generate a detailed well-cited survey with topic clustered sections' '(draft paper format) and other interesting artifacts from a single research query or a curated set of papers(arxiv ids).') st.write('##### Data Provider: arXiv Open Archive Initiative OAI') st.write('##### GitHub: https://github.com/sidphbot/Auto-Research') st.write(f'Note: this is only a demo on cpu-13GB RAM, hence it supports limited number of papers & only {demo_session_limit} parallel user sessions') download_placeholder = st.container() with st.sidebar.form(key="survey_keywords_form"): session_data = sp.pydantic_input(key="keywords_input_model", model=KeywordsModel) st.write('or') session_data.update(sp.pydantic_input(key="arxiv_ids_input_model", model=ArxivIDsModel)) submit = st.form_submit_button(label="Submit") st.sidebar.write('#### execution log:') run_kwargs = {'surveyor':get_surveyor_instance(_print_fn=st.sidebar.write, _survey_print_fn=st.write), 'download_placeholder':download_placeholder} if submit: if st.session_state.session_count < demo_session_limit: st.session_state.session_count = st.session_state.session_count + 1 if session_data['research_keywords'] != '': run_kwargs.update({'research_keywords':session_data['research_keywords'], 'max_search':session_data['max_search'], 'num_papers':session_data['num_papers']}) elif session_data['arxiv_ids'] != '': run_kwargs.update({'arxiv_ids':[id.strip() for id in session_data['arxiv_ids'].split(',')]}) print(run_kwargs) try: run_survey(**run_kwargs) except e: st.write(f'ERROR: {str(e)}, server might be crowded right now, please try later, thank you for your patience') pass else: st.write(f'{st.session_state.session_count} sessions running, this is a demo and only supports 2 parallel sessions, \n please try in sometime')