Spaces:
Runtime error
Runtime error
Krishnachaitanya2004
commited on
Commit
•
99cdfe6
1
Parent(s):
0a1dd02
Publish Document Chatbot to Hugging Face
Browse files- document_chatbot.py +122 -0
- requirements.txt +5 -0
document_chatbot.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# !pip install langchain
|
3 |
+
# !pip install sentence-transformers
|
4 |
+
# !pip install accelerate
|
5 |
+
# !pip install chromadb
|
6 |
+
# !pip install "unstructured[all-docs]"
|
7 |
+
|
8 |
+
from langchain.vectorstores import Chroma
|
9 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
10 |
+
from transformers import pipeline
|
11 |
+
import torch
|
12 |
+
from langchain.llms import HuggingFacePipeline
|
13 |
+
from langchain.embeddings import SentenceTransformerEmbeddings
|
14 |
+
from langchain.chains import RetrievalQA
|
15 |
+
from langchain_community.document_loaders import UnstructuredFileLoader
|
16 |
+
from langchain.text_splitter import CharacterTextSplitter
|
17 |
+
import streamlit as st
|
18 |
+
import os
|
19 |
+
|
20 |
+
|
21 |
+
def main_process(uploaded_file):
|
22 |
+
file_name = list(uploaded_file.keys())[0]
|
23 |
+
|
24 |
+
# Create a temporary directory
|
25 |
+
temp_dir = "temp"
|
26 |
+
os.makedirs(temp_dir, exist_ok=True)
|
27 |
+
|
28 |
+
# Save the uploaded file to the temporary directory
|
29 |
+
temp_path = os.path.join(temp_dir, file_name)
|
30 |
+
with open(temp_path, "wb") as temp_file:
|
31 |
+
temp_file.write(uploaded_file[file_name])
|
32 |
+
|
33 |
+
# Process the uploaded file
|
34 |
+
loader = UnstructuredFileLoader(temp_path)
|
35 |
+
documents = loader.load()
|
36 |
+
for document in documents:
|
37 |
+
print(document.page_content)
|
38 |
+
# We cant load the whole pdf into the program so we split the pdf into chunks
|
39 |
+
# We use RecursiveCharacterTextSplitter to split the pdf into chunks
|
40 |
+
# Each chunk is 500 characters long and the chunks overlap by 200 characters (You can change this according to your needs)
|
41 |
+
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=400)
|
42 |
+
texts = text_splitter.split_documents(documents)
|
43 |
+
|
44 |
+
# We use SentenceTransformerEmbeddings to embed the text chunks
|
45 |
+
# Embeddings are used to find the similarity between the query and the text chunks
|
46 |
+
# We use multi-qa-mpnet-base-dot-v1 model to embed the text chunks
|
47 |
+
# We need to save the embeddings to disk so we use persist_directory to save the embeddings to disk
|
48 |
+
embeddings = SentenceTransformerEmbeddings(model_name="multi-qa-mpnet-base-dot-v1")
|
49 |
+
persist_directory = "/content/chroma/"
|
50 |
+
|
51 |
+
# Chroma is used to store the embeddings
|
52 |
+
# We use from_documents to store the embeddings
|
53 |
+
# We use the persist_directory to save the embeddings to disk
|
54 |
+
db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory)
|
55 |
+
|
56 |
+
# To save and load the saved vector db (if needed in the future)
|
57 |
+
# Persist the database to disk
|
58 |
+
# db.persist()
|
59 |
+
# db = Chroma(persist_directory="db", embedding_function=embeddings)
|
60 |
+
|
61 |
+
checkpoint = "MBZUAI/LaMini-Flan-T5-783M"
|
62 |
+
|
63 |
+
# Initialize the tokenizer and base model for text generation
|
64 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
65 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
66 |
+
checkpoint,
|
67 |
+
device_map="auto",
|
68 |
+
torch_dtype=torch.float32
|
69 |
+
)
|
70 |
+
|
71 |
+
pipe = pipeline(
|
72 |
+
'text2text-generation',
|
73 |
+
model = base_model,
|
74 |
+
tokenizer = tokenizer,
|
75 |
+
max_length = 512,
|
76 |
+
do_sample = True,
|
77 |
+
temperature = 0.3,
|
78 |
+
top_p= 0.95
|
79 |
+
)
|
80 |
+
|
81 |
+
# Initialize a local language model pipeline
|
82 |
+
local_llm = HuggingFacePipeline(pipeline=pipe)
|
83 |
+
# Create a RetrievalQA chain
|
84 |
+
qa_chain = RetrievalQA.from_chain_type(
|
85 |
+
llm=local_llm,
|
86 |
+
chain_type='stuff',
|
87 |
+
retriever=db.as_retriever(search_type="similarity", search_kwargs={"k": 2}),
|
88 |
+
return_source_documents=True,
|
89 |
+
)
|
90 |
+
return qa_chain
|
91 |
+
|
92 |
+
st.title("Document Chatbot")
|
93 |
+
st.write("Upload a pdf file to get started")
|
94 |
+
|
95 |
+
uploaded_file = st.file_uploader("Choose a file", type=["pdf"])
|
96 |
+
|
97 |
+
if uploaded_file is not None:
|
98 |
+
qa_chain = main_process(uploaded_file)
|
99 |
+
if "messages" not in st.session_state:
|
100 |
+
st.session_state.messages = []
|
101 |
+
|
102 |
+
# Display chat messages from history on app rerun
|
103 |
+
for message in st.session_state.messages:
|
104 |
+
with st.chat_message(message["role"]):
|
105 |
+
st.markdown(message["content"])
|
106 |
+
|
107 |
+
# Accept user input
|
108 |
+
if prompt := st.chat_input("What is up?"):
|
109 |
+
# Add user message to chat history
|
110 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
111 |
+
# Display user message in chat message container
|
112 |
+
with st.chat_message("user"):
|
113 |
+
st.markdown(prompt)
|
114 |
+
# Get response from chatbot
|
115 |
+
with st.chat_message("assitant"):
|
116 |
+
response = qa_chain(prompt)
|
117 |
+
st.markdown(response)
|
118 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
unstructured==0.11.8
|
2 |
+
langchain==0.0.336
|
3 |
+
sentence-transformers==2.2.2
|
4 |
+
accelerate==0.25.0
|
5 |
+
chromadb==0.4.22
|