File size: 6,581 Bytes
d98144d
 
 
 
 
 
 
f6a270c
 
d98144d
 
 
 
f6a270c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d98144d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d7fd94
 
 
 
 
 
 
 
 
 
d98144d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6a270c
 
d98144d
 
 
 
 
 
 
 
 
 
 
 
 
 
89e2748
d98144d
 
 
 
 
 
 
5d7fd94
 
 
 
aea0dbd
f6a270c
4600a74
5d7fd94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d98144d
 
 
 
 
 
 
 
 
 
72d46f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import pickle
from typing import Optional, Tuple
import gradio as gr
from threading import Lock

from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.chains import ChatVectorDBChain, ConversationalRetrievalChain
from template import QA_PROMPT, CONDENSE_QUESTION_PROMPT
from pdf2vectorstore import convert_to_vectorstore

def get_chain(api_key, vectorstore, model_name):
    if model_name == "gpt-4":
        llm = ChatOpenAI(model_name = model_name, temperature=0,  openai_api_key=api_key)
        retriever = vectorstore.as_retriever()
        retriever.search_kwargs['distance_metric'] = 'cos'
        retriever.search_kwargs['fetch_k'] = 100
        retriever.search_kwargs['maximal_marginal_relevance'] = True
        retriever.search_kwargs['k'] = 10
        qa_chain = ConversationalRetrievalChain.from_llm(
            llm,
            retriever,
            qa_prompt=QA_PROMPT,
            condense_question_prompt=CONDENSE_QUESTION_PROMPT,
        )
        return qa_chain
    else:
        llm = OpenAI(model_name = model_name, temperature=0,  openai_api_key=api_key)
        qa_chain = ChatVectorDBChain.from_llm(
            llm,
            vectorstore,
            qa_prompt=QA_PROMPT,
            condense_question_prompt=CONDENSE_QUESTION_PROMPT,
        )
        return qa_chain

def set_openai_api_key(api_key: str, vectorstore, model_name: str):
    if api_key:
        chain = get_chain(api_key, vectorstore, model_name) 
        return chain

class ChatWrapper:

    def __init__(self):
        self.lock = Lock()
        self.previous_url = ""
        self.vectorstore_state = None
        self.chain = None

    def __call__(
        self, 
        api_key: str, 
        arxiv_url: str, 
        inp: str, 
        history: Optional[Tuple[str, str]],
        model_name: str,
    ):
        if not arxiv_url or not api_key:
            history = history or []
            history.append((inp, "Please provide both arXiv URL and API key to begin"))
            return history, history

        if arxiv_url != self.previous_url:
            history = []
            vectorstore = convert_to_vectorstore(arxiv_url, api_key)
            self.previous_url = arxiv_url
            self.chain  = set_openai_api_key(api_key, vectorstore, model_name)
            self.vectorstore_state = vectorstore
        
        if self.chain  is None:
            self.chain  = set_openai_api_key(api_key, self.vectorstore_state, model_name)
        
        self.lock.acquire()
        try:
            history = history or []
            if self.chain  is None:
                history.append((inp, "Please paste your OpenAI key to use"))
                return history, history
            import openai
            openai.api_key = api_key
            output = self.chain ({"question": inp, "chat_history": history})["answer"]
            history.append((inp, output))
        except Exception as e:
            raise e
        finally:
            api_key = ""
            self.lock.release()
        return history, history

chat = ChatWrapper()

block = gr.Blocks(css=".gradio-container {background-color: #f8f8f8; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif}")

with block:
    gr.HTML("""
    <style>
        body {
            background-color: #f5f5f5;
            font-family: 'Roboto', sans-serif;
            padding: 30px;
        }
    </style>
    """)
    
    gr.HTML("<h1 style='text-align: center;'>ArxivGPT</h1>")
    gr.HTML("<h3 style='text-align: center;'>Ask questions about research papers</h3>")
    
    with gr.Row():
        with gr.Column(width="auto"):
            openai_api_key_textbox = gr.Textbox(
                label="OpenAI API Key",
                placeholder="Paste your OpenAI API key (sk-...)",
                show_label=True,
                lines=1,
                type="password",
            )
        with gr.Column(width="auto"):
            arxiv_url_textbox = gr.Textbox(
                label="Arxiv URL",
                placeholder="Enter the arXiv URL",
                show_label=True,
                lines=1,
            )
        with gr.Column(width="auto"):
            model_dropdown = gr.Dropdown(
                label="Choose a model",
                choices=["gpt-3.5-turbo", "gpt-4"],
            )

    chatbot = gr.Chatbot()

    with gr.Row():
        message = gr.Textbox(
            label="What's your question?",
            placeholder="Ask questions about the paper you just linked",
            lines=1,
        )
        submit = gr.Button(value="Send", variant="secondary").style(full_width=False)

    gr.Examples(
        examples=[
            "What's this paper about?",
            "Please give me a brief summary about this paper",
            "Are there any interesting correlations in the given paper?",
            "How can this paper be applied in the real world?",
            "What are the limitations of this paper?",
        ],
        inputs=message,
    )
    gr.HTML("""
            <div style="text-align:center">
                <p>Developed by <a href='https://www.linkedin.com/in/dekay/'>Github and Huggingface: Volkopat</a></p>
                <p>Powered by <a href='https://openai.com/'>OpenAI</a>, <a href='https://arxiv.org/'>arXiv</a> and <a href='https://github.com/hwchase17/langchain'>LangChain 🦜️🔗</a></p>
                <p>ArxivGPT is a chatbot that answers questions about research papers. It uses a pretrained GPT-4 model to generate answers.</p>
                <p>If you don't get a response for GPT-4, it is likely that you don't have API access, try 3.5</p>
                <p>It can take upto a minute for you to start a conversation for every new paper. Please be patient.</p>
            </div>
            <style>
                p {
                    margin-bottom: 10px;
                    font-size: 16px;
                }
                a {
                    color: #3867d6;
                    text-decoration: none;
                }
                a:hover {
                    text-decoration: underline;
                }
            </style>
            """)

    state = gr.State()

    submit.click(chat, 
                 inputs=[openai_api_key_textbox, arxiv_url_textbox, message, state, model_dropdown], 
                 outputs=[chatbot, state])
    message.submit(chat, 
                   inputs=[openai_api_key_textbox, arxiv_url_textbox, message, state, model_dropdown], 
                   outputs=[chatbot, state])

block.launch(width=800)