Nadezhda Komarova
first commit
4be6b01
import os
import pathlib
import gradio as gr
# LangChain imports
from langchain_community.document_loaders import (
CSVLoader, PyPDFLoader, UnstructuredWordDocumentLoader,
UnstructuredPowerPointLoader, UnstructuredMarkdownLoader,
UnstructuredHTMLLoader, NotebookLoader
)
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.messages import HumanMessage, SystemMessage
# -------------------------
# UTF-8 safe Text Loader
# -------------------------
class SafeTextLoader:
"""Loads a text file as a single Document, safely handling UTF-8 decoding errors."""
def __init__(self, file_path):
self.file_path = file_path
print(f"[Debug] Initialized SafeTextLoader with file_path: {file_path}")
def load(self):
try:
print(f"[Debug] Attempting to load file: {self.file_path}")
with open(self.file_path, "rb") as f: # open in binary mode
raw_bytes = f.read()
text = raw_bytes.decode("utf-8", errors="ignore") # decode safely
print(f"[Debug] Successfully loaded file: {self.file_path}")
return [Document(page_content=text, metadata={"source": str(self.file_path)})]
except Exception as e:
print(f"[Error] Failed to read {self.file_path}: {e}")
return []
# -------------------------
# Loader mapping
# -------------------------
LOADER_MAPPING = {
# Text
".txt": SafeTextLoader,
".json": SafeTextLoader,
".md": UnstructuredMarkdownLoader,
".csv": CSVLoader,
".yaml": SafeTextLoader,
".yml": SafeTextLoader,
# Documents
".pdf": PyPDFLoader,
".docx": UnstructuredWordDocumentLoader,
".pptx": UnstructuredPowerPointLoader,
".html": UnstructuredHTMLLoader,
".htm": UnstructuredHTMLLoader,
# Code / Notebook
".ipynb": NotebookLoader,
".py": SafeTextLoader,
".js": SafeTextLoader,
".sql": SafeTextLoader,
}
# -------------------------
# Dataset creation
# -------------------------
def create_dataset(directory_path: str = "context"):
"""Loads all supported files from the given directory (recursively)."""
print(f"[Debug] Starting dataset creation for directory: {directory_path}")
target_dir = pathlib.Path(directory_path).resolve()
if not target_dir.exists() or not target_dir.is_dir():
print(f"[Error] Target directory does not exist: {target_dir}")
return []
documents = []
for file_path in target_dir.rglob("*"): # recursive
if not file_path.is_file():
continue
ext = file_path.suffix.lower()
loader_cls = LOADER_MAPPING.get(ext)
if loader_cls is None:
print(f"[Skip] Unsupported file type: {file_path}")
continue
try:
print(f"[Debug] Loading file: {file_path}")
loader = loader_cls(str(file_path))
docs = loader.load()
documents.extend(docs)
print(f"[Loaded] {file_path} ({len(docs)} docs)")
except Exception as e:
print(f"[Error] Failed to load {file_path}: {e}")
print(f"[Done] Finished scanning {target_dir}")
print(f"Total documents loaded: {len(documents)}")
return documents
# -------------------------
# Prepare RAG (Ollama + FAISS)
# -------------------------
def prepare_RAG(dir_name="context", chunk_size=600, chunk_overlap=50):
print(f"[Debug] Preparing RAG with Ollama + FAISS. Context dir={dir_name}")
documents = create_dataset(dir_name)
if not documents:
raise ValueError("No documents loaded. Please add files to the context directory.")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
print(f"[Debug] Splitting documents into chunks with chunk_size={chunk_size}, chunk_overlap={chunk_overlap}")
all_splits = text_splitter.split_documents(documents)
print(f"[Debug] Number of chunks created: {len(all_splits)}")
# Ollama embeddings (local)
print(f"[Debug] Initializing Ollama embeddings")
embeddings = OllamaEmbeddings(model="nomic-embed-text")
# FAISS vector store
print(f"[Debug] Creating FAISS vector store")
vectorstore = FAISS.from_documents(all_splits, embeddings)
# Ollama LLM
print(f"[Debug] Initializing Ollama LLM")
llm = ChatOllama(model="llama3") # change model if needed
return vectorstore, llm
# -------------------------
# Retrieval
# -------------------------
def retrieve_RAG(query, vectorstore, top_k=5):
print(f"[Debug] Retrieving top {top_k} documents for query: {query}")
retriever = vectorstore.as_retriever(search_kwargs={"k": top_k})
results = retriever.get_relevant_documents(query)
print(f"[Debug] Retrieved {len(results)} documents")
return results
# -------------------------
# Generation
# -------------------------
def generate_RAG(prompt_message, llm, retrieved_docs):
print(f"[Debug] Generating response for prompt: {prompt_message}")
context_message = (
"You are an expert assistant. Use ONLY the provided context documents "
"to answer the question. If the context does not contain the answer, reply with 'I don’t know'."
)
context_text = "\n\n".join([d.page_content for d in retrieved_docs])
print(f"[Debug] Context for generation: {context_text[:500]}... (truncated)")
prompt = [
SystemMessage(content=context_message),
HumanMessage(content=f"Context:\n{context_text}\n\nQuestion: {prompt_message}")
]
response = llm.invoke(prompt)
print(f"[Debug] Generated response: {response.content}")
return response
# -------------------------
# Gradio Chatbot
# -------------------------
def run_chatbot(user_dir="context"):
print(f"[Debug] Starting chatbot with user_dir: {user_dir}")
vectorstore, llm = prepare_RAG(dir_name=user_dir)
# Step 1: Add user message
def add_user_message(message, history):
print(f"[Debug] Adding user message: {message}")
history = history or []
history.append({"role": "user", "content": message})
return "", history, history
# Step 2: Generate bot response
def generate_bot_response(history):
if not history or history[-1]["role"] != "user":
print(f"[Debug] No user message to respond to.")
return history, history
user_msg = history[-1]["content"]
print(f"[Debug] Generating response for user message: {user_msg}")
retrieved_docs = retrieve_RAG(user_msg, vectorstore)
response = generate_RAG(user_msg, llm, retrieved_docs)
history.append({"role": "assistant", "content": response.content})
return history, history
with gr.Blocks() as demo:
gr.Markdown("# 📚 On-Prem RAG Chatbot (Ollama + FAISS)")
gr.Markdown("Ask questions about your local documents.")
chatbot = gr.Chatbot(type="messages")
msg = gr.Textbox(label="Your message")
state = gr.State([])
msg.submit(add_user_message, inputs=[msg, state], outputs=[msg, chatbot, state]) \
.then(generate_bot_response, inputs=[state], outputs=[chatbot, state])
demo.launch()
# -------------------------
# Main
# -------------------------
if __name__ == "__main__":
user_input = input("Enter a subfolder inside 'context' (press Enter for full 'context'): ").strip()
if not user_input:
user_dir = "context"
else:
user_dir = os.path.join("context", user_input)
print(f"[Info] Using context directory: {user_dir}")
run_chatbot(user_dir)