Spaces:
Sleeping
Sleeping
import gradio as gr | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_community.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains import RetrievalQA | |
from langchain_community.llms import HuggingFaceHub | |
try: | |
from langchain_huggingface import HuggingFaceEndpoint | |
HUGGINGFACE_ENDPOINT_AVAILABLE = True | |
except ImportError: | |
HUGGINGFACE_ENDPOINT_AVAILABLE = False | |
print("langchain-huggingface not available, using fallback") | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
# You can use this section to suppress warnings generated by your code: | |
def warn(*args, **kwargs): | |
pass | |
import warnings | |
warnings.warn = warn | |
warnings.filterwarnings('ignore') | |
# Set your Hugging Face API token here. | |
# For deployment on Hugging Face, you can set this as an environment variable. | |
import os | |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = "hf_YOUR_HUGGINGFACE_TOKEN" | |
## LLM - Using an open-source model from Hugging Face | |
def get_llm(): | |
""" | |
Initializes and returns a Hugging Face Hub LLM model. | |
Using a conversational model suitable for legal advice. | |
""" | |
# Check if API token is properly set | |
api_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
if not api_token or api_token == "hf_YOUR_HUGGINGFACE_TOKEN": | |
raise ValueError("Please set a valid HUGGINGFACEHUB_API_TOKEN environment variable. You can get one from https://huggingface.co/settings/tokens") | |
# Try different models in order of preference | |
models_to_try = [ | |
("mistralai/Mixtral-8x7B-Instruct-v0.1", "text-generation"), | |
("microsoft/DialoGPT-medium", "text-generation"), | |
("google/flan-t5-base", "text2text-generation"), | |
("huggingface/CodeBERTa-small-v1", "text-generation") | |
] | |
for repo_id, task in models_to_try: | |
if HUGGINGFACE_ENDPOINT_AVAILABLE: | |
try: | |
llm = HuggingFaceEndpoint( | |
repo_id=repo_id, | |
max_length=512, | |
temperature=0.1, | |
huggingfacehub_api_token=api_token | |
) | |
print(f"Successfully initialized HuggingFaceEndpoint with {repo_id}") | |
return llm | |
except Exception as e: | |
print(f"HuggingFaceEndpoint with {repo_id} failed: {e}") | |
try: | |
llm = HuggingFaceHub( | |
repo_id=repo_id, | |
task=task, | |
model_kwargs={ | |
"temperature": 0.1, | |
"max_length": 512 | |
}, | |
huggingfacehub_api_token=api_token | |
) | |
print(f"Successfully initialized HuggingFaceHub with {repo_id}") | |
return llm | |
except Exception as e: | |
print(f"HuggingFaceHub with {repo_id} failed: {e}") | |
raise ValueError("All LLM initialization attempts failed. Please check your API token and internet connection.") | |
## Document loader | |
def document_loader(file_path): | |
""" | |
Loads a PDF document from the given file path. | |
""" | |
try: | |
loader = PyPDFLoader(file_path) | |
loaded_document = loader.load() | |
# Check if document was loaded successfully | |
if not loaded_document: | |
raise ValueError("No content could be extracted from the PDF") | |
print(f"Successfully loaded {len(loaded_document)} pages from PDF") | |
# Check if pages have content | |
total_content = sum(len(doc.page_content.strip()) for doc in loaded_document) | |
if total_content == 0: | |
raise ValueError("PDF appears to be empty or contains no extractable text") | |
print(f"Total content length: {total_content} characters") | |
return loaded_document | |
except Exception as e: | |
print(f"Error loading document: {e}") | |
raise ValueError(f"Failed to load PDF: {e}") | |
## Text splitter | |
def text_splitter(data): | |
""" | |
Splits the loaded document into smaller chunks for processing. | |
""" | |
try: | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200, | |
length_function=len, | |
separators=["\n\n", "\n", " ", ""] | |
) | |
chunks = text_splitter.split_documents(data) | |
# Filter out very small chunks | |
filtered_chunks = [chunk for chunk in chunks if len(chunk.page_content.strip()) > 50] | |
print(f"Created {len(filtered_chunks)} chunks (filtered from {len(chunks)} total)") | |
if not filtered_chunks: | |
raise ValueError("No meaningful content chunks could be created from the document") | |
return filtered_chunks | |
except Exception as e: | |
print(f"Error in text splitting: {e}") | |
raise ValueError(f"Failed to split document into chunks: {e}") | |
## Vector db and Embedding model | |
def vector_database(chunks): | |
""" | |
Creates a FAISS vector database from the document chunks using a | |
local Hugging Face embeddings model. | |
""" | |
try: | |
# Using local embeddings model (more reliable than API-based) | |
embedding_model = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={'device': 'cpu'}, # Use CPU for compatibility | |
encode_kwargs={'normalize_embeddings': True} | |
) | |
print(f"Processing {len(chunks)} chunks for embedding...") | |
# Create vector database | |
vectordb = FAISS.from_documents(chunks, embedding_model) | |
print("Vector database created successfully!") | |
return vectordb | |
except Exception as e: | |
print(f"Error creating vector database: {e}") | |
print(f"Error type: {type(e)}") | |
# Try alternative approach with text extraction | |
try: | |
print("Trying alternative approach with text extraction...") | |
texts = [chunk.page_content for chunk in chunks] | |
metadatas = [chunk.metadata for chunk in chunks] | |
embedding_model = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={'device': 'cpu'} | |
) | |
vectordb = FAISS.from_texts(texts, embedding_model, metadatas=metadatas) | |
print("Alternative approach succeeded!") | |
return vectordb | |
except Exception as e2: | |
print(f"Alternative approach also failed: {e2}") | |
raise ValueError(f"Failed to create embeddings. Original error: {e}. Alternative error: {e2}") | |
## Retriever | |
def retriever(file_path): | |
""" | |
Loads, splits, and creates a retriever for the document. | |
""" | |
splits = document_loader(file_path) | |
chunks = text_splitter(splits) | |
# Add a check to ensure chunks are not empty | |
if not chunks: | |
raise ValueError("The uploaded document could not be processed. Please try another file.") | |
print(f"Created {len(chunks)} chunks from the document") | |
vectordb = vector_database(chunks) | |
retriever = vectordb.as_retriever() | |
return retriever | |
## QA Chain | |
def retriever_qa(file, query): | |
""" | |
Sets up a RetrievalQA chain to answer questions based on the document. | |
""" | |
# Check if a file was uploaded | |
if not file: | |
return "Please upload a valid PDF file before asking a question." | |
# Check if query is provided | |
if not query or query.strip() == "": | |
return "Please enter a question to get started." | |
# Use the file path from the Gradio file object | |
file_path = file.name if hasattr(file, 'name') else str(file) | |
try: | |
llm = get_llm() | |
retriever_obj = retriever(file_path) | |
# Simplified prompt - let the RetrievalQA chain handle the context properly | |
qa = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever_obj, | |
return_source_documents=True, | |
) | |
# Create a proper prompt for legal advice | |
legal_prompt = f"""Based on the document content, please provide professional legal guidance for the following question. | |
Be conversational, clear, and cite relevant sections when possible. | |
Question: {query} | |
Please provide a helpful and accurate response based on the document content.""" | |
response = qa.invoke({"query": legal_prompt}) | |
# Extract the result | |
result_text = response.get('result', 'No response generated.') | |
# Clean up the response if needed | |
if result_text.startswith("Legal Advisor's Answer:"): | |
result_text = result_text.replace("Legal Advisor's Answer:", "").strip() | |
return result_text | |
except ValueError as ve: | |
# Handle specific ValueError (like API token issues) | |
if "API token" in str(ve): | |
return f"Configuration Error: {ve}\n\nPlease:\n1. Get a HuggingFace API token from https://huggingface.co/settings/tokens\n2. Set it as HUGGINGFACEHUB_API_TOKEN environment variable" | |
else: | |
return f"Error: {ve}" | |
except Exception as e: | |
error_msg = str(e) | |
if "API token" in error_msg or "authentication" in error_msg.lower(): | |
return "Error: Please check your Hugging Face API token configuration." | |
elif "embedding" in error_msg.lower(): | |
return "Error: Failed to create document embeddings. Please try uploading a different PDF file." | |
elif "InferenceClient" in error_msg: | |
return "Error: HuggingFace library compatibility issue. Please try updating your dependencies or contact support." | |
else: | |
return f"Error processing your request: {error_msg}" | |
# Create Gradio interface with better error handling | |
def create_interface(): | |
""" | |
Creates and returns the Gradio interface | |
""" | |
interface = gr.Interface( | |
fn=retriever_qa, | |
allow_flagging="never", | |
inputs=[ | |
gr.File( | |
label="Upload PDF File", | |
file_count="single", | |
file_types=['.pdf'] | |
), | |
gr.Textbox( | |
label="Input Query", | |
lines=3, | |
placeholder="Type your legal question here...", | |
info="Ask questions about the uploaded document" | |
) | |
], | |
outputs=gr.Textbox( | |
label="Legal Advisor's Response", | |
lines=10, | |
max_lines=20 | |
), | |
title="Nigerian Constitution Legal Advisor Chatbot", | |
description=""" | |
Upload a PDF document (like the Nigerian Constitution) and ask legal questions about it. | |
The AI will analyze the document and provide contextual legal guidance. | |
**Note:** Make sure to set your Hugging Face API token in the environment variables. | |
""", | |
examples=[ | |
[None, "What are the fundamental rights guaranteed by this constitution?"], | |
[None, "What is the process for constitutional amendments?"], | |
[None, "What are the powers of the federal government?"] | |
] | |
) | |
return interface | |
# Launch the app | |
if __name__ == "__main__": | |
# Check if API token is set | |
if not os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") == "hf_YOUR_HUGGINGFACE_TOKEN": | |
print("WARNING: Please set your actual Hugging Face API token in the HUGGINGFACEHUB_API_TOKEN environment variable") | |
rag_application = create_interface() | |
rag_application.launch(share=True) |