File size: 5,410 Bytes
2bf46ef
a8d4e3d
2bf46ef
 
a864c58
dc621b3
8442e32
a8d4e3d
8387173
8cfcf51
f310b8b
8b54034
 
7fd0c02
8b54034
 
 
 
 
 
 
 
f07cf21
8b54034
 
 
 
 
 
 
56c991e
8b54034
 
 
71830cc
8b54034
 
f310b8b
a8d4e3d
f310b8b
1f3126c
 
 
 
 
 
 
 
 
 
 
 
 
a8d4e3d
 
2bf46ef
 
 
 
 
 
 
 
 
 
 
 
f310b8b
2bf46ef
 
a8492e7
 
2bf46ef
 
524154e
4143c64
 
 
 
 
 
 
 
524154e
 
 
 
 
 
4988819
524154e
 
8cfcf51
 
 
 
 
524154e
8cfcf51
8b54034
524154e
8cfcf51
4143c64
 
 
 
 
 
 
 
 
4988819
 
dc3f228
 
4988819
4143c64
 
7a5b32f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import List, Optional
import streamlit as st
import streamlit_pydantic as sp
from pydantic import BaseModel, Field
from PIL import Image
import tempfile
from pathlib import Path

from src.Surveyor import Surveyor


@st.experimental_singleton(suppress_st_warning=True)
def get_surveyor_instance(_print_fn, _survey_print_fn):
     with st.spinner('Loading The-Researcher ...'):
        return Surveyor(print_fn=_print_fn, survey_print_fn=_survey_print_fn, high_gpu=True)


def run_survey(surveyor, download_placeholder, research_keywords=None, arxiv_ids=None, max_search=None, num_papers=None):
    import hashlib
    import time

    hash = hashlib.sha1()
    hash.update(str(time.time()).encode('utf-8'))
    temp_hash = hash.hexdigest()
    survey_root = Path(temp_hash).resolve()
    dir_args = {f'{dname}_dir': survey_root / dname for dname in ['pdf', 'txt', 'img', 'tab', 'dump']}
    for d in dir_args.values():
        d.mkdir(exist_ok=True, parents=True)
    print(survey_root)
    print(dir_args)
    dir_args = {k: f'{str(v.resolve())}/' for k, v in dir_args.items()}
    zip_file_name, survey_file_name = surveyor.survey(research_keywords, 
                                                        arxiv_ids,
                                                        max_search=max_search, 
                                                        num_papers=num_papers,
                                                        **dir_args)
    show_survey_download(zip_file_name, survey_file_name, download_placeholder)


def show_survey_download(zip_file_name, survey_file_name, download_placeholder):
    with open(str(zip_file_name), "rb") as file:
        btn = download_placeholder.download_button(
            label="Download extracted topic-clustered-highlights, images and tables as zip",
            data=file,
            file_name=str(zip_file_name)
        )

    with open(str(survey_file_name), "rb") as file:
        btn = download_placeholder.download_button(
            label="Download detailed generated survey file",
            data=file,
            file_name=str(survey_file_name)
        )


class KeywordsModel(BaseModel):
    research_keywords: Optional[str] =  Field(
        '', description="Enter your research keywords:"
    )
    max_search: int = Field(
        10, ge=1, le=50, multiple_of=1,
        description="num_papers_to_search:"
    )
    num_papers: int = Field(
        3, ge=1, le=8, multiple_of=1, 
        description="num_papers_to_select:"
    )


class ArxivIDsModel(BaseModel):
    arxiv_ids: Optional[str] =  Field(
        '', description="Enter comma_separated arxiv ids for your curated set of papers (e.g. 2205.12755, 2205.10937, ...):"
    )

if __name__ == '__main__':
    if 'session_count' not in st.session_state:
        st.session_state.session_count = 0
        
    demo_session_limit = 2
    
    if st.session_state.session_count > demo_session_limit:
        st.write(f'{st.session_state.session_count} sessions running, this is a demo and only supports {demo_session_limit} parallel sessions, \n please try in sometime')
    
    st.sidebar.image(Image.open('logo_landscape.png'), use_column_width = 'always')
    st.title('Auto-Research')
    st.write('#### A no-code utility to generate a detailed well-cited survey with topic clustered sections' 
             '(draft paper format) and other interesting artifacts from a single research query or a curated set of papers(arxiv ids).')
    st.write('##### Data Provider: arXiv Open Archive Initiative OAI')
    st.write('##### GitHub: https://github.com/sidphbot/Auto-Research')
    st.write(f'Note: this is only a demo on cpu-13GB RAM, hence it supports limited number of papers & only {demo_session_limit} parallel user sessions')
    download_placeholder = st.container()

    with st.sidebar.form(key="survey_keywords_form"):
        session_data = sp.pydantic_input(key="keywords_input_model", model=KeywordsModel)
        st.write('or')
        session_data.update(sp.pydantic_input(key="arxiv_ids_input_model", model=ArxivIDsModel))
        submit = st.form_submit_button(label="Submit")
    st.sidebar.write('#### execution log:')
        
    run_kwargs = {'surveyor':get_surveyor_instance(_print_fn=st.sidebar.write, _survey_print_fn=st.write),
                  'download_placeholder':download_placeholder}
    if submit:
        if st.session_state.session_count < demo_session_limit:
            st.session_state.session_count = st.session_state.session_count + 1
            if session_data['research_keywords'] != '':
                run_kwargs.update({'research_keywords':session_data['research_keywords'], 
                                   'max_search':session_data['max_search'], 
                                   'num_papers':session_data['num_papers']})
            elif session_data['arxiv_ids'] != '':
                run_kwargs.update({'arxiv_ids':[id.strip() for id in session_data['arxiv_ids'].split(',')]})
            print(run_kwargs)
            try:
                run_survey(**run_kwargs)
            except e:
                st.write(f'ERROR: {str(e)}, server might be crowded right now, please try later, thank you for your patience')
                pass
        else:
            st.write(f'{st.session_state.session_count} sessions running, this is a demo and only supports 2 parallel sessions, \n please try in sometime')