Spaces:
Sleeping
Sleeping
# You can find this code for Chainlit python streaming here (https://docs.chainlit.io/concepts/streaming/python) | |
# OpenAI Chat completion | |
import os | |
import chainlit as cl # importing chainlit for our app | |
from chainlit.prompt import Prompt, PromptMessage # importing prompt tools | |
from chainlit.playground.providers import ChatOpenAI # importing ChatOpenAI tools | |
from dotenv import load_dotenv | |
import arxiv | |
import pinecone | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.embeddings.azure_openai import AzureOpenAIEmbeddings | |
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain.storage import LocalFileStore, InMemoryStore | |
from utils.store import index_documents, search_and_index | |
from utils.chain import create_chain | |
from langchain.vectorstores import Pinecone | |
from langchain.schema.runnable import RunnableSequence | |
from langchain.schema import format_document | |
from pprint import pprint | |
from langchain_core.vectorstores import VectorStoreRetriever | |
import langchain | |
from langchain.cache import InMemoryCache | |
from langchain.memory import ConversationBufferMemory | |
load_dotenv() | |
YOUR_API_KEY = os.environ["PINECONE_API_KEY"] | |
YOUR_ENV = os.environ["PINECONE_ENV"] | |
INDEX_NAME= 'arxiv-paper-index' | |
WANDB_API_KEY=os.environ["WANDB_API_KEY"] | |
WANDB_PROJECT=os.environ["WANDB_PROJECT"] | |
# marks a function that will be executed at the start of a user session | |
async def start_chat(): | |
settings = { | |
"model": "gpt-3.5-turbo", | |
"temperature": 0, | |
"max_tokens": 500 | |
} | |
await cl.Message( | |
content="Hi, I am here to help you learn about a topic, what would you like to learn about today? π" | |
).send() | |
# create an embedder through a cache interface (locally) (on start) | |
store = InMemoryStore() | |
core_embeddings_model = AzureOpenAIEmbeddings( | |
api_key=os.environ['AZURE_OPENAI_API_KEY'], | |
azure_deployment="text-embedding-ada-002", | |
azure_endpoint=os.environ['AZURE_OPENAI_ENDPOINT'] | |
) | |
embedder = CacheBackedEmbeddings.from_bytes_store( | |
underlying_embeddings=core_embeddings_model, | |
document_embedding_cache=store, | |
namespace=core_embeddings_model.model | |
) | |
# instantiate pinecone (on start) | |
pinecone.init( | |
api_key=YOUR_API_KEY, | |
environment=YOUR_ENV | |
) | |
if INDEX_NAME not in pinecone.list_indexes(): | |
pinecone.create_index( | |
name=INDEX_NAME, | |
metric='cosine', | |
dimension=1536 | |
) | |
index = pinecone.GRPCIndex(INDEX_NAME) | |
llm = AzureChatOpenAI( | |
temperature=settings['temperature'], | |
max_tokens=settings['max_tokens'], | |
api_key=os.environ['AZURE_OPENAI_API_KEY'], | |
azure_deployment="gpt-35-turbo-16k", | |
api_version="2023-07-01-preview", | |
streaming=True | |
) | |
# create a prompt cache (locally) (on start) | |
langchain.llm_cache = InMemoryCache() | |
# log data in WaB (on start) | |
os.environ["WANDB_MODE"] = "disabled" | |
os.environ["LANGCHAIN_WANDB_TRACING"] = "true" | |
# setup memory | |
memory = ConversationBufferMemory(memory_key="chat_history") | |
tools = { | |
"index": index, | |
"embedder": embedder, | |
"llm": llm, | |
"memory": memory | |
} | |
cl.user_session.set("tools", tools) | |
cl.user_session.set("settings", settings) | |
cl.user_session.set("first_run", False) | |
# marks a function that should be run each time the chatbot receives a message from a user | |
async def main(message: cl.Message): | |
settings = cl.user_session.get("settings") | |
tools: dict = cl.user_session.get("tools") | |
first_run = cl.user_session.get("first_run") | |
retrieval_augmented_qa_chain = cl.user_session.get("chain", None) | |
memory: ConversationBufferMemory = cl.user_session.get("memory") | |
sys_message = cl.Message(content="") | |
await sys_message.send() # renders a loader | |
if not first_run: | |
index: pinecone.GRPCIndex = tools['index'] | |
embedder: CacheBackedEmbeddings = tools['embedder'] | |
llm: ChatOpenAI = tools['llm'] | |
memory: ConversationBufferMemory = tools['memory'] | |
# using query search for ArXiv documents and index files(on message) | |
await cl.make_async(search_and_index)(message=message, quantity=10, embedder=embedder, index=index) | |
text_field = "source_document" | |
index = pinecone.Index(INDEX_NAME) | |
vectorstore = Pinecone( | |
index=index, | |
embedding=embedder.embed_query, | |
text_key=text_field | |
) | |
retriever: VectorStoreRetriever = vectorstore.as_retriever() | |
# create the chain (on message) | |
retrieval_augmented_qa_chain: RunnableSequence = create_chain(retriever=retriever, llm=llm) | |
cl.user_session.set("chain", retrieval_augmented_qa_chain) | |
sys_message.content = """ | |
I found some papers and studied them π \n""" | |
await sys_message.update() | |
# run | |
async for chunk in retrieval_augmented_qa_chain.astream({"question": f"{message.content}", "chat_history": memory.buffer_as_messages}): | |
if res:= chunk.get('response'): | |
await sys_message.stream_token(res.content) | |
if chunk.get("context"): | |
pprint(chunk.get("context")) | |
await sys_message.send() | |
memory.chat_memory.add_user_message(message.content) | |
memory.chat_memory.add_ai_message(sys_message.content) | |
print(memory.buffer_as_str) | |
cl.user_session.set("memory", memory) | |
cl.user_session.set("first_run", True) |