Shreyas094 commited on
Commit
8da6a04
·
verified ·
1 Parent(s): 8544733

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -106
app.py CHANGED
@@ -17,129 +17,201 @@ huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
17
 
18
  # Memory database to store question-answer pairs
19
  memory_database = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def load_and_split_document_basic(file):
21
- """Loads and splits the document into pages."""
22
- loader = PyPDFLoader(file.name)
23
- data = loader.load_and_split()
24
- return data
 
25
  def load_and_split_document_recursive(file: NamedTemporaryFile) -> List[Document]:
26
- """Loads and splits the document into chunks."""
27
- loader = PyPDFLoader(file.name)
28
- pages = loader.load()
29
- text_splitter = RecursiveCharacterTextSplitter(
30
- chunk_size=1000,
31
- chunk_overlap=200,
32
- length_function=len,
33
- )
34
- chunks = text_splitter.split_documents(pages)
35
- return chunks
 
 
 
36
  def get_embeddings():
37
- return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
 
38
  def create_or_update_database(data, embeddings):
39
- if os.path.exists("faiss_database"):
40
- db = FAISS.load_local("faiss_database", embeddings, allow_dangerous_deserialization=True)
41
- db.add_documents(data)
42
- else:
43
- db = FAISS.from_documents(data, embeddings)
44
- db.save_local("faiss_database")
 
45
  def clear_cache():
46
- if os.path.exists("faiss_database"):
47
- os.remove("faiss_database")
48
- return "Cache cleared successfully."
49
- else:
50
- return "No cache to clear."
 
51
  prompt = """
52
  Answer the question based only on the following context:
53
  {context}
54
  Question: {question}
 
55
  Provide a concise and direct answer to the question:
56
  """
 
57
  def get_model(temperature, top_p, repetition_penalty):
58
- return HuggingFaceHub(
59
- repo_id="mistralai/Mistral-7B-Instruct-v0.3",
60
- model_kwargs={
61
- "temperature": temperature,
62
- "top_p": top_p,
63
- "repetition_penalty": repetition_penalty,
64
- "max_length": 512
65
- },
66
- huggingfacehub_api_token=huggingface_token
67
- )
 
68
  def generate_chunked_response(model, prompt, max_tokens=500, max_chunks=5):
69
- full_response = ""
70
- for i in range(max_chunks):
71
- chunk = model(prompt + full_response, max_new_tokens=max_tokens)
72
- full_response += chunk
73
- if chunk.strip().endswith((".", "!", "?")):
74
- break
75
- return full_response.strip()
 
76
  def response(database, model, question):
77
- prompt_val = ChatPromptTemplate.from_template(prompt)
78
- retriever = database.as_retriever()
79
- context = retriever.get_relevant_documents(question)
80
- context_str = "\n".join([doc.page_content for doc in context])
81
- formatted_prompt = prompt_val.format(context=context_str, question=question)
82
- ans = generate_chunked_response(model, formatted_prompt)
83
- return ans
 
 
 
 
84
  def update_vectors(files, use_recursive_splitter):
85
- if not files:
86
- return "Please upload at least one PDF file."
87
- embed = get_embeddings()
88
- total_chunks = 0
89
- for file in files:
90
- if use_recursive_splitter:
91
- data = load_and_split_document_recursive(file)
92
- else:
93
- data = load_and_split_document_basic(file)
94
- create_or_update_database(data, embed)
95
- total_chunks += len(data)
96
- return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files."
 
 
 
 
97
  def ask_question(question, temperature, top_p, repetition_penalty):
98
- if not question:
99
- return "Please enter a question."
100
- # Check if the question exists in the memory database
101
- if question in memory_database:
102
- return memory_database[question]
103
- embed = get_embeddings()
104
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
105
- model = get_model(temperature, top_p, repetition_penalty)
106
- # Generate response from document database
107
- answer = response(database, model, question)
108
- # Store the question and answer in the memory database
109
- memory_database[question] = answer
110
- return answer
 
 
 
 
 
 
111
  def extract_db_to_excel():
112
- embed = get_embeddings()
113
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
114
- documents = database.docstore._dict.values()
115
- data = [{"page_content": doc.page_content, "metadata": json.dumps(doc.metadata)} for doc in documents]
116
- df = pd.DataFrame(data)
117
- with NamedTemporaryFile(delete=False, suffix='.xlsx') as tmp:
118
- excel_path = tmp.name
119
- df.to_excel(excel_path, index=False)
120
- return excel_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  # Gradio interface
122
  with gr.Blocks() as demo:
