Multi_Modal_RAG / split.py
Sameer-Handsome173's picture
Upload 2 files
2d99efe verified
import os
import json
import uuid
import requests
import base64
import fitz # PyMuPDF
from fastapi import FastAPI, UploadFile, File
from pypdf import PdfReader
import pdfplumber
from PIL import Image
import io
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_core.documents import Document
# ================= JSON File Store =================
class JSONFileStore:
def __init__(self, store_path: str):
self.store_path = store_path
os.makedirs(self.store_path, exist_ok=True)
def mset(self, key_value_pairs: list[tuple[str, Document]]) -> None:
for key, doc in key_value_pairs:
file_path = os.path.join(self.store_path, f"{key}.json")
doc_dict = {"page_content": doc.page_content, "metadata": doc.metadata}
with open(file_path, "w", encoding="utf-8") as f:
json.dump(doc_dict, f, ensure_ascii=False)
def mget(self, keys: list[str]) -> list[Document]:
documents = []
for key in keys:
file_path = os.path.join(self.store_path, f"{key}.json")
if os.path.exists(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
doc_dict = json.load(f)
documents.append(
Document(
page_content=doc_dict["page_content"],
metadata=doc_dict["metadata"],
)
)
except Exception as e:
print(f"Error loading {key}: {e}")
documents.append(None)
else:
documents.append(None)
return documents
# ================= FastAPI Setup =================
app = FastAPI(title="πŸš€ Multimodal RAG Ingestion Service (Text + Tables + Images)")
VECTOR_PATH = "./vectorstore/faiss_index"
DOCSTORE_PATH = "./docstore"
TEMP_DOCS_PATH = "./docs"
QWEN_TEXT_URL = "https://sameer-handsome173-multi-modal.hf.space/summarize_qwen"
BLIP_IMAGE_URL = "https://sameer-handsome173-multi-modal.hf.space/summarize_smol"
print("πŸ”„ Loading embedding model...")
embedding_fn = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
print("βœ… Embedding model loaded")
# Load or create vectorstore
if os.path.exists(VECTOR_PATH):
vectorstore = FAISS.load_local(
VECTOR_PATH, embedding_fn, allow_dangerous_deserialization=True
)
print("βœ… Loaded existing FAISS vectorstore")
else:
os.makedirs(os.path.dirname(VECTOR_PATH), exist_ok=True)
vectorstore = FAISS.from_texts(["init"], embedding_fn)
print("βœ… Created new FAISS vectorstore")
# Initialize JSON store
os.makedirs(DOCSTORE_PATH, exist_ok=True)
store = JSONFileStore(DOCSTORE_PATH)
print("βœ… Initialized JSONFileStore")
# ================= Extraction Functions =================
def extract_tables_from_pdf(pdf_path: str) -> list[str]:
tables = []
try:
with pdfplumber.open(pdf_path) as pdf:
for page_num, page in enumerate(pdf.pages):
page_tables = page.extract_tables()
if page_tables:
for table_idx, table in enumerate(page_tables):
table_str = f"Table from page {page_num + 1}:\n"
for row in table:
if row:
table_str += " | ".join(
[str(cell) if cell else "" for cell in row]
) + "\n"
tables.append(table_str)
print(f"πŸ“Š Extracted table from page {page_num + 1}")
except Exception as e:
print(f"⚠️ Error extracting tables: {e}")
return tables
def extract_text_from_pdf(pdf_path: str) -> list[dict]:
"""Extract text per page"""
texts = []
try:
reader = PdfReader(pdf_path)
for i, page in enumerate(reader.pages):
text = page.extract_text()
if text and text.strip():
texts.append({"page": i + 1, "content": text.strip()})
print(f"πŸ“ Extracted text from page {i+1}")
except Exception as e:
print(f"❌ Error extracting text: {e}")
return texts
import hashlib
def extract_images_from_pdf(pdf_path: str) -> list[str]:
"""Extract large, unique images from PDF as base64"""
images_b64 = []
image_hashes = set()
try:
reader = PdfReader(pdf_path)
for page_num, page in enumerate(reader.pages):
if '/XObject' not in page['/Resources']:
continue
xObject = page['/Resources']['/XObject'].get_object()
for obj in xObject:
if xObject[obj]['/Subtype'] == '/Image':
try:
width = xObject[obj]['/Width']
height = xObject[obj]['/Height']
if width < 100 or height < 100:
continue # skip small images
data = xObject[obj].get_data()
h = hashlib.md5(data).hexdigest()
if h in image_hashes:
continue # skip duplicates
image_hashes.add(h)
mode = "RGB" if xObject[obj]['/ColorSpace'] == '/DeviceRGB' else "P"
image = Image.frombytes(mode, (width, height), data)
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
images_b64.append(img_b64)
print(f"πŸ“Έ Extracted image from page {page_num+1} ({width}x{height})")
except Exception as e:
print(f"⚠️ Error extracting image from page {page_num+1}: {e}")
except Exception as e:
print(f"❌ Error extracting images: {e}")
return images_b64
# ================= Summarization =================
def summarize_text(content: str) -> str:
try:
response = requests.post(
QWEN_TEXT_URL,
data={"prompt": f"Summarize the following content:\n\n{content}"},
timeout=30,
)
if response.status_code == 200:
return response.json().get("response", content[:200])
else:
return content[:200]
except Exception as e:
print(f"⚠️ Text summary fallback: {e}")
return content[:200]
def summarize_image(image_b64: str) -> str:
try:
image_bytes = base64.b64decode(image_b64)
files = {"image": ("image.jpg", image_bytes, "image/jpeg")}
data = {"text": "Describe this image in detail"}
response = requests.post(BLIP_IMAGE_URL, files=files, data=data, timeout=30)
if response.status_code == 200:
return response.json().get("response", "No image summary generated")
return "Image extracted from PDF"
except Exception as e:
print(f"⚠️ Image summary fallback: {e}")
return "Image extracted from PDF"
# ================= FastAPI Endpoints =================
@app.get("/")
def home():
return {
"message": "βœ… Multimodal RAG Ingestion Service is running",
"endpoints": {
"ingest": "POST /ingest - Upload PDF file",
"stats": "GET /stats - View system statistics",
},
}
@app.get("/stats")
def get_stats():
vector_count = (
vectorstore.index.ntotal if hasattr(vectorstore, "index") else 0
)
docstore_files = (
len([f for f in os.listdir(DOCSTORE_PATH) if f.endswith(".json")])
if os.path.exists(DOCSTORE_PATH)
else 0
)
return {
"status": "healthy",
"vectorstore_count": vector_count,
"docstore_count": docstore_files,
}
@app.post("/ingest")
async def ingest_pdf(file: UploadFile = File(...)):
if not file.filename.endswith(".pdf"):
return {"error": "Only PDF files are supported"}
os.makedirs(TEMP_DOCS_PATH, exist_ok=True)
temp_path = os.path.join(TEMP_DOCS_PATH, file.filename)
with open(temp_path, "wb") as f:
content = await file.read()
f.write(content)
print(f"\nπŸ“„ Processing {file.filename}...")
texts = extract_text_from_pdf(temp_path)
images = extract_images_from_pdf(temp_path)
tables = extract_tables_from_pdf(temp_path)
print(f"πŸ“Š Found: {len(texts)} texts, {len(tables)} tables, {len(images)} images")
if not texts and not tables and not images:
return {"error": "No content extracted", "filename": file.filename}
doc_ids, summaries, originals = [], [], []
# Texts
for i, item in enumerate(texts):
page_num = item["page"]
content = item["content"]
summary = summarize_text(content)
doc_id = str(uuid.uuid4())
doc_ids.append(doc_id)
summaries.append(summary)
originals.append(
Document(
page_content=content,
metadata={
"doc_id": doc_id,
"type": "text",
"page": page_num,
"source": file.filename,
"summary": summary,
},
)
)
# Tables
for table in tables:
summary = summarize_text(f"Table content:\n{table}")
doc_id = str(uuid.uuid4())
doc_ids.append(doc_id)
summaries.append(summary)
originals.append(
Document(
page_content=table,
metadata={
"doc_id": doc_id,
"type": "table",
"source": file.filename,
"summary": summary,
},
)
)
# Images
for i, item in enumerate(images):
page_num = item["page"]
img_b64 = item["image_b64"]
summary = summarize_image(img_b64)
doc_id = str(uuid.uuid4())
doc_ids.append(doc_id)
summaries.append(summary)
originals.append(
Document(
page_content=img_b64,
metadata={
"doc_id": doc_id,
"type": "image",
"page": page_num,
"source": file.filename,
"summary": summary,
"is_base64": True,
},
)
)
# Store
vectorstore.add_texts(
texts=summaries,
metadatas=[{"doc_id": doc_id, "source": file.filename} for doc_id in doc_ids],
ids=doc_ids,
)
store.mset(list(zip(doc_ids, originals)))
vectorstore.save_local(VECTOR_PATH)
print("βœ… Saved to disk")
os.remove(temp_path)
return {
"status": "success",
"filename": file.filename,
"processed": {
"texts": len(texts),
"tables": len(tables),
"images": len(images),
"total": len(originals),
},
"doc_ids_sample": doc_ids[:5],
"message": f"βœ… Processed {len(originals)} components from {file.filename}",
}