Spaces:
Running
Running
# to-do: Enable downloading multiple patent PDFs via corresponding links | |
import sys | |
import os | |
import re | |
import shutil | |
import time | |
import fitz | |
import streamlit as st | |
import nltk | |
import tempfile | |
import subprocess | |
# Pin NLTK to version 3.9.1 | |
REQUIRED_NLTK_VERSION = "3.9.1" | |
subprocess.run([sys.executable, "-m", "pip", "install", f"nltk=={REQUIRED_NLTK_VERSION}"]) | |
# Set up temporary directory for NLTK resources | |
nltk_data_path = os.path.join(tempfile.gettempdir(), "nltk_data") | |
os.makedirs(nltk_data_path, exist_ok=True) | |
nltk.data.path.append(nltk_data_path) | |
# Download 'punkt_tab' for compatibility | |
try: | |
print("Ensuring NLTK 'punkt_tab' resource is downloaded...") | |
nltk.download("punkt_tab", download_dir=nltk_data_path) | |
except Exception as e: | |
print(f"Error downloading NLTK 'punkt_tab': {e}") | |
raise e | |
sys.path.append(os.path.abspath(".")) | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain.llms import OpenAI | |
from langchain.document_loaders import UnstructuredPDFLoader | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import NLTKTextSplitter | |
from patent_downloader import PatentDownloader | |
PERSISTED_DIRECTORY = tempfile.mkdtemp() | |
# Fetch API key securely from the environment | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
if not OPENAI_API_KEY: | |
st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.") | |
st.stop() | |
def check_poppler_installed(): | |
if not shutil.which("pdfinfo"): | |
raise EnvironmentError( | |
"Poppler is not installed or not in PATH. Install 'poppler-utils' for PDF processing." | |
) | |
check_poppler_installed() | |
def load_docs(document_path): | |
try: | |
loader = UnstructuredPDFLoader( | |
document_path, | |
mode="elements", | |
strategy="fast", | |
ocr_languages=None | |
) | |
documents = loader.load() | |
text_splitter = NLTKTextSplitter(chunk_size=1000) | |
split_docs = text_splitter.split_documents(documents) | |
# Filter metadata to only include str, int, float, or bool | |
for doc in split_docs: | |
if hasattr(doc, "metadata") and isinstance(doc.metadata, dict): | |
doc.metadata = { | |
k: v for k, v in doc.metadata.items() | |
if isinstance(v, (str, int, float, bool)) | |
} | |
return split_docs | |
except Exception as e: | |
st.error(f"Failed to load and process PDF: {e}") | |
st.stop() | |
def already_indexed(vectordb, file_name): | |
indexed_sources = set( | |
x["source"] for x in vectordb.get(include=["metadatas"])["metadatas"] | |
) | |
return file_name in indexed_sources | |
def load_chain(file_name=None): | |
loaded_patent = st.session_state.get("LOADED_PATENT") | |
vectordb = Chroma( | |
persist_directory=PERSISTED_DIRECTORY, | |
embedding_function=HuggingFaceEmbeddings(), | |
) | |
if loaded_patent == file_name or already_indexed(vectordb, file_name): | |
st.write("✅ Already indexed.") | |
else: | |
vectordb.delete_collection() | |
docs = load_docs(file_name) | |
st.write("🔍 Number of Documents: ", len(docs)) | |
vectordb = Chroma.from_documents( | |
docs, HuggingFaceEmbeddings(), persist_directory=PERSISTED_DIRECTORY | |
) | |
vectordb.persist() | |
st.session_state["LOADED_PATENT"] = file_name | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
return_messages=True, | |
input_key="question", | |
output_key="answer", | |
) | |
return ConversationalRetrievalChain.from_llm( | |
OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY), | |
vectordb.as_retriever(search_kwargs={"k": 3}), | |
return_source_documents=False, | |
memory=memory, | |
) | |
def extract_patent_number(url): | |
pattern = r"/patent/([A-Z]{2}\d+)" | |
match = re.search(pattern, url) | |
return match.group(1) if match else None | |
def download_pdf(patent_number): | |
try: | |
patent_downloader = PatentDownloader(verbose=True) | |
output_path = patent_downloader.download(patents=patent_number, output_path=tempfile.gettempdir()) | |
return output_path[0] | |
except Exception as e: | |
st.error(f"Failed to download patent PDF: {e}") | |
st.stop() | |
def preview_pdf(pdf_path): | |
"""Generate and display the first page of the PDF as an image.""" | |
try: | |
doc = fitz.open(pdf_path) # Open PDF | |
first_page = doc[0] # Extract the first page | |
pix = first_page.get_pixmap() # Render page to a Pixmap (image) | |
temp_image_path = os.path.join(tempfile.gettempdir(), "pdf_preview.png") | |
pix.save(temp_image_path) # Save the image temporarily | |
return temp_image_path | |
except Exception as e: | |
st.error(f"Error generating PDF preview: {e}") | |
return None | |
if __name__ == "__main__": | |
st.set_page_config( | |
page_title="Patent Chat: Google Patents Chat Demo", | |
page_icon="📖", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
st.header("📖 Patent Chat: Google Patents Chat Demo") | |
# Fetch query parameters safely | |
query_params = st.query_params | |
default_patent_link = query_params.get("patent_link", "https://patents.google.com/patent/US8676427B1/en") | |
# Input for Google Patent Link | |
patent_link = st.text_area("Enter Google Patent Link:", value=default_patent_link, height=100) | |
# Button to start processing | |
if st.button("Load and Process Patent"): | |
if not patent_link: | |
st.warning("Please enter a Google patent link to proceed.") | |
st.stop() | |
# Extract patent number | |
patent_number = extract_patent_number(patent_link) | |
if not patent_number: | |
st.error("Invalid patent link format. Please provide a valid Google patent link.") | |
st.stop() | |
st.write(f"Patent number: **{patent_number}**") | |
# File download handling | |
pdf_path = os.path.join(tempfile.gettempdir(), f"{patent_number}.pdf") | |
if os.path.isfile(pdf_path): | |
st.write("✅ File already downloaded.") | |
else: | |
st.write("📥 Downloading patent file...") | |
pdf_path = download_pdf(patent_number) | |
st.write(f"✅ File downloaded: {pdf_path}") | |
# Generate and display PDF preview | |
st.write("🖼️ Generating PDF preview...") | |
preview_image_path = preview_pdf(pdf_path) | |
if preview_image_path: | |
st.image(preview_image_path, caption="First Page Preview", use_column_width=True) | |
else: | |
st.warning("Failed to generate a preview for this PDF.") | |
# Load the document into the system | |
st.write("🔄 Loading document into the system...") | |
# Persist the chain in session state to prevent reloading | |
if "chain" not in st.session_state or st.session_state.get("loaded_file") != pdf_path: | |
st.session_state.chain = load_chain(pdf_path) | |
st.session_state.loaded_file = pdf_path | |
st.session_state.messages = [{"role": "assistant", "content": "Hello! How can I assist you with this patent?"}] | |
st.success("🚀 Document successfully loaded! You can now start asking questions.") | |
# Initialize messages if not already done | |
if "messages" not in st.session_state: | |
st.session_state.messages = [{"role": "assistant", "content": "Hello! How can I assist you with this patent?"}] | |
# Display previous chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# User input and chatbot response | |
if "chain" in st.session_state: | |
if user_input := st.chat_input("What is your question?"): | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
with st.chat_message("user"): | |
st.markdown(user_input) | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
full_response = "" | |
with st.spinner("Generating response..."): | |
try: | |
assistant_response = st.session_state.chain({"question": user_input}) | |
full_response = assistant_response["answer"] | |
except Exception as e: | |
full_response = f"An error occurred: {e}" | |
message_placeholder.markdown(full_response) | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
else: | |
st.info("Press the 'Load and Process Patent' button to start processing.") | |