File size: 2,983 Bytes
8953dfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import chromadb
from chromadb.config import Settings
from typing import List
import os
from gemini_embedding import GeminiEmbeddingFunction
import datetime
embedding_function = GeminiEmbeddingFunction()

def create_chroma_db(documents: List[str]):
    """
    Creates a persistent Chroma database using the provided documents.
    """
    # Create a persistent directory for ChromaDB
    persist_directory = "chroma_db"
    os.makedirs(persist_directory, exist_ok=True)
    
    # Initialize the client with persistence
    chroma_client = chromadb.PersistentClient(
        path=persist_directory,
    )
    
    # Get or create collection
    try:
        # Try to get existing collection
        db = chroma_client.get_collection(
            name="document_collection",
            embedding_function=embedding_function
        )
        # Clear existing documents
        db.delete(db.get()["ids"])
    except:
        # Create new collection if it doesn't exist
        db = chroma_client.create_collection(
            name="document_collection",
            embedding_function=embedding_function
        )
    
    # Add documents in batches to avoid memory issues
    batch_size = 20
    for i in range(0, len(documents), batch_size):
        batch = documents[i:i + batch_size]
        db.add(
            documents=batch,
            ids=[f"doc_{j}" for j in range(i, i + len(batch))]
        )
    
    return db

def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
    start_time = datetime.datetime.now()
    print(f"{start_time}: Starting ChromaDB query for question: {query[:50]}...") # Log query start

    try:
        results = db.query(
            query_texts=[query],
            n_results=min(n_results, db.count()),
            include=['documents', 'distances']
        )
        end_time = datetime.datetime.now()
        print(f"{end_time}: ChromaDB query finished. Time taken: {end_time - start_time}")  # Log the time taken
        # ... (rest of your get_relevant_passage function remains the same)

        # Ensure results exist and contain at least one document
        if not results or 'documents' not in results or not results['documents'] or not results['documents'][0]:
            print("No relevant passages found.")
            return []

        # Extract valid results
        documents = results['documents'][0]  # List of retrieved documents
        distances = results['distances'][0]  # Corresponding similarity scores

        # Debugging output
        print(f"Number of relevant passages retrieved: {len(documents)}")
        for i, (doc, distance) in enumerate(zip(documents, distances)):
            similarity = 1 - distance  # Convert distance to similarity score
            print(f"Passage {i+1} (Similarity: {similarity:.4f}): {doc[:100]}...")

        return documents  # Return only valid results
    except Exception as e:
        print(f"Error in get_relevant_passage: {str(e)}")
        return []