Spaces:
Sleeping
Sleeping
import chainlit as cl | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
from langchain.chat_models import ChatOpenAI | |
from langchain.storage import LocalFileStore | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
) | |
# Initialize text splitter and other settings | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
system_template = """ | |
Use the following pieces of context to answer the user's question. If the question cannot be answered with the supplied context, simply answer "I cannot determine this based on the provided context." | |
---------------- | |
{context}""" | |
messages = [ | |
SystemMessagePromptTemplate.from_template(system_template), | |
HumanMessagePromptTemplate.from_template("{question}"), | |
] | |
prompt = ChatPromptTemplate(messages=messages) | |
chain_type_kwargs = {"prompt": prompt} | |
def rename(orig_author: str): | |
rename_dict = {"RetrievalQA": "PageTurn"} | |
return rename_dict.get(orig_author, orig_author) | |
# Initialize the index and other setup | |
async def init(): | |
msg = cl.Message(content=f"Building Index...") | |
await msg.send() | |
with open('./data/aerodynamic_drag.txt', 'r', encoding='Windows-1252') as f: | |
aerodynamic_drag_data = f.read() | |
documents = text_splitter.create_documents([aerodynamic_drag_data]) | |
store = LocalFileStore("./cache/") | |
core_embeddings_model = OpenAIEmbeddings() | |
embedder = CacheBackedEmbeddings.from_bytes_store( | |
core_embeddings_model, store, namespace=core_embeddings_model.model | |
) | |
docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder) | |
chain = RetrievalQA.from_chain_type( | |
ChatOpenAI(model="gpt-4", temperature=0, streaming=True), | |
chain_type="stuff", | |
return_source_documents=True, | |
retriever=docsearch.as_retriever(), | |
chain_type_kwargs={"prompt": prompt} | |
) | |
msg.content = f"Index built!" | |
await msg.send() | |
cl.user_session.set("chain", chain) | |
# Main function to handle incoming queries | |
async def main(message): | |
chain = cl.user_session.get("chain") | |
cb = cl.AsyncLangchainCallbackHandler( | |
stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"] | |
) | |
cb.answer_reached = True | |
res = await chain.acall(message, callbacks=[cb], ) | |
answer = res["result"] | |
await cl.Message(content=answer).send() | |