Spaces:
Running
Running
import gradio as gr | |
import json | |
import os | |
import pdfplumber | |
import together | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import re | |
import unicodedata | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Set up Together.AI API Key (Replace with your actual key) | |
assert os.getenv("TOGETHER_API_KEY"), "api key missing" | |
# Use a sentence transformer for embeddings | |
#'BAAI/bge-base-en-v1.5' | |
# embedding_model = SentenceTransformer("BAAI/bge-base-en-v1.5") | |
# 'togethercomputer/m2-bert-80M-8k-retrieval' | |
embedding_model = SentenceTransformer( | |
"togethercomputer/m2-bert-80M-8k-retrieval", | |
trust_remote_code=True # Allow remote code execution | |
) | |
# Define dataset storage folder | |
DATASET_DIR = "/home/user/.cache/huggingface/datasets/my_documents" | |
os.makedirs(DATASET_DIR, exist_ok=True) # Ensure directory exists | |
# Define file paths inside dataset folder | |
INDEX_FILE = os.path.join(DATASET_DIR, "faiss_index.bin") # FAISS index file | |
METADATA_FILE = os.path.join(DATASET_DIR, "metadata.json") # Metadata file | |
embedding_dim = 768 # Adjust according to model | |
# Initialize FAISS index | |
index = faiss.IndexFlatL2(embedding_dim) | |
# Debugging: Check working directory and available files | |
print("Current working directory:", os.getcwd()) | |
print("Files in dataset directory:", os.listdir(DATASET_DIR)) | |
# Load FAISS index if it exists | |
if os.path.exists(INDEX_FILE): | |
print(" FAISS index file exists") | |
index = faiss.read_index(INDEX_FILE) | |
else: | |
print(" No FAISS index found. Creating a new one.") | |
index = faiss.IndexFlatL2(embedding_dim) # Empty FAISS index | |
# Load metadata | |
if os.path.exists(METADATA_FILE): | |
print(" Metadata file exists") | |
with open(METADATA_FILE, "r") as f: | |
metadata = json.load(f) | |
else: | |
metadata = {} | |
def store_document(text): | |
print(" Storing document...") | |
# Generate a unique filename inside the dataset folder | |
doc_id = len(metadata) + 1 | |
filename = os.path.join(DATASET_DIR, f"doc_{doc_id}.txt") | |
print(f"Saving document at: {filename}") | |
# Save document to file | |
with open(filename, "w", encoding="utf-8") as f: | |
f.write(text) | |
print(" Document saved") | |
# Generate and store embedding | |
embedding = embedding_model.encode([text]).astype(np.float32) | |
index.add(embedding) # Add to FAISS index | |
print(" Embeddings generated") | |
# Get FAISS index for the new document | |
doc_index = index.ntotal - 1 | |
# Update metadata with FAISS index | |
metadata[str(doc_index)] = filename | |
with open(METADATA_FILE, "w") as f: | |
json.dump(metadata, f) | |
print(" Saved Metadata") | |
# Save FAISS index | |
faiss.write_index(index, INDEX_FILE) | |
print(" FAISS index saved") | |
return f"Document stored at: {filename}" | |
def retrieve_document(query): | |
print(f"Retrieving document based on:\n{query}") | |
# Generate query embedding | |
query_embedding = embedding_model.encode([query]).astype(np.float32) | |
# Search for the closest document in FAISS index | |
_, closest_idx = index.search(query_embedding, 1) | |
# Check if a relevant document was found | |
if closest_idx[0][0] == -1 or str(closest_idx[0][0]) not in metadata: | |
print("No relevant document found") | |
return None | |
# Retrieve the document file path | |
filename = metadata[str(closest_idx[0][0])] | |
# Read and return the document content | |
with open(filename, "r", encoding="utf-8") as f: | |
return f.read() | |
def clean_text(text): | |
"""Cleans extracted text for better processing by the model.""" | |
print("cleaning") | |
text = unicodedata.normalize("NFKC", text) # Normalize Unicode characters | |
text = re.sub(r'\s+', ' ', text).strip() # Remove extra spaces and newlines | |
text = re.sub(r'[^a-zA-Z0-9.,!?;:\'\"()\-]', ' ', text) # Keep basic punctuation | |
text = re.sub(r'(?i)(page\s*\d+)', '', text) # Remove page numbers | |
return text | |
def extract_text_from_pdf(pdf_file): | |
"""Extract and clean text from the uploaded PDF.""" | |
print("extracting") | |
try: | |
with pdfplumber.open(pdf_file) as pdf: | |
text = " ".join(clean_text(text) for page in pdf.pages if (text := page.extract_text())) | |
store_document(text) | |
return text | |
except Exception as e: | |
print(f"Error extracting text: {e}") | |
return None | |
def split_text(text, chunk_size=500): | |
"""Splits text into smaller chunks for better processing.""" | |
print("splitting") | |
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)] | |
def chatbot(pdf_file, user_question): | |
"""Processes the PDF and answers the user's question.""" | |
print("chatbot start") | |
if pdf_file: | |
# Extract text from the PDF | |
text = extract_text_from_pdf(pdf_file) | |
if not text: | |
return "Could not extract any text from the PDF." | |
# retrieve the document relevant to the query | |
doc = retrieve_document(user_question) | |
if doc: | |
print(f"found doc{doc}") | |
# Split into smaller chunks | |
chunks = split_text(doc) | |
# Use only the first chunk (to optimize token usage) | |
prompt = f"Based on this document, answer the question:\n\nDocument:\n{chunks[0]}\n\nQuestion: {user_question}" | |
print(f"prompt: \n{prompt}") | |
else: | |
prompt=user_question | |
try: | |
print("asking") | |
response = together.Completion.create( | |
model="mistralai/Mistral-7B-Instruct-v0.1", | |
prompt=prompt, | |
max_tokens=200, | |
temperature=0.7, | |
) | |
# Return chatbot's response | |
return response.choices[0].text | |
except Exception as e: | |
return f"Error generating response: {e}" | |
# Send to Together.AI (Mistral-7B) | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=chatbot, | |
inputs=[gr.File(label="Upload PDF"), gr.Textbox(label="Ask a Question")], | |
outputs=gr.Textbox(label="Answer"), | |
title="PDF Q&A Chatbot (Powered by Together.AI)" | |
) | |
# Launch Gradio app | |
iface.launch() |