| | """Knowledge base with RAG capabilities""" |
| |
|
| | import os |
| | from typing import List |
| | from langchain_openai import OpenAIEmbeddings |
| | from langchain_community.vectorstores import FAISS |
| | from langchain_community.document_loaders import PyPDFLoader |
| | from langchain_text_splitters import RecursiveCharacterTextSplitter |
| | from langchain_core.documents import Document |
| | from config import logger_knowledge |
| |
|
| |
|
| | class KnowledgeBase: |
| | """Knowledge base with FAISS vector store for RAG capabilities""" |
| | |
| | def __init__(self, pdf_path: str, index_path: str, embedding_model: str = "text-embedding-3-small", top_k: int = 2, recreate_index: bool = False): |
| | """ |
| | Initialize knowledge base with FAISS vector store |
| | |
| | Args: |
| | pdf_path: Path to the PDF document |
| | index_path: Path to save/load the FAISS index |
| | embedding_model: OpenAI embedding model to use |
| | top_k: Number of documents to retrieve |
| | recreate_index: Whether to recreate the FAISS index from scratch |
| | """ |
| | self.pdf_path = pdf_path |
| | self.index_path = index_path |
| | self.top_k = top_k |
| | |
| | logger_knowledge.info(f"Initializing KnowledgeBase with embedding_model={embedding_model}, top_k={top_k}") |
| | logger_knowledge.debug(f"PDF path: {pdf_path}") |
| | logger_knowledge.debug(f"Index path: {index_path}") |
| | |
| | logger_knowledge.info(f"Loading OpenAI embeddings model: {embedding_model}") |
| | self.embeddings = OpenAIEmbeddings(model=embedding_model) |
| | self.vectorstore = self._load_or_create_index(recreate_index) |
| | |
| | def _load_or_create_index(self, recreate: bool = False) -> FAISS: |
| | """Load existing FAISS index or create new one from PDF""" |
| | |
| | if not recreate and os.path.exists(self.index_path): |
| | logger_knowledge.info(f"Loading existing FAISS index from {self.index_path}") |
| | try: |
| | vectorstore = FAISS.load_local( |
| | self.index_path, |
| | self.embeddings, |
| | allow_dangerous_deserialization=True |
| | ) |
| | logger_knowledge.info("FAISS index loaded successfully") |
| | return vectorstore |
| | except Exception as e: |
| | logger_knowledge.error(f"Failed to load FAISS index: {str(e)}") |
| | raise |
| | |
| | |
| | logger_knowledge.info(f"Creating new FAISS index from {self.pdf_path}") |
| | |
| | |
| | if recreate and os.path.exists(self.index_path): |
| | import shutil |
| | try: |
| | shutil.rmtree(self.index_path) |
| | logger_knowledge.info("Removed old index directory") |
| | except Exception as e: |
| | logger_knowledge.warning(f"Could not remove old index: {e}") |
| | |
| | |
| | if not os.path.exists(self.pdf_path): |
| | error_msg = f"PDF file not found: {self.pdf_path}" |
| | logger_knowledge.error(error_msg) |
| | raise FileNotFoundError(error_msg) |
| | |
| | logger_knowledge.info(f"Loading PDF from {self.pdf_path}") |
| | try: |
| | loader = PyPDFLoader(self.pdf_path) |
| | documents = loader.load() |
| | logger_knowledge.info(f"Loaded {len(documents)} pages from PDF") |
| | except Exception as e: |
| | logger_knowledge.error(f"Failed to load PDF: {str(e)}") |
| | raise |
| | |
| | |
| | logger_knowledge.info("Splitting documents into chunks") |
| | text_splitter = RecursiveCharacterTextSplitter( |
| | chunk_size=800, |
| | chunk_overlap=150, |
| | length_function=len, |
| | separators=["\n\n", "\n", ". ", ", ", " ", ""] |
| | ) |
| | chunks = text_splitter.split_documents(documents) |
| | logger_knowledge.info(f"Split into {len(chunks)} chunks") |
| | |
| | |
| | logger_knowledge.info("Creating FAISS vector store from chunks") |
| | try: |
| | vectorstore = FAISS.from_documents(chunks, self.embeddings) |
| | logger_knowledge.info("FAISS vector store created successfully") |
| | except Exception as e: |
| | logger_knowledge.error(f"Failed to create FAISS vector store: {str(e)}") |
| | raise |
| | |
| | |
| | try: |
| | vectorstore.save_local(self.index_path) |
| | logger_knowledge.info(f"Saved FAISS index to {self.index_path}") |
| | except Exception as e: |
| | logger_knowledge.error(f"Failed to save FAISS index: {str(e)}") |
| | raise |
| | |
| | return vectorstore |
| | |
| | def retrieve_relevant_docs(self, query: str, k: int = None) -> List[Document]: |
| | """ |
| | Retrieve relevant documents for a query |
| | |
| | Args: |
| | query: User question |
| | k: Number of documents to retrieve (uses top_k if not specified) |
| | |
| | Returns: |
| | List of relevant document chunks |
| | """ |
| | if not self.vectorstore: |
| | logger_knowledge.error("Vector store not initialized!") |
| | return [] |
| | |
| | k = k or self.top_k |
| | logger_knowledge.debug(f"Retrieving top {k} documents for query") |
| | |
| | try: |
| | results = self.vectorstore.similarity_search(query, k=k) |
| | logger_knowledge.info(f"Retrieved {len(results)} documents") |
| | return results |
| | except Exception as e: |
| | logger_knowledge.error(f"Document retrieval failed: {str(e)}") |
| | raise |
| | |
| | def retrieve_relevant(self, query: str, k: int = None) -> str: |
| | """ |
| | Retrieve relevant documents as formatted string with metadata |
| | |
| | Args: |
| | query: User question |
| | k: Number of documents to retrieve (uses top_k if not specified) |
| | |
| | Returns: |
| | Concatenated text from relevant documents with metadata |
| | """ |
| | logger_knowledge.info(f"Retrieving context for query: {query[:50]}..." if len(query) > 50 else f"Retrieving context for query: {query}") |
| | |
| | docs = self.retrieve_relevant_docs(query, k) |
| | |
| | if not docs: |
| | logger_knowledge.warning("No documents retrieved for query") |
| | return "" |
| | |
| | formatted_chunks = [] |
| | for i, doc in enumerate(docs, 1): |
| | chunk_text = f"--- Chunk {i} ---" |
| | |
| | |
| | if doc.metadata: |
| | metadata_str = ", ".join([f"{k}: {v}" for k, v in doc.metadata.items()]) |
| | chunk_text += f"\nMetadata: {metadata_str}" |
| | logger_knowledge.debug(f"Chunk {i} metadata: {doc.metadata}") |
| | |
| | |
| | content_preview = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content |
| | logger_knowledge.debug(f"Chunk {i} content preview: {content_preview}") |
| | chunk_text += f"\n\n{doc.page_content}" |
| | formatted_chunks.append(chunk_text) |
| | |
| | total_length = sum(len(chunk) for chunk in formatted_chunks) |
| | logger_knowledge.info(f"Formatted {len(formatted_chunks)} chunks, total length: {total_length} characters") |
| | |
| | return "\n\n".join(formatted_chunks) |
| |
|