Jaita's picture
Update kb_embed.py
740f68a verified
# kb_embed.py
from pathlib import Path
import os
from docx import Document
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
import logging
logging.basicConfig(level=logging.INFO)
BASE_DIR = Path(__file__).resolve().parent
CHROMA_DIR = BASE_DIR / "chroma_db"
MODEL_DIR = BASE_DIR / "all-MiniLM-L6-v2" # optional local cache
DOCS_DIR = BASE_DIR / "GenericSOPsForTesting"
CHROMA_DIR.mkdir(parents=True, exist_ok=True)
client = chromadb.PersistentClient(
path=str(CHROMA_DIR),
settings=Settings(anonymized_telemetry=False)
)
collection = client.get_or_create_collection(name="knowledge_base")
# Use default HF cache (simpler on Spaces). If you must use local folder, keep cache_folder.
try:
# Prefer auto-download and cache:
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# If you want to use local cache dir: uncomment
# model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=str(MODEL_DIR))
except Exception as e:
logging.exception(f"Failed to load embedding model: {e}")
raise
def extract_text_from_docx(file_path: str) -> str:
doc = Document(file_path)
return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
def chunk_text(text: str, max_words: int = 300):
words = text.split()
chunks = [" ".join(words[i:i + max_words]) for i in range(0, len(words), max_words)]
return [c for c in chunks if c.strip()]
def ingest_documents(folder_path: str):
logging.info(f"πŸ“‚ Checking folder: {folder_path}")
if not os.path.isdir(folder_path):
logging.warning(f"❌ Invalid folder path: {folder_path}")
return
files = [f for f in os.listdir(folder_path) if f.lower().endswith(".docx")]
logging.info(f"Found {len(files)} Word files: {files}")
if not files:
logging.warning("⚠️ No .docx files found. Please check the folder path.")
return
added = 0
for file in files:
file_path = os.path.join(folder_path, file)
text = extract_text_from_docx(file_path)
chunks = chunk_text(text)
if not chunks:
logging.warning(f"⚠️ No text chunks extracted from {file}")
continue
logging.info(f"πŸ“„ Ingesting {file} with {len(chunks)} chunks")
for i, chunk in enumerate(chunks):
embedding = model.encode(chunk).tolist()
doc_id = f"{file}_{i}"
# Avoid duplicate ids (if re-ingesting)
try:
collection.add(
ids=[doc_id],
embeddings=[embedding],
documents=[chunk],
metadatas=[{"filename": file, "chunk_index": i}]
)
added += 1
except Exception as e:
logging.warning(f"Skipping duplicate or failed add for {doc_id}: {e}")
logging.info(f"βœ… Documents ingested. Added entries: {added}. Total entries: {collection.count()}")
def search_knowledge_base(query: str, top_k: int = 3):
query_embedding = model.encode(query).tolist()
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=["documents", "metadatas", "distances"]
)
return results
def main():
ingest_documents(str(DOCS_DIR)) if DOCS_DIR.exists() else logging.error(f"❌ Invalid folder path: {DOCS_DIR}")
if __name__ == "__main__":
main()