Spaces:
Sleeping
Sleeping
# app.py | |
import os | |
import sys | |
import logging | |
from getpass import getpass | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.prompts import ChatPromptTemplate | |
import gradio as gr | |
import zipfile | |
import os | |
# Define your zip file path and destination folder | |
zip_file = "combined_folders.zip" | |
destination_folder = "properties_vectors" | |
# Create the destination folder if it doesn't exist | |
os.makedirs(destination_folder, exist_ok=True) | |
# Use zipfile to unzip | |
with zipfile.ZipFile(zip_file, 'r') as zip_ref: | |
zip_ref.extractall(destination_folder) | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Function to get the absolute path | |
def get_absolute_path(relative_path): | |
if getattr(sys, 'frozen', False): | |
# If the application is run as a bundle, the PyInstaller bootloader | |
# extends the sys module by a flag frozen=True and sets the app | |
# path into variable _MEIPASS'. | |
base_path = sys._MEIPASS | |
else: | |
base_path = os.path.abspath(".") | |
return os.path.join(base_path, relative_path) | |
# Retrieve OpenAI API key from environment variable or prompt | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
if not openai_api_key: | |
openai_api_key = getpass("Enter your OpenAI API key2: ") | |
os.environ["OPENAI_API_KEY"] = openai_api_key | |
# Initialize embeddings | |
embeddings = OpenAIEmbeddings(model="text-embedding-3-small") | |
# Function to list available vector store directories | |
def list_vectorstore_directories(base_path='vectorstores'): | |
""" | |
Lists all subdirectories in the base_path which are potential vector store directories. | |
""" | |
directories = [] | |
try: | |
for entry in os.listdir(base_path): | |
full_path = os.path.join(base_path, entry) | |
print(full_path) | |
print(full_path) | |
if os.path.isdir(full_path): | |
# Check if the directory contains Chroma vector store files | |
required_files = ['chroma.sqlite3'] | |
if all(os.path.exists(os.path.join(full_path, file)) for file in required_files): | |
directories.append(full_path) | |
except Exception as e: | |
logger.error(f"Error listing directories in '{base_path}': {e}") | |
return directories | |
# Function to load selected vector stores | |
def load_selected_vectorstores(selected_dirs): | |
""" | |
Loads Chroma vector stores from the selected directories. | |
""" | |
vectorstores = [] | |
for directory in selected_dirs: | |
try: | |
vectorstore = Chroma( | |
persist_directory=directory, | |
embedding_function=embeddings | |
) | |
vectorstores.append(vectorstore) | |
logger.info(f"Loaded vectorstore from '{directory}'.") | |
except Exception as e: | |
logger.error(f"Error loading vectorstore from '{directory}': {e}") | |
return vectorstores | |
# Function to create a combined retriever | |
def create_combined_retriever(vectorstores, search_kwargs={"k": 20}): | |
retrievers = [vs.as_retriever(search_kwargs=search_kwargs) for vs in vectorstores] | |
class CombinedRetriever: | |
def __init__(self, retrievers): | |
self.retrievers = retrievers | |
def get_relevant_documents(self, query): | |
docs = [] | |
for retriever in self.retrievers: | |
try: | |
docs.extend(retriever.get_relevant_documents(query)) | |
except Exception as e: | |
logger.error(f"Error retrieving documents: {e}") | |
# Remove duplicates based on content and source | |
unique_docs = { (doc.page_content, doc.metadata.get('source', '')): doc for doc in docs } | |
return list(unique_docs.values()) | |
return CombinedRetriever(retrievers) | |
# Define the QA function | |
def answer_question(selected_dirs, question): | |
if not selected_dirs: | |
return "Please select at least one vector store directory." | |
# Load the selected vector stores | |
vectorstores = load_selected_vectorstores(selected_dirs) | |
if not vectorstores: | |
return "No vector stores loaded. Please check the selected directories." | |
# Create combined retriever | |
combined_retriever = create_combined_retriever(vectorstores, search_kwargs={"k": 20}) | |
# Load the LLM | |
try: | |
llm = ChatOpenAI(model_name="gpt-4o") | |
except Exception as e: | |
logger.error(f"Error loading LLM: {e}") | |
return "Error loading the language model. Please check your OpenAI API key and access." | |
# Define the prompt template | |
template = """ | |
You are an AI assistant specialized in extracting precise information from legal documents. | |
Special emphasis on documents but refer outside if necessary. | |
Always include the source filename and page number in your response. | |
If multiple documents are the always prefer the lastest date ones. | |
If ammendment documents are the always prefer the ammendments. | |
Context: | |
{context} | |
Question: {input} | |
Answer: | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
# Create QA chain | |
try: | |
qa_chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt) | |
except Exception as e: | |
logger.error(f"Error creating QA chain: {e}") | |
return "Error initializing the QA system." | |
# Retrieve documents | |
try: | |
retrieved_docs = combined_retriever.get_relevant_documents(question) | |
except Exception as e: | |
logger.error(f"Error retrieving documents: {e}") | |
return "Error retrieving documents." | |
if not retrieved_docs: | |
return "No relevant documents found for the question." | |
# Modify the retrieved documents to include metadata within the content | |
for doc in retrieved_docs: | |
source = doc.metadata.get("source", "Unknown Source") | |
page_number = doc.metadata.get("page_number", "Unknown Page") | |
doc.page_content = f"Source: {source}\nPage: {page_number}\nContent: {doc.page_content}" | |
# Generate response using the QA chain | |
try: | |
response = qa_chain.run(input_documents=retrieved_docs, input=question) | |
except Exception as e: | |
logger.error(f"Error generating response: {e}") | |
return "Error generating the response." | |
return response | |
# Set Up the Gradio Interface | |
# Get absolute path for vectorstores | |
vectorstores_path = get_absolute_path('properties_vectors/vectors') | |
# List available vector store directories | |
available_dirs = list_vectorstore_directories(vectorstores_path) | |
# if not available_dirs: | |
# available_dirs = [ | |
# "/content/trinity" | |
# # Add other directories as needed | |
# ] | |
# Define Gradio interface | |
iface = gr.Interface( | |
fn=answer_question, | |
inputs=[ | |
gr.CheckboxGroup( | |
choices=available_dirs, | |
label="Select Vector Store Directories" | |
), | |
gr.Textbox( | |
lines=2, | |
placeholder="Enter your question here...", | |
label="Your Question" | |
) | |
], | |
outputs=gr.Textbox(label="Response"), | |
title="Vector Store QA Assistant", | |
description="Select one or more vector store directories and ask your question. The assistant will retrieve relevant documents and provide an answer.", | |
allow_flagging="never" | |
) | |
# Launch the interface | |
iface.launch(debug=True , share=True) | |