ChatitoArXiv / app.py
RubenAMtz's picture
added chat memory and fixed sys_message bugs
dd49b84
raw
history blame
6.34 kB
# You can find this code for Chainlit python streaming here (https://docs.chainlit.io/concepts/streaming/python)
# OpenAI Chat completion
import os
import chainlit as cl # importing chainlit for our app
from chainlit.prompt import Prompt, PromptMessage # importing prompt tools
from chainlit.playground.providers import ChatOpenAI # importing ChatOpenAI tools
from dotenv import load_dotenv
import arxiv
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pinecone
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from utils.store import index_documents
from utils.chain import create_chain
from langchain.vectorstores import Pinecone
from langchain.chat_models import ChatOpenAI
from langchain.schema.runnable import RunnableSequence
from langchain.schema import format_document
from pprint import pprint
from langchain_core.vectorstores import VectorStoreRetriever
import langchain
from langchain.cache import InMemoryCache
from langchain_core.messages.human import HumanMessage
from langchain.memory import ConversationBufferMemory
load_dotenv()
YOUR_API_KEY = os.environ["PINECONE_API_KEY"]
YOUR_ENV = os.environ["PINECONE_ENV"]
INDEX_NAME= 'arxiv-paper-index'
WANDB_API_KEY=os.environ["WANDB_API_KEY"]
WANDB_PROJECT=os.environ["WANDB_PROJECT"]
first_run = False
@cl.on_chat_start # marks a function that will be executed at the start of a user session
async def start_chat():
settings = {
"model": "gpt-3.5-turbo",
"temperature": 0,
"max_tokens": 500
}
await cl.Message(
content="What would you like to learn about today? 😊"
).send()
# instantiate arXiv client (on start)
arxiv_client = arxiv.Client()
# create an embedder through a cache interface (locally) (on start)
store = LocalFileStore("./cache/")
core_embeddings_model = OpenAIEmbeddings(
api_key=os.environ['OPENAI_API_KEY']
)
embedder = CacheBackedEmbeddings.from_bytes_store(
underlying_embeddings=core_embeddings_model,
document_embedding_cache=store,
namespace=core_embeddings_model.model
)
# instantiate pinecone (on start)
pinecone.init(
api_key=YOUR_API_KEY,
environment=YOUR_ENV
)
if INDEX_NAME not in pinecone.list_indexes():
pinecone.create_index(
name=INDEX_NAME,
metric='cosine',
dimension=1536
)
index = pinecone.GRPCIndex(INDEX_NAME)
# setup your ChatOpenAI model (on start)
llm = ChatOpenAI(
model=settings['model'],
temperature=settings['temperature'],
max_tokens=settings['max_tokens'],
api_key=os.environ["OPENAI_API_KEY"],
streaming=True
)
# create a prompt cache (locally) (on start)
langchain.llm_cache = InMemoryCache()
# log data in WaB (on start)
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
# setup memory
memory = ConversationBufferMemory(memory_key="chat_history")
tools = {
"arxiv_client": arxiv_client,
"index": index,
"embedder": embedder,
"llm": llm,
"memory": memory
}
cl.user_session.set("tools", tools)
cl.user_session.set("settings", settings)
cl.user_session.set("first_run", False)
@cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
async def main(message: cl.Message):
settings = cl.user_session.get("settings")
tools: dict = cl.user_session.get("tools")
first_run = cl.user_session.get("first_run")
retrieval_augmented_qa_chain = cl.user_session.get("chain", None)
memory: ConversationBufferMemory = cl.user_session.get("memory")
sys_message = cl.Message(content="")
await sys_message.send() # renders a loader
if not first_run:
arxiv_client: arxiv.Client = tools['arxiv_client']
index: pinecone.GRPCIndex = tools['index']
embedder: CacheBackedEmbeddings = tools['embedder']
llm: ChatOpenAI = tools['llm']
memory: ConversationBufferMemory = tools['memory']
# using query search for ArXiv documents (on message)
search = arxiv.Search(
query = message.content,
max_results = 10,
sort_by = arxiv.SortCriterion.Relevance
)
paper_urls = []
for result in arxiv_client.results(search):
paper_urls.append(result.pdf_url)
# load them and split them (on message)
docs = []
for paper_url in paper_urls:
try:
loader = PyPDFLoader(paper_url)
docs.append(loader.load())
except:
print(f"Error loading {paper_url}")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 400,
chunk_overlap = 30,
length_function = len
)
# create an index using pinecone (on message)
index_documents(docs, text_splitter, embedder, index)
text_field = "source_document"
index = pinecone.Index(INDEX_NAME)
vectorstore = Pinecone(
index=index,
embedding=embedder.embed_query,
text_key=text_field
)
retriever: VectorStoreRetriever = vectorstore.as_retriever()
# create the chain (on message)
retrieval_augmented_qa_chain: RunnableSequence = create_chain(retriever=retriever, llm=llm)
cl.user_session.set("chain", retrieval_augmented_qa_chain)
sys_message.content = """
I found some papers and studied them πŸ˜‰ \n"""
await sys_message.update()
# run
for chunk in retrieval_augmented_qa_chain.stream({"question": f"{message.content}", "chat_history": memory.buffer_as_messages}):
pprint(chunk)
if res:= chunk.get('response'):
await sys_message.stream_token(res.content)
await sys_message.send()
memory.chat_memory.add_user_message(message.content)
memory.chat_memory.add_ai_message(sys_message.content)
print(memory.buffer_as_str)
cl.user_session.set("memory", memory)
cl.user_session.set("first_run", True)