Kalpokoch commited on
Commit
e8fe4a8
·
1 Parent(s): f6648b0

updates to policy_vector_db

Browse files
Files changed (1) hide show
  1. app/policy_vector_db.py +38 -13
app/policy_vector_db.py CHANGED
@@ -7,20 +7,32 @@ import chromadb
7
  from chromadb.config import Settings
8
  import logging
9
 
10
- logger = logging.getLogger("app")
 
 
11
 
12
  class PolicyVectorDB:
 
 
 
 
13
  def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.5):
14
  self.persist_directory = persist_directory
15
  self.client = chromadb.PersistentClient(path=persist_directory, settings=Settings(allow_reset=True))
16
  self.collection_name = "neepco_dop_policies"
17
- # ✅ Use 'cuda' if a GPU is available for better performance
18
- self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cpu')
19
- self.collection = self._get_collection()
 
 
 
20
  self.top_k_default = top_k_default
21
  self.relevance_threshold = relevance_threshold
22
 
23
  def _get_collection(self):
 
 
 
24
  if self.collection is None:
25
  self.collection = self.client.get_or_create_collection(
26
  name=self.collection_name,
@@ -29,9 +41,13 @@ class PolicyVectorDB:
29
  return self.collection
30
 
31
  def _flatten_metadata(self, metadata: Dict) -> Dict:
 
32
  return {key: str(value) for key, value in metadata.items()}
33
 
34
  def add_chunks(self, chunks: List[Dict]):
 
 
 
35
  collection = self._get_collection()
36
  if not chunks:
37
  logger.info("No chunks provided to add.")
@@ -47,23 +63,32 @@ class PolicyVectorDB:
47
  new_chunks = [chunk for chunk in chunks_with_ids if str(chunk.get('id')) not in existing_ids]
48
 
49
  if not new_chunks:
50
- logger.info("All provided chunks already exist in the database.")
51
  return
52
 
53
  logger.info(f"Adding {len(new_chunks)} new chunks to the vector database...")
 
 
54
  batch_size = 64
55
  for i in range(0, len(new_chunks), batch_size):
56
  batch = new_chunks[i:i + batch_size]
 
57
  ids = [str(chunk['id']) for chunk in batch]
58
  texts = [chunk['text'] for chunk in batch]
59
  metadatas = [self._flatten_metadata(chunk.get('metadata', {})) for chunk in batch]
60
 
61
  embeddings = self.embedding_model.encode(texts, show_progress_bar=False).tolist()
 
62
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
63
  logger.info(f"Added batch {i//batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
 
64
  logger.info(f"Finished adding {len(new_chunks)} chunks.")
65
 
66
  def search(self, query_text: str, top_k: int = None) -> List[Dict]:
 
 
 
 
67
  collection = self._get_collection()
68
  query_embedding = self.embedding_model.encode([query_text]).tolist()
69
  k = top_k if top_k is not None else self.top_k_default
@@ -80,7 +105,6 @@ class PolicyVectorDB:
80
  for i, doc in enumerate(results['documents'][0]):
81
  relevance_score = 1 - results['distances'][0][i]
82
 
83
- # ✅ RECOMMENDED CHANGE: Filter results internally based on the threshold
84
  if relevance_score >= self.relevance_threshold:
85
  search_results.append({
86
  'text': doc,
@@ -88,21 +112,22 @@ class PolicyVectorDB:
88
  'relevance_score': relevance_score
89
  })
90
 
91
- # Return the top k results *after* filtering
92
  return sorted(search_results, key=lambda x: x['relevance_score'], reverse=True)[:k]
93
 
94
- def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str):
 
 
 
95
  try:
96
  if db_instance._get_collection().count() > 0:
97
  logger.info("Vector database already contains data. Skipping population.")
98
  return True
99
-
100
  logger.info("Vector database is empty. Attempting to populate from chunks file.")
101
  if not os.path.exists(chunks_file_path):
102
- logger.error(f"Chunks file not found at {chunks_file_path}. Cannot populate DB.")
103
  return False
104
 
105
- # ✅ CORRECTED CODE: Read the JSONL file line-by-line
106
  chunks_to_add = []
107
  with open(chunks_file_path, 'r', encoding='utf-8') as f:
108
  for line in f:
@@ -112,12 +137,12 @@ def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str):
112
  logger.warning(f"Skipping malformed line in chunks file: {line.strip()}")
113
 
114
  if not chunks_to_add:
115
- logger.warning(f"Chunks file at {chunks_file_path} is empty or invalid. No data to add.")
116
  return False
117
 
118
  db_instance.add_chunks(chunks_to_add)
119
  logger.info("Vector database population attempt complete.")
120
  return True
121
  except Exception as e:
122
- logger.error(f"DB Population Error: {e}", exc_info=True)
123
  return False
 
7
  from chromadb.config import Settings
8
  import logging
