pdf-something / app.py
Penality's picture
Update app.py
a6dfbcd verified
raw
history blame
6.07 kB
import gradio as gr
import json
import os
import pdfplumber
import together
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import re
import unicodedata
from dotenv import load_dotenv
load_dotenv()
# Set up Together.AI API Key (Replace with your actual key)
assert os.getenv("TOGETHER_API_KEY"), "api key missing"
# Use a sentence transformer for embeddings
#'BAAI/bge-base-en-v1.5'
# embedding_model = SentenceTransformer("BAAI/bge-base-en-v1.5")
# 'togethercomputer/m2-bert-80M-8k-retrieval'
embedding_model = SentenceTransformer(
"togethercomputer/m2-bert-80M-8k-retrieval",
trust_remote_code=True # Allow remote code execution
)
# Define dataset storage folder
DATASET_DIR = "/home/user/.cache/huggingface/datasets/my_documents"
os.makedirs(DATASET_DIR, exist_ok=True) # Ensure directory exists
# Define file paths inside dataset folder
INDEX_FILE = os.path.join(DATASET_DIR, "faiss_index.bin") # FAISS index file
METADATA_FILE = os.path.join(DATASET_DIR, "metadata.json") # Metadata file
embedding_dim = 768 # Adjust according to model
# Initialize FAISS index
index = faiss.IndexFlatL2(embedding_dim)
# Debugging: Check working directory and available files
print("Current working directory:", os.getcwd())
print("Files in dataset directory:", os.listdir(DATASET_DIR))
# Load FAISS index if it exists
if os.path.exists(INDEX_FILE):
print(" FAISS index file exists")
index = faiss.read_index(INDEX_FILE)
else:
print(" No FAISS index found. Creating a new one.")
index = faiss.IndexFlatL2(embedding_dim) # Empty FAISS index
# Load metadata
if os.path.exists(METADATA_FILE):
print(" Metadata file exists")
with open(METADATA_FILE, "r") as f:
metadata = json.load(f)
else:
metadata = {}
def store_document(text):
print(" Storing document...")
# Generate a unique filename inside the dataset folder
doc_id = len(metadata) + 1
filename = os.path.join(DATASET_DIR, f"doc_{doc_id}.txt")
print(f"Saving document at: {filename}")
# Save document to file
with open(filename, "w", encoding="utf-8") as f:
f.write(text)
print(" Document saved")
# Generate and store embedding
embedding = embedding_model.encode([text]).astype(np.float32)
index.add(embedding) # Add to FAISS index
print(" Embeddings generated")
# Get FAISS index for the new document
doc_index = index.ntotal - 1
# Update metadata with FAISS index
metadata[str(doc_index)] = filename
with open(METADATA_FILE, "w") as f:
json.dump(metadata, f)
print(" Saved Metadata")
# Save FAISS index
faiss.write_index(index, INDEX_FILE)
print(" FAISS index saved")
return f"Document stored at: {filename}"
def retrieve_document(query):
print(f"Retrieving document based on:\n{query}")
# Generate query embedding
query_embedding = embedding_model.encode([query]).astype(np.float32)
# Search for the closest document in FAISS index
_, closest_idx = index.search(query_embedding, 1)
# Check if a relevant document was found
if closest_idx[0][0] == -1 or str(closest_idx[0][0]) not in metadata:
print("No relevant document found")
return None
# Retrieve the document file path
filename = metadata[str(closest_idx[0][0])]
# Read and return the document content
with open(filename, "r", encoding="utf-8") as f:
return f.read()
def clean_text(text):
"""Cleans extracted text for better processing by the model."""
print("cleaning")
text = unicodedata.normalize("NFKC", text) # Normalize Unicode characters
text = re.sub(r'\s+', ' ', text).strip() # Remove extra spaces and newlines
text = re.sub(r'[^a-zA-Z0-9.,!?;:\'\"()\-]', ' ', text) # Keep basic punctuation
text = re.sub(r'(?i)(page\s*\d+)', '', text) # Remove page numbers
return text
def extract_text_from_pdf(pdf_file):
"""Extract and clean text from the uploaded PDF."""
print("extracting")
try:
with pdfplumber.open(pdf_file) as pdf:
text = " ".join(clean_text(text) for page in pdf.pages if (text := page.extract_text()))
store_document(text)
return text
except Exception as e:
print(f"Error extracting text: {e}")
return None
def split_text(text, chunk_size=500):
"""Splits text into smaller chunks for better processing."""
print("splitting")
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
def chatbot(pdf_file, user_question):
"""Processes the PDF and answers the user's question."""
print("chatbot start")
if pdf_file:
# Extract text from the PDF
text = extract_text_from_pdf(pdf_file)
if not text:
return "Could not extract any text from the PDF."
# retrieve the document relevant to the query
doc = retrieve_document(user_question)
if doc:
print(f"found doc{doc}")
# Split into smaller chunks
chunks = split_text(doc)
# Use only the first chunk (to optimize token usage)
prompt = f"Based on this document, answer the question:\n\nDocument:\n{chunks[0]}\n\nQuestion: {user_question}"
print(f"prompt: \n{prompt}")
else:
prompt=user_question
try:
print("asking")
response = together.Completion.create(
model="mistralai/Mistral-7B-Instruct-v0.1",
prompt=prompt,
max_tokens=200,
temperature=0.7,
)
# Return chatbot's response
return response.choices[0].text
except Exception as e:
return f"Error generating response: {e}"
# Send to Together.AI (Mistral-7B)
# Gradio Interface
iface = gr.Interface(
fn=chatbot,
inputs=[gr.File(label="Upload PDF"), gr.Textbox(label="Ask a Question")],
outputs=gr.Textbox(label="Answer"),
title="PDF Q&A Chatbot (Powered by Together.AI)"
)
# Launch Gradio app
iface.launch()