summarize / utils /create_vectordb.py
Pasindu599's picture
Add save_summary and get_summaries endpoints to FastAPI app; refactor create_chroma_db to handle single document input
51a3d33
import os
from typing import Optional, List
import chromadb
from chromadb.utils import embedding_functions
from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from dotenv import load_dotenv
import google.generativeai as genai
load_dotenv()
# Configure paths
CORPUS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "corpus")
DB_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "vectordb")
# Ensure directories exist
os.makedirs(CORPUS_DIR, exist_ok=True)
os.makedirs(DB_DIR, exist_ok=True)
def load_documents(corpus_dir: str = CORPUS_DIR) -> List:
"""Load documents from the corpus directory."""
if not os.path.exists(corpus_dir):
raise FileNotFoundError(f"Corpus directory not found: {corpus_dir}")
print(f"Loading documents from {corpus_dir}...")
# Initialize loaders for different file types
loaders = {
# "txt": DirectoryLoader(corpus_dir, glob="**/*.txt", loader_cls=TextLoader),
"pdf": DirectoryLoader(corpus_dir, glob="**/*.pdf", loader_cls=PyPDFLoader),
# "docx": DirectoryLoader(corpus_dir, glob="**/*.docx", loader_cls=Docx2txtLoader),
}
documents = []
for file_type, loader in loaders.items():
try:
docs = loader.load()
print(f"Loaded {len(docs)} {file_type} documents")
documents.extend(docs)
except Exception as e:
print(f"Error loading {file_type} documents: {e}")
return documents
def split_documents(documents, chunk_size=1000, chunk_overlap=200):
"""Split documents into chunks."""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
)
splits = text_splitter.split_documents(documents)
print(f"Split {len(documents)} documents into {len(splits)} chunks")
return splits
def create_chroma_db_and_document(document, collection_name="corpus_collection", db_dir=DB_DIR):
"""Create a Chroma vector database from documents."""
# Initialize the Gemini embedding function
gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction(
api_key=os.getenv("GOOGLE_API_KEY"),
model_name="models/embedding-001"
)
# Initialize Chroma client
client = chromadb.PersistentClient(path=db_dir)
# Create or get collection
try:
collection = client.get_collection(name=collection_name)
print(f"Using existing collection: {collection_name}")
except:
collection = client.create_collection(
name=collection_name,
embedding_function=gemini_ef
)
print(f"Created new collection: {collection_name}")
try:
collection.add(
documents = [document.page_content],
ids = [document.id]
)
print("Document added to collection successfully.")
return True
except Exception as e:
print(f"Error adding document to collection: {e}")
return False
def query_chroma_db(query: str, collection_name="corpus_collection", n_results=5, db_dir=DB_DIR):
"""Query the Chroma vector database."""
# Initialize the Gemini embedding function
gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction(
api_key=os.getenv("GOOGLE_API_KEY"),
model_name="models/embedding-001"
)
# Initialize Chroma client
client = chromadb.PersistentClient(path=db_dir)
# Get collection
collection = client.get_collection(name=collection_name, embedding_function=gemini_ef)
# Query collection
results = collection.query(
query_texts=[query],
n_results=n_results
)
return results
def main():
"""Main function to create and test the vector database."""
print("Starting vector database creation...")
# Load documents
documents = load_documents()
if not documents:
print("No documents found in corpus directory. Please add documents to proceed.")
return
# Split documents
splits = split_documents(documents)
# Create vector database
collection = create_chroma_db(splits)
# Test query
test_query = "What is this corpus about?"
print(f"\nTesting query: '{test_query}'")
results = query_chroma_db(test_query)
print(f"Found {len(results['documents'][0])} matching documents")
for i, (doc, metadata) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
print(f"\nResult {i+1}:")
print(f"Document: {doc[:150]}...")
print(f"Source: {metadata.get('source', 'Unknown')}")
print("\nVector database creation and testing complete!")
if __name__ == "__main__":
main()