9
 
10
+ # --- Basic Logging Setup ---
11
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
12
+ logger = logging.getLogger(__name__)
13
 
14
  class PolicyVectorDB:
15
+ """
16
+ Manages the connection, population, and querying of a ChromaDB vector database
17
+ for policy documents.
18
+ """
19
  def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.5):
20
  self.persist_directory = persist_directory
21
  self.client = chromadb.PersistentClient(path=persist_directory, settings=Settings(allow_reset=True))
22
  self.collection_name = "neepco_dop_policies"
23
+
24
+ # Using a powerful open-source embedding model.
25
+ # Change 'cpu' to 'cuda' if a GPU is available for significantly faster embedding.
26
+ self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cpu')
27
+
28
+ self.collection = None # Initialize collection as None for lazy loading
29
  self.top_k_default = top_k_default
30
  self.relevance_threshold = relevance_threshold
31
 
32
  def _get_collection(self):
33
+ """
34
+ Retrieves or creates the ChromaDB collection. Implements lazy loading.
35
+ """
36
  if self.collection is None:
37
  self.collection = self.client.get_or_create_collection(
38
  name=self.collection_name,
 
41
  return self.collection
42
 
43
  def _flatten_metadata(self, metadata: Dict) -> Dict:
44
+ """Ensures all metadata values are strings, as required by some ChromaDB versions."""
45
  return {key: str(value) for key, value in metadata.items()}
46
 
47
  def add_chunks(self, chunks: List[Dict]):
48
+ """
49
+ Adds a list of chunks to the vector database, skipping any that already exist.
50
+ """
51
  collection = self._get_collection()
52
  if not chunks:
53
  logger.info("No chunks provided to add.")
 
63
  new_chunks = [chunk for chunk in chunks_with_ids if str(chunk.get('id')) not in existing_ids]
64
 
65
  if not new_chunks:
66
+ logger.info("All provided chunks already exist in the database. No new data to add.")
67
  return
68
 
69
  logger.info(f"Adding {len(new_chunks)} new chunks to the vector database...")
70
+
71
+ # Process in batches for efficiency
72
  batch_size = 64
73
  for i in range(0, len(new_chunks), batch_size):
74
  batch = new_chunks[i:i + batch_size]
75
+
76
  ids = [str(chunk['id']) for chunk in batch]
77
  texts = [chunk['text'] for chunk in batch]
78
  metadatas = [self._flatten_metadata(chunk.get('metadata', {})) for chunk in batch]
79
 
80
  embeddings = self.embedding_model.encode(texts, show_progress_bar=False).tolist()
81
+
82
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
83
  logger.info(f"Added batch {i//batch_size + 1}/{(len(new_chunks) + batch_size - 1) // batch_size}")
84
+
85
  logger.info(f"Finished adding {len(new_chunks)} chunks.")
86
 
87
  def search(self, query_text: str, top_k: int = None) -> List[Dict]:
88
+ """
89
+ Searches the vector database for a given query text.
90
+ Returns a list of results filtered by a relevance threshold.
91
+ """
92
  collection = self._get_collection()
93
  query_embedding = self.embedding_model.encode([query_text]).tolist()
94
  k = top_k if top_k is not None else self.top_k_default
 
105
  for i, doc in enumerate(results['documents'][0]):
106
  relevance_score = 1 - results['distances'][0][i]
107
 
 
108
  if relevance_score >= self.relevance_threshold:
109
  search_results.append({
110
  'text': doc,
 
112
  'relevance_score': relevance_score
113
  })
114
 
 
115
  return sorted(search_results, key=lambda x: x['relevance_score'], reverse=True)[:k]
116
 
117
+ def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> bool:
118
+ """
119
+ Checks if the DB is empty and populates it from a JSONL file if needed.
120
+ """
121
  try:
122
  if db_instance._get_collection().count() > 0:
123
  logger.info("Vector database already contains data. Skipping population.")
124
  return True
125
+
126
  logger.info("Vector database is empty. Attempting to populate from chunks file.")
127
  if not os.path.exists(chunks_file_path):
128
+ logger.error(f"Chunks file not found at '{chunks_file_path}'. Cannot populate DB.")
129
  return False
130
 
 
131
  chunks_to_add = []
132
  with open(chunks_file_path, 'r', encoding='utf-8') as f:
133
  for line in f:
 
137
  logger.warning(f"Skipping malformed line in chunks file: {line.strip()}")
138
 
139
  if not chunks_to_add:
140
+ logger.warning(f"Chunks file at '{chunks_file_path}' is empty or invalid. No data to add.")
141
  return False
142
 
143
  db_instance.add_chunks(chunks_to_add)
144
  logger.info("Vector database population attempt complete.")
145
  return True
146
  except Exception as e:
147
+ logger.error(f"An error occurred during DB population check: {e}", exc_info=True)
148
  return False