Spaces:
Running
Running
import os | |
import fitz | |
import json | |
import gradio as gr | |
import pytesseract | |
import chromadb | |
import torch | |
import nltk | |
import traceback | |
import docx2txt | |
from PIL import Image | |
from io import BytesIO | |
from tqdm import tqdm | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
from sentence_transformers import SentenceTransformer, util | |
from nltk.tokenize import sent_tokenize | |
# Ensure punkt is downloaded | |
try: | |
nltk.data.find("tokenizers/punkt") | |
except LookupError: | |
nltk.download("punkt") | |
# Configuration | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
MANUALS_DIR = "Manuals" | |
CHROMA_PATH = "chroma_store" | |
COLLECTION_NAME = "manual_chunks" | |
CHUNK_SIZE = 750 | |
CHUNK_OVERLAP = 100 | |
MAX_CONTEXT_CHUNKS = 3 | |
MODEL_ID = "ibm-granite/granite-vision-3.2-2b" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# ---------------- Text Helpers ---------------- | |
def clean(text): | |
return "\n".join([line.strip() for line in text.splitlines() if line.strip()]) | |
def split_sentences(text): | |
try: | |
return sent_tokenize(text) | |
except: | |
print("Tokenizer fallback: simple split.") | |
return text.split(". ") | |
def split_chunks(sentences, max_tokens=CHUNK_SIZE, overlap=CHUNK_OVERLAP): | |
chunks = [] | |
current_chunk, length = [], 0 | |
for sent in sentences: | |
words = sent.split() | |
if length + len(words) > max_tokens and current_chunk: | |
chunks.append(" ".join(current_chunk)) | |
current_chunk = current_chunk[-overlap:] | |
length = sum(len(s.split()) for s in current_chunk) | |
current_chunk.append(sent) | |
length += len(words) | |
if current_chunk: | |
chunks.append(" ".join(current_chunk)) | |
return chunks | |
# ---------------- File Readers ---------------- | |
def extract_pdf_text(path): | |
chunks = [] | |
try: | |
doc = fitz.open(path) | |
for i, page in enumerate(doc): | |
text = page.get_text().strip() | |
if not text: | |
img = Image.open(BytesIO(page.get_pixmap(dpi=300).tobytes("png"))) | |
text = pytesseract.image_to_string(img) | |
chunks.append((path, i + 1, clean(text))) | |
except Exception as e: | |
print("PDF read error:", path, e) | |
return chunks | |
def extract_docx_text(path): | |
try: | |
return [(path, 1, clean(docx2txt.process(path)))] | |
except Exception as e: | |
print("DOCX read error:", path, e) | |
return [] | |
# ---------------- Embedding ---------------- | |
def embed_all(): | |
try: | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
embedder.eval() | |
client = chromadb.PersistentClient(path=CHROMA_PATH) | |
try: | |
client.delete_collection(COLLECTION_NAME) | |
except: | |
pass | |
collection = client.get_or_create_collection(COLLECTION_NAME) | |
docs, ids, metas = [], [], [] | |
print("Processing manuals...") | |
for fname in os.listdir(MANUALS_DIR): | |
fpath = os.path.join(MANUALS_DIR, fname) | |
if fname.lower().endswith(".pdf"): | |
pages = extract_pdf_text(fpath) | |
elif fname.lower().endswith(".docx"): | |
pages = extract_docx_text(fpath) | |
else: | |
continue | |
for path, page, text in pages: | |
for i, chunk in enumerate(split_chunks(split_sentences(text))): | |
chunk_id = f"{fname}::{page}::{i}" | |
docs.append(chunk) | |
ids.append(chunk_id) | |
metas.append({"source": fname, "page": page}) | |
if len(docs) >= 32: # Increased batch size for efficiency | |
embs = embedder.encode(docs).tolist() | |
collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs) | |
docs, ids, metas = [], [], [] | |
if docs: | |
embs = embedder.encode(docs).tolist() | |
collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs) | |
print(f"Embedded {len(ids)} chunks.") | |
return collection, embedder | |
except Exception as e: | |
print("Embedding startup failed:", e) | |
return None, None | |
# ---------------- Model Setup ---------------- | |
def load_model(): | |
try: | |
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN) | |
model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device) | |
return model, processor | |
except Exception as e: | |
print("Model loading failed:", e) | |
return None, None | |
def ask_model(question, context, model, processor): | |
prompt = f"""Use only the following context to answer. If uncertain, say \"I don't know.\" | |
<context> | |
{context} | |
</context> | |
Q: {question} | |
A:""" | |
inputs = processor(prompt, return_tensors="pt").to(device) | |
output = model.generate(**inputs) | |
return processor.decode(output[0], skip_special_tokens=True) | |
# ---------------- Query ---------------- | |
def get_answer(question): | |
if not embedder or not db or not model: | |
return "System not ready. Try again after initialization." | |
try: | |
query_emb = embedder.encode(question, convert_to_tensor=True) | |
results = db.query(query_texts=[question], n_results=MAX_CONTEXT_CHUNKS) | |
context = "\n\n".join(results["documents"][0]) | |
return ask_model(question, context, model, processor) | |
except Exception as e: | |
print("Query error:", e) | |
return f"Error: {e}" | |
# ---------------- UI ---------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("## SmartManuals-AI (Granite 3.2-2B)") | |
with gr.Row(): | |
question = gr.Textbox(label="Ask your question") | |
ask = gr.Button("Ask") | |
answer = gr.Textbox(label="Answer", lines=8) | |
ask.click(fn=get_answer, inputs=question, outputs=answer) | |
# Startup Initialization | |
embedder = None | |
model = None | |
processor = None | |
try: | |
db, embedder = embed_all() | |
except Exception as e: | |
print("❌ Embedding failed:", e) | |
try: | |
model, processor = load_model() | |
except Exception as e: | |
print("❌ Model load failed:", e) | |
# Launch | |
if __name__ == "__main__": | |
demo.launch(share=False) | |