ATDCChatbot / app.py
CSAle's picture
Adding ATDC App
a1b25e5
import chainlit as cl
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.document_loaders import BSHTMLLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
import chainlit as cl
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
@cl.author_rename
def rename(orig_author: str):
rename_dict = {"RetrievalQA": "ATDC Website", "Chatbot": "The ATDC Webpage"}
return rename_dict.get(orig_author, orig_author)
@cl.on_chat_start
async def init():
msg = cl.Message(content=f"Building Index...")
await msg.send()
core_embeddings_model = OpenAIEmbeddings()
new_db = FAISS.load_local("faiss_index", core_embeddings_model)
chain = RetrievalQA.from_chain_type(
ChatOpenAI(model="gpt-3.5-turbo", temperature=0, streaming=True),
chain_type="stuff",
return_source_documents=True,
retriever=new_db.as_retriever(),
)
msg.content = f"Index built!"
await msg.send()
cl.user_session.set("chain", chain)
@cl.on_message
async def main(message):
chain = cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
res = await chain.acall(message, callbacks=[cb])
answer = res["result"]
if cb.has_streamed_final_answer:
await cb.final_stream.update()
else:
await cl.Message(content=answer).send()