Amsterdam-event / app.py
mwitiderrick's picture
Update app.py
a9ee327 verified
import langchain
from langchain.llms import DeepSparse
from langchain.document_loaders import TextLoader, DirectoryLoader
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
import chainlit as cl
import os
MODEL_PATH = "hf:neuralmagic/mpt-7b-chat-pruned50-quant"
generation_config = {"max_new_tokens":200}
llm = DeepSparse(model=MODEL_PATH, generation_config=generation_config)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'})
text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100)
DATA_PATH = "documents"
@cl.on_chat_start
async def init():
loader = DirectoryLoader(DATA_PATH,
glob='*.txt',
loader_cls=TextLoader)
documents = loader.load()
texts = text_splitter.split_documents(documents)
docsearch = Chroma.from_documents(texts, embeddings)
chain = RetrievalQA.from_chain_type(
llm,
chain_type="stuff",
return_source_documents=True,
retriever=docsearch.as_retriever(search_kwargs={"k": 1}),
)
# Save the metadata and texts in the user session
metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))]
cl.user_session.set("metadatas", metadatas)
cl.user_session.set("texts", texts)
cl.user_session.set("chain", chain)
@cl.on_message
async def main(message: cl.Message):
chain = cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["HELPFUL", "ANSWER"]
)
cb.answer_reached = True
res = await chain.acall(message.content, callbacks=[cb])
answer = res["result"]
source_documents = res["source_documents"]
source_elements = []
# Get the metadata and texts from the user session
metadatas = cl.user_session.get("metadatas")
all_sources = [m["source"] for m in metadatas]
texts = cl.user_session.get("texts")
if source_documents:
found_sources = []
# Add the sources to the message
for source_idx, source in enumerate(source_documents):
# Get the index of the source
source_name = f"source_{source_idx}"
found_sources.append(source_name)
# Create the text element referenced in the message
source_elements.append(cl.Text(content=str(source.page_content).strip(), name=source_name))
if found_sources:
answer += f"\nSources: {', '.join(found_sources)}"
else:
answer += "\nNo sources found"
if cb.has_streamed_final_answer:
cb.final_stream.content = answer
cb.final_stream.elements = source_elements
await cb.final_stream.update()
else:
await cl.Message(content=answer,
elements=source_elements
).send()