DrishtiSharma commited on
Commit
0c77c36
·
verified ·
1 Parent(s): 99b856f

Create test.py

Browse files
Files changed (1) hide show
  1. test.py +339 -0
test.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # to-do: Enable downloading multiple patent PDFs via corresponding links
2
+ import sys
3
+ import os
4
+ import re
5
+ import shutil
6
+ import time
7
+ import fitz
8
+ import streamlit as st
9
+ import nltk
10
+ import tempfile
11
+ import subprocess
12
+
13
+ # Pin NLTK to version 3.9.1
14
+ REQUIRED_NLTK_VERSION = "3.9.1"
15
+ subprocess.run([sys.executable, "-m", "pip", "install", f"nltk=={REQUIRED_NLTK_VERSION}"])
16
+
17
+ # Set up temporary directory for NLTK resources
18
+ nltk_data_path = os.path.join(tempfile.gettempdir(), "nltk_data")
19
+ os.makedirs(nltk_data_path, exist_ok=True)
20
+ nltk.data.path.append(nltk_data_path)
21
+
22
+ # Download 'punkt_tab' for compatibility
23
+ try:
24
+ print("Ensuring NLTK 'punkt_tab' resource is downloaded...")
25
+ nltk.download("punkt_tab", download_dir=nltk_data_path)
26
+ except Exception as e:
27
+ print(f"Error downloading NLTK 'punkt_tab': {e}")
28
+ raise e
29
+
30
+ sys.path.append(os.path.abspath("."))
31
+ from langchain.chains import ConversationalRetrievalChain
32
+ from langchain.memory import ConversationBufferMemory
33
+ from langchain.llms import OpenAI
34
+ from langchain.document_loaders import UnstructuredPDFLoader
35
+ from langchain.vectorstores import Chroma
36
+ from langchain.embeddings import HuggingFaceEmbeddings
37
+ from langchain.text_splitter import NLTKTextSplitter
38
+ from patent_downloader import PatentDownloader
39
+ from langchain.document_loaders import PyMuPDFLoader
40
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
41
+
42
+ PERSISTED_DIRECTORY = tempfile.mkdtemp()
43
+
44
+ # Fetch API key securely from the environment
45
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
46
+ if not OPENAI_API_KEY:
47
+ st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.")
48
+ st.stop()
49
+
50
+ def check_poppler_installed():
51
+ if not shutil.which("pdfinfo"):
52
+ raise EnvironmentError(
53
+ "Poppler is not installed or not in PATH. Install 'poppler-utils' for PDF processing."
54
+ )
55
+
56
+ check_poppler_installed()
57
+
58
+ def load_docs(document_path):
59
+ """
60
+ Load and clean the PDF content, then split into chunks.
61
+ """
62
+ try:
63
+ import fitz # PyMuPDF for text extraction
64
+
65
+ # Step 1: Extract plain text from PDF
66
+ doc = fitz.open(document_path)
67
+ extracted_text = []
68
+
69
+ for page_num, page in enumerate(doc):
70
+ page_text = page.get_text("text") # Extract text
71
+ clean_page_text = clean_extracted_text(page_text)
72
+ if clean_page_text: # Keep only non-empty cleaned text
73
+ extracted_text.append(clean_page_text)
74
+
75
+ doc.close()
76
+
77
+ # Combine all pages into one text
78
+ full_text = "\n".join(extracted_text)
79
+ st.write(f"📄 Total Cleaned Text Length: {len(full_text)} characters")
80
+
81
+ # Step 2: Chunk the cleaned text
82
+ text_splitter = RecursiveCharacterTextSplitter(
83
+ chunk_size=1000,
84
+ chunk_overlap=100,
85
+ separators=["\n\n", "\n", " ", ""]
86
+ )
87
+ split_docs = text_splitter.create_documents([full_text])
88
+
89
+ # Debug: Show total chunks count and first 3 chunks for verification
90
+ st.write(f"🔍 Total Chunks After Splitting: {len(split_docs)}")
91
+ for i, doc in enumerate(split_docs[:3]): # Show first 3 chunks only
92
+ st.write(f"Chunk {i + 1}: {doc.page_content[:300]}...")
93
+
94
+ return split_docs
95
+ except Exception as e:
96
+ st.error(f"Failed to load and process PDF: {e}")
97
+ st.stop()
98
+
99
+
100
+ def clean_extracted_text(text):
101
+ """
102
+ Cleans extracted text to remove metadata, headers, and irrelevant content.
103
+ """
104
+ lines = text.split("\n")
105
+ cleaned_lines = []
106
+
107
+ for line in lines:
108
+ line = line.strip()
109
+
110
+ # Filter out lines with metadata patterns
111
+ if (
112
+ re.match(r"^(U\.S\.|United States|Sheet|Figure|References|Patent No|Date of Patent)", line)
113
+ or re.match(r"^\(?\d+\)?$", line) # Matches single numbers (page numbers)
114
+ or "Examiner" in line
115
+ or "Attorney" in line
116
+ or len(line) < 30 # Skip very short lines
117
+ ):
118
+ continue
119
+
120
+ cleaned_lines.append(line)
121
+
122
+ return "\n".join(cleaned_lines)
123
+
124
+
125
+ def already_indexed(vectordb, file_name):
126
+ indexed_sources = set(
127
+ x["source"] for x in vectordb.get(include=["metadatas"])["metadatas"]
128
+ )
129
+ return file_name in indexed_sources
130
+
131
+ def load_chain(file_name=None):
132
+ """
133
+ Load cleaned PDF text, split into chunks, and update the vectorstore.
134
+ """
135
+ loaded_patent = st.session_state.get("LOADED_PATENT")
136
+
137
+ # Debug: Show persist directory
138
+ st.write(f"🗂 Using Persisted Directory: {PERSISTED_DIRECTORY}")
139
+
140
+ vectordb = Chroma(
141
+ persist_directory=PERSISTED_DIRECTORY,
142
+ embedding_function=HuggingFaceEmbeddings(),
143
+ )
144
+
145
+ if loaded_patent == file_name or already_indexed(vectordb, file_name):
146
+ st.write("✅ Already indexed.")
147
+ else:
148
+ st.write("🔄 Starting document processing and vectorstore update...")
149
+
150
+ # Remove existing collection and load new docs
151
+ vectordb.delete_collection()
152
+ docs = load_docs(file_name)
153
+
154
+ # Update vectorstore
155
+ vectordb = Chroma.from_documents(
156
+ docs, HuggingFaceEmbeddings(), persist_directory=PERSISTED_DIRECTORY
157
+ )
158
+ vectordb.persist()
159
+ st.write("✅ Vectorstore successfully updated and persisted.")
160
+
161
+ # Save loaded patent in session state
162
+ st.session_state["LOADED_PATENT"] = file_name
163
+
164
+ # Debug: Check vectorstore indexing summary
165
+ indexed_docs = vectordb.get(include=["documents"])
166
+ st.write(f"✅ Total Indexed Documents: {len(indexed_docs['documents'])}")
167
+
168
+ # Test retrieval with a simple query
169
+ retriever = vectordb.as_retriever(search_kwargs={"k": 3})
170
+ test_query = "What is this document about?"
171
+ results = retriever.get_relevant_documents(test_query)
172
+
173
+ st.write("🔍 Test Retrieval Results for Query:")
174
+ if results:
175
+ for i, res in enumerate(results):
176
+ st.write(f"Retrieved Doc {i + 1}: {res.page_content[:200]}...")
177
+ else:
178
+ st.warning("No documents retrieved for test query.")
179
+
180
+ # Configure memory for conversation
181
+ memory = ConversationBufferMemory(
182
+ memory_key="chat_history",
183
+ return_messages=True
184
+ )
185
+
186
+ return ConversationalRetrievalChain.from_llm(
187
+ OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY),
188
+ retriever,
189
+ memory=memory
190
+ )
191
+
192
+
193
+ def extract_patent_number(url):
194
+ pattern = r"/patent/([A-Z]{2}\d+)"
195
+ match = re.search(pattern, url)
196
+ return match.group(1) if match else None
197
+
198
+ def download_pdf(patent_number):
199
+ try:
200
+ patent_downloader = PatentDownloader(verbose=True)
201
+ output_path = patent_downloader.download(patents=patent_number, output_path=tempfile.gettempdir())
202
+ return output_path[0]
203
+ except Exception as e:
204
+ st.error(f"Failed to download patent PDF: {e}")
205
+ st.stop()
206
+
207
+ def preview_pdf(pdf_path, scale_factor=0.5):
208
+ """
209
+ Generate and display a resized preview of the first page of the PDF.
210
+ Args:
211
+ pdf_path (str): Path to the PDF file.
212
+ scale_factor (float): Factor to reduce the image size (default is 0.5).
213
+ Returns:
214
+ str: Path to the resized image preview.
215
+ """
216
+ try:
217
+ # Open the PDF and extract the first page
218
+ doc = fitz.open(pdf_path)
219
+ first_page = doc[0]
220
+
221
+ # Apply scaling using a transformation matrix
222
+ matrix = fitz.Matrix(scale_factor, scale_factor) # Scale down the image
223
+ pix = first_page.get_pixmap(matrix=matrix) # Generate scaled image
224
+
225
+ # Save the preview image
226
+ temp_image_path = os.path.join(tempfile.gettempdir(), "pdf_preview.png")
227
+ pix.save(temp_image_path)
228
+
229
+ doc.close()
230
+ return temp_image_path
231
+
232
+ except Exception as e:
233
+ st.error(f"Error generating PDF preview: {e}")
234
+ return None
235
+
236
+
237
+ if __name__ == "__main__":
238
+ st.set_page_config(
239
+ page_title="Patent Chat: Google Patents Chat Demo",
240
+ page_icon="📖",
241
+ layout="wide",
242
+ initial_sidebar_state="expanded",
243
+ )
244
+ st.header("📖 Patent Chat: Google Patents Chat Demo")
245
+
246
+ # Input for Google Patent Link
247
+ patent_link = st.text_area(
248
+ "Enter Google Patent Link:",
249
+ value="https://patents.google.com/patent/US8676427B1/en",
250
+ height=90
251
+ )
252
+
253
+ # Initialize session state
254
+ for key in ["LOADED_PATENT", "pdf_preview", "loaded_pdf_path", "chain", "messages"]:
255
+ if key not in st.session_state:
256
+ st.session_state[key] = None
257
+
258
+ # Button to load and process patent
259
+ if st.button("Load and Process Patent"):
260
+ if not patent_link:
261
+ st.warning("Please enter a valid Google patent link.")
262
+ st.stop()
263
+
264
+ # Extract patent number
265
+ patent_number = extract_patent_number(patent_link)
266
+ if not patent_number:
267
+ st.error("Invalid patent link format.")
268
+ st.stop()
269
+
270
+ st.write(f"Patent number: **{patent_number}**")
271
+
272
+ # File handling
273
+ pdf_path = os.path.join(tempfile.gettempdir(), f"{patent_number}.pdf")
274
+ if not os.path.isfile(pdf_path):
275
+ with st.spinner("📥 Downloading patent file..."):
276
+ try:
277
+ pdf_path = download_pdf(patent_number)
278
+ st.write(f"✅ File downloaded: {pdf_path}")
279
+ except Exception as e:
280
+ st.error(f"Failed to download patent: {e}")
281
+ st.stop()
282
+ else:
283
+ st.write("✅ File already downloaded.")
284
+
285
+ # Generate PDF preview only if not already displayed
286
+ if not st.session_state.get("pdf_preview_displayed", False):
287
+ with st.spinner("🖼️ Generating PDF preview..."):
288
+ preview_image_path = preview_pdf(pdf_path, scale_factor=0.5)
289
+ if preview_image_path:
290
+ st.session_state.pdf_preview = preview_image_path
291
+ st.image(preview_image_path, caption="First Page Preview", use_container_width=False)
292
+ st.session_state["pdf_preview_displayed"] = True
293
+ else:
294
+ st.warning("Failed to generate PDF preview.")
295
+ st.session_state.pdf_preview = None
296
+
297
+ # Load the document into the system
298
+ with st.spinner("🔄 Loading document into the system..."):
299
+ try:
300
+ st.session_state.chain = load_chain(pdf_path)
301
+ st.session_state.LOADED_PATENT = patent_number
302
+ st.session_state.loaded_pdf_path = pdf_path
303
+ st.session_state.messages = [{"role": "assistant", "content": "Hello! How can I assist you with this patent?"}]
304
+ st.success("🚀 Document successfully loaded! You can now start asking questions.")
305
+ except Exception as e:
306
+ st.error(f"Failed to load the document: {e}")
307
+ st.stop()
308
+
309
+ # Display previous chat messages
310
+ if st.session_state.messages:
311
+ for message in st.session_state.messages:
312
+ with st.chat_message(message["role"]):
313
+ st.markdown(message["content"])
314
+
315
+ # User input for questions
316
+ if st.session_state.chain:
317
+ if user_input := st.chat_input("What is your question?"):
318
+ # User message
319
+ st.session_state.messages.append({"role": "user", "content": user_input})
320
+ with st.chat_message("user"):
321
+ st.markdown(user_input)
322
+
323
+ # Assistant response
324
+ with st.chat_message("assistant"):
325
+ message_placeholder = st.empty()
326
+ full_response = ""
327
+
328
+ with st.spinner("Generating response..."):
329
+ try:
330
+ # Generate response using the chain
331
+ assistant_response = st.session_state.chain({"question": user_input})
332
+ full_response = assistant_response.get("answer", "I'm sorry, I couldn't process that question.")
333
+ except Exception as e:
334
+ full_response = f"An error occurred: {e}"
335
+
336
+ message_placeholder.markdown(full_response)
337
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
338
+ else:
339
+ st.info("Press the 'Load and Process Patent' button to start processing.")