File size: 7,155 Bytes
d98144d f6a270c d98144d f6a270c d98144d 5d7fd94 d98144d f6a270c d98144d 89e2748 d98144d 5d7fd94 a58f539 a27f388 45a1acf c092976 f6a270c 5d7fd94 d98144d 72d46f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import os
import pickle
from typing import Optional, Tuple
import gradio as gr
from threading import Lock
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.chains import ChatVectorDBChain, ConversationalRetrievalChain
from template import QA_PROMPT, CONDENSE_QUESTION_PROMPT
from pdf2vectorstore import convert_to_vectorstore
def get_chain(api_key, vectorstore, model_name):
if model_name == "gpt-4":
llm = ChatOpenAI(model_name = model_name, temperature=0, openai_api_key=api_key)
retriever = vectorstore.as_retriever()
retriever.search_kwargs['distance_metric'] = 'cos'
retriever.search_kwargs['fetch_k'] = 100
retriever.search_kwargs['maximal_marginal_relevance'] = True
retriever.search_kwargs['k'] = 10
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever,
qa_prompt=QA_PROMPT,
condense_question_prompt=CONDENSE_QUESTION_PROMPT,
)
return qa_chain
else:
llm = OpenAI(model_name = model_name, temperature=0, openai_api_key=api_key)
qa_chain = ChatVectorDBChain.from_llm(
llm,
vectorstore,
qa_prompt=QA_PROMPT,
condense_question_prompt=CONDENSE_QUESTION_PROMPT,
)
return qa_chain
def set_openai_api_key(api_key: str, vectorstore, model_name: str):
if api_key:
chain = get_chain(api_key, vectorstore, model_name)
return chain
class ChatWrapper:
def __init__(self):
self.lock = Lock()
self.previous_url = ""
self.vectorstore_state = None
self.chain = None
def __call__(
self,
api_key: str,
arxiv_url: str,
inp: str,
history: Optional[Tuple[str, str]],
model_name: str,
):
if not arxiv_url or not api_key:
history = history or []
history.append((inp, "Please provide both arXiv URL and API key to begin"))
return history, history
if arxiv_url != self.previous_url:
history = []
vectorstore = convert_to_vectorstore(arxiv_url, api_key)
self.previous_url = arxiv_url
self.chain = set_openai_api_key(api_key, vectorstore, model_name)
self.vectorstore_state = vectorstore
if self.chain is None:
self.chain = set_openai_api_key(api_key, self.vectorstore_state, model_name)
self.lock.acquire()
try:
history = history or []
if self.chain is None:
history.append((inp, "Please paste your OpenAI key to use"))
return history, history
import openai
openai.api_key = api_key
output = self.chain ({"question": inp, "chat_history": history})["answer"]
history.append((inp, output))
except Exception as e:
raise e
finally:
api_key = ""
self.lock.release()
return history, history
chat = ChatWrapper()
block = gr.Blocks(css=".gradio-container {background-color: #f8f8f8; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif}")
with block:
gr.HTML("""
<style>
body {
background-color: #f5f5f5;
font-family: 'Roboto', sans-serif;
padding: 30px;
}
</style>
""")
gr.HTML("<h1 style='text-align: center;'>ArxivGPT</h1>")
gr.HTML("<h3 style='text-align: center;'>Ask questions about research papers</h3>")
with gr.Row():
with gr.Column(width="auto"):
openai_api_key_textbox = gr.Textbox(
label="OpenAI API Key",
placeholder="Paste your OpenAI API key (sk-...)",
show_label=True,
lines=1,
type="password",
)
with gr.Column(width="auto"):
arxiv_url_textbox = gr.Textbox(
label="Arxiv URL",
placeholder="Enter the arXiv URL",
show_label=True,
lines=1,
)
with gr.Column(width="auto"):
model_dropdown = gr.Dropdown(
label="Choose a model",
choices=["gpt-3.5-turbo", "gpt-4"],
)
chatbot = gr.Chatbot()
with gr.Row():
message = gr.Textbox(
label="What's your question?",
placeholder="Ask questions about the paper you just linked",
lines=1,
)
submit = gr.Button(value="Send", variant="secondary").style(full_width=False)
gr.Examples(
examples=[
"What's this paper about?",
"Please give me a brief summary about this paper",
"Are there any interesting correlations in the given paper?",
"How can this paper be applied in the real world?",
"What are the limitations of this paper?",
],
inputs=message,
)
gr.HTML("""
<div style="text-align:center">
<p>Developed by <a href='https://www.linkedin.com/in/dekay/'>Github and Huggingface: Volkopat</a></p>
<p>Powered by <a href='https://openai.com/'>OpenAI</a>, <a href='https://arxiv.org/'>arXiv</a> and <a href='https://github.com/hwchase17/langchain'>LangChain 🦜️🔗</a></p>
<p>ArxivGPT is a chatbot that answers questions about research papers. It uses a pretrained GPT-3.5 model to generate answers.</p>
<p>Currently, it can answer questions about the paper you just linked.</p>
<p>It's still in development, so please report any bugs you find. </p>
<p>It can take up to a minute to start a conversation for every new paper as this is just a demo hosted on a lightweight service.</p>
<p>For best results, test it on better hardware. Took 20 seconds to start on M1 Chip</p>
<p>The answers can be quite limited as there is a 4096 token limit for GPT-3.5, hence wait for GPT-4 access for better quality.</p>
<p>If you don't get a response for GPT-4, it is likely that you don't have API access, try 3.5</p>
<p>Possible upgrades coming up: faster parsing, status messages, other research paper hubs.</p>
</div>
<style>
p {
margin-bottom: 10px;
font-size: 16px;
}
a {
color: #3867d6;
text-decoration: none;
}
a:hover {
text-decoration: underline;
}
</style>
""")
state = gr.State()
submit.click(chat,
inputs=[openai_api_key_textbox, arxiv_url_textbox, message, state, model_dropdown],
outputs=[chatbot, state])
message.submit(chat,
inputs=[openai_api_key_textbox, arxiv_url_textbox, message, state, model_dropdown],
outputs=[chatbot, state])
block.launch(width=800) |