File size: 6,322 Bytes
e58c94d
985f660
 
e58c94d
 
 
985f660
 
c1d9fa4
985f660
 
e58c94d
 
985f660
 
 
 
 
 
 
e58c94d
985f660
e58c94d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391ff38
e58c94d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391ff38
e58c94d
 
391ff38
 
e58c94d
 
 
 
 
 
 
 
 
 
391ff38
e58c94d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985f660
e58c94d
985f660
e58c94d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391ff38
e58c94d
 
 
 
 
 
985f660
e58c94d
985f660
 
e58c94d
985f660
 
 
 
 
 
 
 
e58c94d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import pickle
from pathlib import Path
from typing import List, Dict, Any

import faiss
import numpy as np
from sentence_transformers import SentenceTransformer

from core.utils.logger import logger

class OptimizedGazaKnowledgeBase:
    """Optimized knowledge base that loads pre-made FAISS index and assets"""
    
    def __init__(self, vector_store_dir: str = "./vector_store"):
        self.vector_store_dir = Path(vector_store_dir)
        self.faiss_index = None
        self.embedding_model = None
        self.chunks = []
        self.metadata = []
        self.is_initialized = False
        
    def initialize(self):
        """Load pre-made FAISS index and associated data"""
        try:
            logger.info("πŸ”„ Loading pre-made FAISS index and assets...")
            
            # 1. Load FAISS index
            index_path = self.vector_store_dir / "index.faiss"
            if not index_path.exists():
                raise FileNotFoundError(f"FAISS index not found at {index_path}")
            
            self.faiss_index = faiss.read_index(str(index_path))
            logger.info(f"βœ… Loaded FAISS index: {self.faiss_index.ntotal} vectors, {self.faiss_index.d} dimensions")
            
            # 2. Load chunks
            chunks_path = self.vector_store_dir / "chunks.txt"
            if not chunks_path.exists():
                raise FileNotFoundError(f"Chunks file not found at {chunks_path}")
            
            with open(chunks_path,'r', encoding='utf-8') as f:
                lines = f.readlines()
            
            # Parse chunks from the formatted file
            current_chunk = ""
            for line in lines:
                line = line.strip()
                if line.startswith("=== Chunk") and current_chunk:
                    self.chunks.append(current_chunk.strip())
                    current_chunk = ""
                elif not line.startswith("===") and not line.startswith("Source:") and not line.startswith("Length:"):
                    current_chunk += line + " "
            
            # Add the last chunk
            if current_chunk:
                self.chunks.append(current_chunk.strip())
            
            logger.info(f"βœ… Loaded {len(self.chunks)} text chunks")
            
            # 3. Load metadata
            metadata_path = self.vector_store_dir / "metadata.pkl"
            if metadata_path.exists():
                with open(metadata_path, 'rb') as f:
                    metadata_dict = pickle.load(f)
                
                if isinstance(metadata_dict, dict) and 'metadata' in metadata_dict:
                    self.metadata = metadata_dict['metadata']
                    logger.info(f"βœ… Loaded {len(self.metadata)} metadata entries")
                else:
                    logger.warning("⚠️ Metadata format not recognized, using empty metadata")
                    self.metadata = [{}] * len(self.chunks)
            else:
                logger.warning("⚠️ No metadata file found, using empty metadata")
                self.metadata = [{}] * len(self.chunks)
            
            # 4. Initialize embedding model for query encoding
            logger.info("πŸ”„ Loading embedding model for queries...")
            self.embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
            logger.info("βœ… Embedding model loaded")
            
            # 5. Verify data consistency
            if len(self.chunks) != self.faiss_index.ntotal:
                logger.warning(f"⚠️ Mismatch: {len(self.chunks)} chunks vs {self.faiss_index.ntotal} vectors")
                # Trim chunks to match index size
                self.chunks = self.chunks[:self.faiss_index.ntotal]
                self.metadata = self.metadata[:self.faiss_index.ntotal]
                logger.info(f"βœ… Trimmed to {len(self.chunks)} chunks to match index")
            
            self.is_initialized = True
            logger.info("πŸŽ‰ Knowledge base initialization complete!")
            
        except Exception as e:
            logger.error(f"❌ Failed to initialize knowledge base: {e}")
            raise
    
    def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
        """Search using pre-made FAISS index"""
        if not self.is_initialized:
            raise RuntimeError("Knowledge base not initialized")
        
        try:
            # 1. Encode query
            query_embedding = self.embedding_model.encode([query])
            query_vector = np.array(query_embedding, dtype=np.float32)
            
            # 2. Search FAISS index
            distances, indices = self.faiss_index.search(query_vector, k)
            
            # 3. Prepare results
            results = []
            for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
                if idx >= 0 and idx < len(self.chunks):  # Valid index
                    chunk_metadata = self.metadata[idx] if idx < len(self.metadata) else {}
                    
                    result = {
                        "text": self.chunks[idx],
                        "score": float(1.0 / (1.0 + distance)),  # Convert distance to similarity score
                        "source": chunk_metadata.get("source", "unknown"),
                        "chunk_index": int(idx),
                        "distance": float(distance),
                        "metadata": chunk_metadata
                    }
                    results.append(result)
            
            logger.info(f"πŸ” Search for '{query}...' returned {len(results)} results")
            return results
            
        except Exception as e:
            logger.error(f"❌ Search error: {e}")
            return []
    
    def get_stats(self) -> Dict[str, Any]:
        """Get knowledge base statistics"""
        if not self.is_initialized:
            return {"status": "not_initialized"}
        
        return {
            "status": "initialized",
            "total_chunks": len(self.chunks),
            "total_vectors": self.faiss_index.ntotal,
            "embedding_dimension": self.faiss_index.d,
            "index_type": type(self.faiss_index).__name__,
            "sources": list(set(meta.get("source", "unknown") for meta in self.metadata))
        }