import langchain from langchain.llms import DeepSparse from langchain.document_loaders import TextLoader, DirectoryLoader from langchain.vectorstores import Chroma from langchain.chains import RetrievalQA from langchain.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter import chainlit as cl import os MODEL_PATH = "hf:neuralmagic/mpt-7b-chat-pruned50-quant" generation_config = {"max_new_tokens":200} llm = DeepSparse(model=MODEL_PATH, generation_config=generation_config) embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'}) text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100) DATA_PATH = "documents" @cl.on_chat_start async def init(): loader = DirectoryLoader(DATA_PATH, glob='*.txt', loader_cls=TextLoader) documents = loader.load() texts = text_splitter.split_documents(documents) docsearch = Chroma.from_documents(texts, embeddings) chain = RetrievalQA.from_chain_type( llm, chain_type="stuff", return_source_documents=True, retriever=docsearch.as_retriever(search_kwargs={"k": 1}), ) # Save the metadata and texts in the user session metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))] cl.user_session.set("metadatas", metadatas) cl.user_session.set("texts", texts) cl.user_session.set("chain", chain) @cl.on_message async def main(message: cl.Message): chain = cl.user_session.get("chain") cb = cl.AsyncLangchainCallbackHandler( stream_final_answer=True, answer_prefix_tokens=["HELPFUL", "ANSWER"] ) cb.answer_reached = True res = await chain.acall(message.content, callbacks=[cb]) answer = res["result"] source_documents = res["source_documents"] source_elements = [] # Get the metadata and texts from the user session metadatas = cl.user_session.get("metadatas") all_sources = [m["source"] for m in metadatas] texts = cl.user_session.get("texts") if source_documents: found_sources = [] # Add the sources to the message for source_idx, source in enumerate(source_documents): # Get the index of the source source_name = f"source_{source_idx}" found_sources.append(source_name) # Create the text element referenced in the message source_elements.append(cl.Text(content=str(source.page_content).strip(), name=source_name)) if found_sources: answer += f"\nSources: {', '.join(found_sources)}" else: answer += "\nNo sources found" if cb.has_streamed_final_answer: cb.final_stream.content = answer cb.final_stream.elements = source_elements await cb.final_stream.update() else: await cl.Message(content=answer, elements=source_elements ).send()