|
import chromadb |
|
from chromadb.config import Settings |
|
from typing import List |
|
import os |
|
from gemini_embedding import GeminiEmbeddingFunction |
|
import datetime |
|
embedding_function = GeminiEmbeddingFunction() |
|
|
|
def create_chroma_db(documents: List[str]): |
|
""" |
|
Creates a persistent Chroma database using the provided documents. |
|
""" |
|
|
|
persist_directory = "chroma_db" |
|
os.makedirs(persist_directory, exist_ok=True) |
|
|
|
|
|
chroma_client = chromadb.PersistentClient( |
|
path=persist_directory, |
|
) |
|
|
|
|
|
try: |
|
|
|
db = chroma_client.get_collection( |
|
name="document_collection", |
|
embedding_function=embedding_function |
|
) |
|
|
|
db.delete(db.get()["ids"]) |
|
except: |
|
|
|
db = chroma_client.create_collection( |
|
name="document_collection", |
|
embedding_function=embedding_function |
|
) |
|
|
|
|
|
batch_size = 20 |
|
for i in range(0, len(documents), batch_size): |
|
batch = documents[i:i + batch_size] |
|
db.add( |
|
documents=batch, |
|
ids=[f"doc_{j}" for j in range(i, i + len(batch))] |
|
) |
|
|
|
return db |
|
|
|
def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]: |
|
start_time = datetime.datetime.now() |
|
print(f"{start_time}: Starting ChromaDB query for question: {query[:50]}...") |
|
|
|
try: |
|
results = db.query( |
|
query_texts=[query], |
|
n_results=min(n_results, db.count()), |
|
include=['documents', 'distances'] |
|
) |
|
end_time = datetime.datetime.now() |
|
print(f"{end_time}: ChromaDB query finished. Time taken: {end_time - start_time}") |
|
|
|
|
|
|
|
if not results or 'documents' not in results or not results['documents'] or not results['documents'][0]: |
|
print("No relevant passages found.") |
|
return [] |
|
|
|
|
|
documents = results['documents'][0] |
|
distances = results['distances'][0] |
|
|
|
|
|
print(f"Number of relevant passages retrieved: {len(documents)}") |
|
for i, (doc, distance) in enumerate(zip(documents, distances)): |
|
similarity = 1 - distance |
|
print(f"Passage {i+1} (Similarity: {similarity:.4f}): {doc[:100]}...") |
|
|
|
return documents |
|
except Exception as e: |
|
print(f"Error in get_relevant_passage: {str(e)}") |
|
return [] |
|
|
|
|