123
- gr.Markdown("# Chat with your PDF documents")
124
- with gr.Row():
125
- file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
126
- update_button = gr.Button("Update Vector Store")
127
- use_recursive_splitter = gr.Checkbox(label="Use Recursive Text Splitter", value=False)
128
- update_output = gr.Textbox(label="Update Status")
129
- update_button.click(update_vectors, inputs=[file_input, use_recursive_splitter], outputs=update_output)
130
- with gr.Row():
131
- question_input = gr.Textbox(label="Ask a question about your documents")
132
- temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
133
- top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1)
134
- repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
135
- submit_button = gr.Button("Submit")
136
- answer_output = gr.Textbox(label="Answer")
137
- submit_button.click(ask_question, inputs=[question_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=answer_output)
138
- extract_button = gr.Button("Extract Database to Excel")
139
- excel_output = gr.File(label="Download Excel File")
140
- extract_button.click(extract_db_to_excel, inputs=[], outputs=excel_output)
141
- clear_button = gr.Button("Clear Cache")
142
- clear_output = gr.Textbox(label="Cache Status")
143
- clear_button.click(clear_cache, inputs=[], outputs=clear_output)
 
 
 
 
 
 
 
 
 
 
 
144
  if __name__ == "__main__":
145
- demo.launch()
 
17
 
18
  # Memory database to store question-answer pairs
19
  memory_database = {}
20
+ import os
21
+ import json
22
+ import gradio as gr
23
+ import pandas as pd
24
+ from tempfile import NamedTemporaryFile
25
+ from typing import List
26
+
27
+ from langchain_core.prompts import ChatPromptTemplate
28
+ from langchain_community.vectorstores import FAISS
29
+ from langchain_community.document_loaders import PyPDFLoader
30
+ from langchain_core.output_parsers import StrOutputParser
31
+ from langchain_community.embeddings import HuggingFaceEmbeddings
32
+ from langchain_community.llms import HuggingFaceHub
33
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
34
+ from langchain_core.text_splitters import RecursiveCharacterTextSplitter
35
+ from langchain_core.document import Document
36
+
37
+ huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
38
+
39
+ # Memory database to store question-answer pairs
40
+ memory_database = {}
41
+
42
  def load_and_split_document_basic(file):
43
+ """Loads and splits the document into pages."""
44
+ loader = PyPDFLoader(file.name)
45
+ data = loader.load_and_split()
46
+ return data
47
+
48
  def load_and_split_document_recursive(file: NamedTemporaryFile) -> List[Document]:
49
+ """Loads and splits the document into chunks."""
50
+ loader = PyPDFLoader(file.name)
51
+ pages = loader.load()
52
+
53
+ text_splitter = RecursiveCharacterTextSplitter(
54
+ chunk_size=1000,
55
+ chunk_overlap=200,
56
+ length_function=len,
57
+ )
58
+
59
+ chunks = text_splitter.split_documents(pages)
60
+ return chunks
61
+
62
  def get_embeddings():
63
+ return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
64
+
65
  def create_or_update_database(data, embeddings):
66
+ if os.path.exists("faiss_database"):
67
+ db = FAISS.load_local("faiss_database", embeddings, allow_dangerous_deserialization=True)
68
+ db.add_documents(data)
69
+ else:
70
+ db = FAISS.from_documents(data, embeddings)
71
+ db.save_local("faiss_database")
72
+
73
  def clear_cache():
74
+ if os.path.exists("faiss_database"):
75
+ os.remove("faiss_database")
76
+ return "Cache cleared successfully."
77
+ else:
78
+ return "No cache to clear."
79
+
80
  prompt = """
81
  Answer the question based only on the following context:
82
  {context}
83
  Question: {question}
84
+
85
  Provide a concise and direct answer to the question:
86
  """
87
+
88
  def get_model(temperature, top_p, repetition_penalty):
89
+ return HuggingFaceHub(
90
+ repo_id="mistralai/Mistral-7B-Instruct-v0.3",
91
+ model_kwargs={
92
+ "temperature": temperature,
93
+ "top_p": top_p,
94
+ "repetition_penalty": repetition_penalty,
95
+ "max_length": 512
96
+ },
97
+ huggingfacehub_api_token=huggingface_token
98
+ )
99
+
100
  def generate_chunked_response(model, prompt, max_tokens=500, max_chunks=5):
101
+ full_response = ""
102
+ for i in range(max_chunks):
103
+ chunk = model(prompt + full_response, max_new_tokens=max_tokens)
104
+ full_response += chunk
105
+ if chunk.strip().endswith((".", "!", "?")):
106
+ break
107
+ return full_response.strip()
108
+
109
  def response(database, model, question):
110
+ prompt_val = ChatPromptTemplate.from_template(prompt)
111
+ retriever = database.as_retriever()
112
+
113
+ context = retriever.get_relevant_documents(question)
114
+ context_str = "\n".join([doc.page_content for doc in context])
115
+
116
+ formatted_prompt = prompt_val.format(context=context_str, question=question)
117
+
118
+ ans = generate_chunked_response(model, formatted_prompt)
119
+ return ans # Only return the answer
120
+
121
  def update_vectors(files, use_recursive_splitter):
122
+ if not files:
123
+ return "Please upload at least one PDF file."
124
+
125
+ embed = get_embeddings()
126
+ total_chunks = 0
127
+
128
+ for file in files:
129
+ if use_recursive_splitter:
130
+ data = load_and_split_document_recursive(file)
131
+ else:
132
+ data = load_and_split_document_basic(file)
133
+ create_or_update_database(data, embed)
134
+ total_chunks += len(data)
135
+
136
+ return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files."
137
+
138
  def ask_question(question, temperature, top_p, repetition_penalty):
139
+ if not question:
140
+ return "Please enter a question."
141
+
142
+ # Check if the question exists in the memory database
143
+ if question in memory_database:
144
+ return memory_database[question]
145
+
146
+ embed = get_embeddings()
147
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
148
+ model = get_model(temperature, top_p, repetition_penalty)
149
+
150
+ # Generate response from document database
151
+ answer = response(database, model, question)
152
+
153
+ # Store the question and answer in the memory database
154
+ memory_database[question] = answer
155
+
156
+ return answer
157
+
158
  def extract_db_to_excel():
159
+ embed = get_embeddings()
160
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
161
+
162
+ documents = database.docstore._dict.values()
163
+ data = [{"page_content": doc.page_content, "metadata": json.dumps(doc.metadata)} for doc in documents]
164
+ df = pd.DataFrame(data)
165
+
166
+ with NamedTemporaryFile(delete=False, suffix='.xlsx') as tmp:
167
+ excel_path = tmp.name
168
+ df.to_excel(excel_path, index=False)
169
+
170
+ return excel_path
171
+
172
+ def export_memory_db_to_excel():
173
+ data = [{"question": question, "answer": answer} for question, answer in memory_database.items()]
174
+ df = pd.DataFrame(data)
175
+
176
+ with NamedTemporaryFile(delete=False, suffix='.xlsx') as tmp:
177
+ excel_path = tmp.name
178
+ df.to_excel(excel_path, index=False)
179
+
180
+ return excel_path
181
+
182
  # Gradio interface
183
  with gr.Blocks() as demo:
184
+ gr.Markdown("# Chat with your PDF documents")
185
+
186
+ with gr.Row():
187
+ file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
188
+ update_button = gr.Button("Update Vector Store")
189
+ use_recursive_splitter = gr.Checkbox(label="Use Recursive Text Splitter", value=False)
190
+
191
+ update_output = gr.Textbox(label="Update Status")
192
+ update_button.click(update_vectors, inputs=[file_input, use_recursive_splitter], outputs=update_output)
193
+
194
+ with gr.Row():
195
+ question_input = gr.Textbox(label="Ask a question about your documents")
196
+ temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
197
+ top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1)
198
+ repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
199
+ submit_button = gr.Button("Submit")
200
+
201
+ answer_output = gr.Textbox(label="Answer")
202
+ submit_button.click(ask_question, inputs=[question_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=answer_output)
203
+
204
+ extract_button = gr.Button("Extract Database to Excel")
205
+ excel_output = gr.File(label="Download Excel File")
206
+ extract_button.click(extract_db_to_excel, inputs=[], outputs=excel_output)
207
+
208
+ export_memory_button = gr.Button("Export Memory Database to Excel")
209
+ memory_excel_output = gr.File(label="Download Memory Excel File")
210
+ export_memory_button.click(export_memory_db_to_excel, inputs=[], outputs=memory_excel_output)
211
+
212
+ clear_button = gr.Button("Clear Cache")
213
+ clear_output = gr.Textbox(label="Cache Status")
214
+ clear_button.click(clear_cache, inputs=[], outputs=clear_output)
215
+
216
  if __name__ == "__main__":
217
+ demo.launch()