Spaces:
Sleeping
Sleeping
import langchain | |
from langchain.llms import DeepSparse | |
from langchain.document_loaders import TextLoader, DirectoryLoader | |
from langchain.vectorstores import Chroma | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
import chainlit as cl | |
import os | |
MODEL_PATH = "hf:neuralmagic/mpt-7b-chat-pruned50-quant" | |
generation_config = {"max_new_tokens":200} | |
llm = DeepSparse(model=MODEL_PATH, generation_config=generation_config) | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'}) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100) | |
DATA_PATH = "documents" | |
async def init(): | |
loader = DirectoryLoader(DATA_PATH, | |
glob='*.txt', | |
loader_cls=TextLoader) | |
documents = loader.load() | |
texts = text_splitter.split_documents(documents) | |
docsearch = Chroma.from_documents(texts, embeddings) | |
chain = RetrievalQA.from_chain_type( | |
llm, | |
chain_type="stuff", | |
return_source_documents=True, | |
retriever=docsearch.as_retriever(search_kwargs={"k": 1}), | |
) | |
# Save the metadata and texts in the user session | |
metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))] | |
cl.user_session.set("metadatas", metadatas) | |
cl.user_session.set("texts", texts) | |
cl.user_session.set("chain", chain) | |
async def main(message: cl.Message): | |
chain = cl.user_session.get("chain") | |
cb = cl.AsyncLangchainCallbackHandler( | |
stream_final_answer=True, answer_prefix_tokens=["HELPFUL", "ANSWER"] | |
) | |
cb.answer_reached = True | |
res = await chain.acall(message.content, callbacks=[cb]) | |
answer = res["result"] | |
source_documents = res["source_documents"] | |
source_elements = [] | |
# Get the metadata and texts from the user session | |
metadatas = cl.user_session.get("metadatas") | |
all_sources = [m["source"] for m in metadatas] | |
texts = cl.user_session.get("texts") | |
if source_documents: | |
found_sources = [] | |
# Add the sources to the message | |
for source_idx, source in enumerate(source_documents): | |
# Get the index of the source | |
source_name = f"source_{source_idx}" | |
found_sources.append(source_name) | |
# Create the text element referenced in the message | |
source_elements.append(cl.Text(content=str(source.page_content).strip(), name=source_name)) | |
if found_sources: | |
answer += f"\nSources: {', '.join(found_sources)}" | |
else: | |
answer += "\nNo sources found" | |
if cb.has_streamed_final_answer: | |
cb.final_stream.content = answer | |
cb.final_stream.elements = source_elements | |
await cb.final_stream.update() | |
else: | |
await cl.Message(content=answer, | |
elements=source_elements | |
).send() | |