ChatitoArXiv / app.py
RubenAMtz's picture
changed chain and system prompt
01e3f20
# You can find this code for Chainlit python streaming here (https://docs.chainlit.io/concepts/streaming/python)
# OpenAI Chat completion
import os
import chainlit as cl # importing chainlit for our app
from chainlit.prompt import Prompt, PromptMessage # importing prompt tools
from chainlit.playground.providers import ChatOpenAI # importing ChatOpenAI tools
from dotenv import load_dotenv
import arxiv
import pinecone
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings.azure_openai import AzureOpenAIEmbeddings
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore, InMemoryStore
from utils.store import index_documents, search_and_index
from utils.chain import create_chain
from langchain.vectorstores import Pinecone
from langchain.schema.runnable import RunnableSequence
from langchain.schema import format_document
from pprint import pprint
from langchain_core.vectorstores import VectorStoreRetriever
import langchain
from langchain.cache import InMemoryCache
from langchain.memory import ConversationBufferMemory
load_dotenv()
YOUR_API_KEY = os.environ["PINECONE_API_KEY"]
YOUR_ENV = os.environ["PINECONE_ENV"]
INDEX_NAME= 'arxiv-paper-index'
WANDB_API_KEY=os.environ["WANDB_API_KEY"]
WANDB_PROJECT=os.environ["WANDB_PROJECT"]
@cl.on_chat_start # marks a function that will be executed at the start of a user session
async def start_chat():
settings = {
"model": "gpt-3.5-turbo",
"temperature": 0,
"max_tokens": 500
}
await cl.Message(
content="Hi, I am here to help you learn about a topic, what would you like to learn about today? 😊"
).send()
# create an embedder through a cache interface (locally) (on start)
store = InMemoryStore()
core_embeddings_model = AzureOpenAIEmbeddings(
api_key=os.environ['AZURE_OPENAI_API_KEY'],
azure_deployment="text-embedding-ada-002",
azure_endpoint=os.environ['AZURE_OPENAI_ENDPOINT']
)
embedder = CacheBackedEmbeddings.from_bytes_store(
underlying_embeddings=core_embeddings_model,
document_embedding_cache=store,
namespace=core_embeddings_model.model
)
# instantiate pinecone (on start)
pinecone.init(
api_key=YOUR_API_KEY,
environment=YOUR_ENV
)
if INDEX_NAME not in pinecone.list_indexes():
pinecone.create_index(
name=INDEX_NAME,
metric='cosine',
dimension=1536
)
index = pinecone.GRPCIndex(INDEX_NAME)
llm = AzureChatOpenAI(
temperature=settings['temperature'],
max_tokens=settings['max_tokens'],
api_key=os.environ['AZURE_OPENAI_API_KEY'],
azure_deployment="gpt-35-turbo-16k",
api_version="2023-07-01-preview",
streaming=True
)
# create a prompt cache (locally) (on start)
langchain.llm_cache = InMemoryCache()
# log data in WaB (on start)
os.environ["WANDB_MODE"] = "disabled"
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
# setup memory
memory = ConversationBufferMemory(memory_key="chat_history")
tools = {
"index": index,
"embedder": embedder,
"llm": llm,
"memory": memory
}
cl.user_session.set("tools", tools)
cl.user_session.set("settings", settings)
cl.user_session.set("first_run", False)
@cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
async def main(message: cl.Message):
settings = cl.user_session.get("settings")
tools: dict = cl.user_session.get("tools")
first_run = cl.user_session.get("first_run")
retrieval_augmented_qa_chain = cl.user_session.get("chain", None)
memory: ConversationBufferMemory = cl.user_session.get("memory")
sys_message = cl.Message(content="")
await sys_message.send() # renders a loader
if not first_run:
index: pinecone.GRPCIndex = tools['index']
embedder: CacheBackedEmbeddings = tools['embedder']
llm: ChatOpenAI = tools['llm']
memory: ConversationBufferMemory = tools['memory']
# using query search for ArXiv documents and index files(on message)
await cl.make_async(search_and_index)(message=message, quantity=10, embedder=embedder, index=index)
text_field = "source_document"
index = pinecone.Index(INDEX_NAME)
vectorstore = Pinecone(
index=index,
embedding=embedder.embed_query,
text_key=text_field
)
retriever: VectorStoreRetriever = vectorstore.as_retriever()
# create the chain (on message)
retrieval_augmented_qa_chain: RunnableSequence = create_chain(retriever=retriever, llm=llm)
cl.user_session.set("chain", retrieval_augmented_qa_chain)
sys_message.content = """
I found some papers and studied them πŸ˜‰ \n"""
await sys_message.update()
# run
async for chunk in retrieval_augmented_qa_chain.astream({"question": f"{message.content}", "chat_history": memory.buffer_as_messages}):
if res:= chunk.get('response'):
await sys_message.stream_token(res.content)
if chunk.get("context"):
pprint(chunk.get("context"))
await sys_message.send()
memory.chat_memory.add_user_message(message.content)
memory.chat_memory.add_ai_message(sys_message.content)
print(memory.buffer_as_str)
cl.user_session.set("memory", memory)
cl.user_session.set("first_run", True)