Spaces:
Running
Running
Add save_summary and get_summaries endpoints to FastAPI app; refactor create_chroma_db to handle single document input
51a3d33
import os | |
from typing import Optional, List | |
import chromadb | |
from chromadb.utils import embedding_functions | |
from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from dotenv import load_dotenv | |
import google.generativeai as genai | |
load_dotenv() | |
# Configure paths | |
CORPUS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "corpus") | |
DB_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "vectordb") | |
# Ensure directories exist | |
os.makedirs(CORPUS_DIR, exist_ok=True) | |
os.makedirs(DB_DIR, exist_ok=True) | |
def load_documents(corpus_dir: str = CORPUS_DIR) -> List: | |
"""Load documents from the corpus directory.""" | |
if not os.path.exists(corpus_dir): | |
raise FileNotFoundError(f"Corpus directory not found: {corpus_dir}") | |
print(f"Loading documents from {corpus_dir}...") | |
# Initialize loaders for different file types | |
loaders = { | |
# "txt": DirectoryLoader(corpus_dir, glob="**/*.txt", loader_cls=TextLoader), | |
"pdf": DirectoryLoader(corpus_dir, glob="**/*.pdf", loader_cls=PyPDFLoader), | |
# "docx": DirectoryLoader(corpus_dir, glob="**/*.docx", loader_cls=Docx2txtLoader), | |
} | |
documents = [] | |
for file_type, loader in loaders.items(): | |
try: | |
docs = loader.load() | |
print(f"Loaded {len(docs)} {file_type} documents") | |
documents.extend(docs) | |
except Exception as e: | |
print(f"Error loading {file_type} documents: {e}") | |
return documents | |
def split_documents(documents, chunk_size=1000, chunk_overlap=200): | |
"""Split documents into chunks.""" | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, | |
chunk_overlap=chunk_overlap, | |
length_function=len, | |
) | |
splits = text_splitter.split_documents(documents) | |
print(f"Split {len(documents)} documents into {len(splits)} chunks") | |
return splits | |
def create_chroma_db_and_document(document, collection_name="corpus_collection", db_dir=DB_DIR): | |
"""Create a Chroma vector database from documents.""" | |
# Initialize the Gemini embedding function | |
gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction( | |
api_key=os.getenv("GOOGLE_API_KEY"), | |
model_name="models/embedding-001" | |
) | |
# Initialize Chroma client | |
client = chromadb.PersistentClient(path=db_dir) | |
# Create or get collection | |
try: | |
collection = client.get_collection(name=collection_name) | |
print(f"Using existing collection: {collection_name}") | |
except: | |
collection = client.create_collection( | |
name=collection_name, | |
embedding_function=gemini_ef | |
) | |
print(f"Created new collection: {collection_name}") | |
try: | |
collection.add( | |
documents = [document.page_content], | |
ids = [document.id] | |
) | |
print("Document added to collection successfully.") | |
return True | |
except Exception as e: | |
print(f"Error adding document to collection: {e}") | |
return False | |
def query_chroma_db(query: str, collection_name="corpus_collection", n_results=5, db_dir=DB_DIR): | |
"""Query the Chroma vector database.""" | |
# Initialize the Gemini embedding function | |
gemini_ef = embedding_functions.GoogleGenerativeAiEmbeddingFunction( | |
api_key=os.getenv("GOOGLE_API_KEY"), | |
model_name="models/embedding-001" | |
) | |
# Initialize Chroma client | |
client = chromadb.PersistentClient(path=db_dir) | |
# Get collection | |
collection = client.get_collection(name=collection_name, embedding_function=gemini_ef) | |
# Query collection | |
results = collection.query( | |
query_texts=[query], | |
n_results=n_results | |
) | |
return results | |
def main(): | |
"""Main function to create and test the vector database.""" | |
print("Starting vector database creation...") | |
# Load documents | |
documents = load_documents() | |
if not documents: | |
print("No documents found in corpus directory. Please add documents to proceed.") | |
return | |
# Split documents | |
splits = split_documents(documents) | |
# Create vector database | |
collection = create_chroma_db(splits) | |
# Test query | |
test_query = "What is this corpus about?" | |
print(f"\nTesting query: '{test_query}'") | |
results = query_chroma_db(test_query) | |
print(f"Found {len(results['documents'][0])} matching documents") | |
for i, (doc, metadata) in enumerate(zip(results['documents'][0], results['metadatas'][0])): | |
print(f"\nResult {i+1}:") | |
print(f"Document: {doc[:150]}...") | |
print(f"Source: {metadata.get('source', 'Unknown')}") | |
print("\nVector database creation and testing complete!") | |
if __name__ == "__main__": | |
main() | |