PageTurn / app.py
Jason Caro
Update app.py
4263541
import chainlit as cl
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import CacheBackedEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.storage import LocalFileStore
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
# Initialize text splitter and other settings
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
system_template = """
Use the following pieces of context to answer the user's question. If the question cannot be answered with the supplied context, simply answer "I cannot determine this based on the provided context."
----------------
{context}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate(messages=messages)
chain_type_kwargs = {"prompt": prompt}
@cl.author_rename
def rename(orig_author: str):
rename_dict = {"RetrievalQA": "PageTurn"}
return rename_dict.get(orig_author, orig_author)
# Initialize the index and other setup
@cl.on_chat_start
async def init():
msg = cl.Message(content=f"Building Index...")
await msg.send()
with open('./data/aerodynamic_drag.txt', 'r', encoding='Windows-1252') as f:
aerodynamic_drag_data = f.read()
documents = text_splitter.create_documents([aerodynamic_drag_data])
store = LocalFileStore("./cache/")
core_embeddings_model = OpenAIEmbeddings()
embedder = CacheBackedEmbeddings.from_bytes_store(
core_embeddings_model, store, namespace=core_embeddings_model.model
)
docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)
chain = RetrievalQA.from_chain_type(
ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
chain_type="stuff",
return_source_documents=True,
retriever=docsearch.as_retriever(),
chain_type_kwargs={"prompt": prompt}
)
msg.content = f"Index built!"
await msg.send()
cl.user_session.set("chain", chain)
# Main function to handle incoming queries
@cl.on_message
async def main(message):
chain = cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
res = await chain.acall(message, callbacks=[cb], )
answer = res["result"]
await cl.Message(content=answer).send()