mutea commited on
Commit
8f0aa9f
1 Parent(s): aa3fed2

create app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pypdf import PdfReader
3
+ import torch
4
+ from io import BytesIO
5
+ from langchain.prompts import PromptTemplate
6
+ from langchain.document_loaders import PyPDFLoader
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain.embeddings import HuggingFaceEmbeddings
9
+ from langchain.vectorstores import FAISS
10
+ from langchain.chains import RetrievalQA
11
+ import textwrap
12
+
13
+ from langchain.llms.huggingface_pipeline import HuggingFacePipeline
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
15
+
16
+ # load the environments
17
+ from dotenv import load_dotenv
18
+ load_dotenv()
19
+
20
+
21
+
22
+ #DEFINE SOME VARIABLES
23
+
24
+ CHUNK_SIZE = 1000
25
+ # Using HuggingFaceEmbeddings with the chosen embedding model
26
+ embeddings = HuggingFaceEmbeddings(
27
+ model_name="sentence-transformers/all-mpnet-base-v2",model_kwargs = {"device": "cuda"})
28
+
29
+
30
+
31
+ # transformer model configuration
32
+ quant_config = BitsAndBytesConfig(
33
+ load_in_4bit=True,
34
+ bnb_4bit_use_double_quant=True,
35
+ bnb_4bit_quant_type="nf4",
36
+ bnb_4bit_compute_dtype=torch.bfloat16
37
+ )
38
+
39
+
40
+
41
+ # CREATE A VECTOR DATABASE - FAISS
42
+ def creat_vector_db(uploaded_pdfs) -> FAISS:
43
+ """Read multiple PDFs, split, embedd and store the embeddings on FAISS vector store"""
44
+
45
+ text = ""
46
+ for pdf in uploaded_pdfs:
47
+ pdf_reader = PdfReader(pdf)
48
+ for page in pdf_reader.pages:
49
+ text += page.extract_text()
50
+
51
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE,
52
+ chunk_overlap=100)
53
+ texts = text_splitter.split_text(text)
54
+
55
+ vector_db = FAISS.from_texts(texts, embeddings) # create vector db for similarity search
56
+ vector_db.save_local("faiss_index") # save the vector db to avoid repeated calls to it
57
+ return vector_db
58
+
59
+ # LOAD LLM
60
+ def load_llm():
61
+
62
+ model_id = "Deci/DeciLM-6b"
63
+
64
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
65
+ model = AutoModelForCausalLM.from_pretrained(model_id,
66
+ trust_remote_code=True,
67
+ device_map = "auto",
68
+ quantization_config=quant_config)
69
+
70
+ pipe = pipeline("text-generation",
71
+ model=model,
72
+ tokenizer=tokenizer,
73
+ temperature=0.1,
74
+ return_full_text = True,
75
+ max_new_tokens=40,
76
+ repetition_penalty = 1.1)
77
+
78
+ llm = HuggingFacePipeline(pipeline=pipe)
79
+
80
+ return llm
81
+
82
+
83
+
84
+ # RESPONSE INSTRUCTIONS
85
+ def set_custom_prompt():
86
+ """instructions, to the llm for text response generation"""
87
+
88
+
89
+ custom_prompt_template = """You have been given the following documents to answer the user's question.
90
+ If you do not have information from the information given to answer the questions just say 'I don't know the answer" and don't try to make up an answer.
91
+ Context: {context}
92
+ Question: {question}
93
+ Give a detailed helpful answer and nothing more.
94
+ Helpful answer:
95
+ """
96
+ prompt = PromptTemplate(template=custom_prompt_template, input_variables=[
97
+ "context", "question"])
98
+ return prompt
99
+
100
+
101
+
102
+ # QUESTION ANSWERING CHAIN
103
+ def retrieval_qa_chain(prompt, vector_db):
104
+ """Chain to retrieve answers. the chain takes the documents and
105
+ makes a call to the DeciLM-6b llm """
106
+
107
+ llm = load_llm()
108
+
109
+ qa_chain = RetrievalQA.from_chain_type(
110
+ llm = llm,
111
+ chain_type = "stuff",
112
+ retriever = vector_db.as_retriever(),
113
+ return_source_documents=True,
114
+ chain_type_kwargs={"prompt": prompt}
115
+ )
116
+
117
+ return qa_chain
118
+
119
+ # QUESTION ANSWER BOT
120
+ def qa_bot():
121
+ vectore_db = FAISS.load_local("faiss_index", embeddings)
122
+ conversation_prompt = set_custom_prompt()
123
+ conversation = retrieval_qa_chain(conversation_prompt, vectore_db)
124
+ return conversation
125
+
126
+ # RESPONSE FROM BOT
127
+ def bot_response(query):
128
+ conversation_result = qa_bot()
129
+ response = conversation_result({"query": query})
130
+ return response["result"]
131
+
132
+
133
+ def main():
134
+ st.set_page_config(page_title="Multiple PDFs chat with DeciLM-6b and LangChain",
135
+ page_icon=":file_folder:")
136
+
137
+ # page side panel
138
+ with st.sidebar:
139
+ st.subheader("Hello, welcome!")
140
+
141
+ pdfs = st.file_uploader(label="Upload your PDFs here and click Process!",
142
+ accept_multiple_files=True)
143
+
144
+ if st.button("Process"):
145
+ with st.spinner("Processing file(s)..."):
146
+ # create a vectore store
147
+ creat_vector_db(pdfs)
148
+ st.write("Your files are Processed. You set to ask questions!")
149
+
150
+ st.header("Chat with Multiple PDFs using DeciLM-6b-instruct LLM")
151
+
152
+ # Query side
153
+ query = st.text_input(label="Type your question based on the PDFs",
154
+ placeholder="Type question...")
155
+
156
+ if query:
157
+ st.write(f"Query: {query}")
158
+ st.text(textwrap.fill(bot_response(query), width=80))
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()