Kalpokoch commited on
Commit
a6eea30
·
verified ·
1 Parent(s): 95eb732

Update app/policy_vector_db.py

Browse files
Files changed (1) hide show
  1. app/policy_vector_db.py +28 -59
app/policy_vector_db.py CHANGED
@@ -1,113 +1,82 @@
1
  import os
2
  import json
3
- import shutil
4
- import logging
5
  from typing import List, Dict
6
-
7
- import chromadb
8
  from sentence_transformers import SentenceTransformer
9
- import torch
10
-
11
- logger = logging.getLogger("vector-db")
12
 
13
  class PolicyVectorDB:
14
- def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.65):
15
  self.persist_directory = persist_directory
 
16
  self.collection_name = "neepco_dop_policies"
 
 
17
  self.top_k_default = top_k_default
18
  self.relevance_threshold = relevance_threshold
19
 
20
- self.client = chromadb.PersistentClient(path=self.persist_directory)
21
- self.collection = None
22
-
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
- self.embedding_model = SentenceTransformer("BAAI/bge-large-en-v1.5", device=device)
25
- logger.info(f"[INIT] Embedding model loaded on {device.upper()}.")
26
-
27
  def _get_collection(self):
28
  if self.collection is None:
29
  self.collection = self.client.get_or_create_collection(
30
  name=self.collection_name,
31
  metadata={"description": "NEEPCO Delegation of Powers Policy"}
32
  )
33
- logger.info(f"[COLLECTION] Loaded collection '{self.collection_name}'. Count: {self.collection.count()}")
34
  return self.collection
35
 
36
  def _flatten_metadata(self, metadata: Dict) -> Dict:
37
- return {k: str(v) for k, v in metadata.items()}
38
 
39
  def add_chunks(self, chunks: List[Dict]):
40
  collection = self._get_collection()
41
  if not chunks:
42
- logger.warning("[ADD] No chunks to add.")
43
  return
44
-
45
  existing_ids = set(collection.get()['ids'])
46
- new_chunks = [c for c in chunks if c['id'] not in existing_ids]
47
-
48
  if not new_chunks:
49
- logger.info("[ADD] All chunks already exist in DB.")
50
  return
51
-
52
- logger.info(f"[ADD] Adding {len(new_chunks)} new chunks.")
53
  batch_size = 128
54
  for i in range(0, len(new_chunks), batch_size):
55
  batch = new_chunks[i:i + batch_size]
56
- texts = [c['text'] for c in batch]
57
- ids = [c['id'] for c in batch]
58
- metadatas = [self._flatten_metadata(c['metadata']) for c in batch]
59
-
60
  embeddings = self.embedding_model.encode(texts, show_progress_bar=False).tolist()
61
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
62
-
63
- logger.info(f"[ADD] Total docs after insert: {collection.count()}")
64
 
65
  def search(self, query_text: str, top_k: int = None) -> List[Dict]:
66
  collection = self._get_collection()
67
- top_k = top_k or self.top_k_default
68
-
69
  query_embedding = self.embedding_model.encode([query_text]).tolist()
 
70
  results = collection.query(
71
  query_embeddings=query_embedding,
72
  n_results=top_k,
73
  include=["documents", "metadatas", "distances"]
74
  )
75
-
76
  search_results = []
77
- if not results.get("documents"):
78
- logger.warning("[SEARCH] No documents found.")
79
- return []
80
-
81
- for i, doc in enumerate(results["documents"][0]):
82
- score = 1 - results["distances"][0][i]
83
  search_results.append({
84
- "text": doc,
85
- "metadata": results["metadatas"][0][i],
86
- "relevance_score": round(score, 4)
87
  })
88
-
89
- logger.info(f"[SEARCH] Retrieved {len(search_results)} results for query: {query_text}")
90
  return search_results
91
 
92
-
93
- def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str) -> bool:
94
- logger.info("[POPULATE] Checking vector DB...")
95
-
96
  try:
97
  if db_instance._get_collection().count() == 0:
98
  if not os.path.exists(chunks_file_path):
99
- logger.error(f"[ERROR] Chunks file not found at {chunks_file_path}")
100
  return False
101
-
102
- with open(chunks_file_path, "r", encoding="utf-8") as f:
103
- chunks = json.load(f)
104
-
105
- logger.info(f"[POPULATE] Loaded {len(chunks)} chunks. Populating DB...")
106
- db_instance.add_chunks(chunks)
107
- logger.info("[POPULATE] DB population complete.")
108
  else:
109
- logger.info("[POPULATE] DB already populated.")
110
- return True
111
  except Exception as e:
112
- logger.exception(f"[EXCEPTION] During DB population: {str(e)}")
113
  return False
 
1
  import os
2
  import json
3
+ import torch
 
4
  from typing import List, Dict
 
 
5
  from sentence_transformers import SentenceTransformer
6
+ import chromadb
7
+ from chromadb.config import Settings
 
8
 
9
  class PolicyVectorDB:
10
+ def __init__(self, persist_directory: str, top_k_default: int = 5, relevance_threshold: float = 0.5):
11
  self.persist_directory = persist_directory
12
+ self.client = chromadb.PersistentClient(path=persist_directory, settings=Settings(allow_reset=True))
13
  self.collection_name = "neepco_dop_policies"
14
+ self.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5', device='cuda' if torch.cuda.is_available() else 'cpu')
15
+ self.collection = None
16
  self.top_k_default = top_k_default
17
  self.relevance_threshold = relevance_threshold
18
 
 
 
 
 
 
 
 
19
  def _get_collection(self):
20
  if self.collection is None:
21
  self.collection = self.client.get_or_create_collection(
22
  name=self.collection_name,
23
  metadata={"description": "NEEPCO Delegation of Powers Policy"}
24
  )
 
25
  return self.collection
26
 
27
  def _flatten_metadata(self, metadata: Dict) -> Dict:
28
+ return {key: str(value) for key, value in metadata.items()}
29
 
30
  def add_chunks(self, chunks: List[Dict]):
31
  collection = self._get_collection()
32
  if not chunks:
33
+ print("No chunks provided to add.")
34
  return
 
35
  existing_ids = set(collection.get()['ids'])
36
+ new_chunks = [chunk for chunk in chunks if chunk.get('id') not in existing_ids]
 
37
  if not new_chunks:
38
+ print("No new chunks to add.")
39
  return
 
 
40
  batch_size = 128
41
  for i in range(0, len(new_chunks), batch_size):
42
  batch = new_chunks[i:i + batch_size]
43
+ texts = [chunk['text'] for chunk in batch]
44
+ ids = [chunk['id'] for chunk in batch]
45
+ metadatas = [self._flatten_metadata(chunk['metadata']) for chunk in batch]
 
46
  embeddings = self.embedding_model.encode(texts, show_progress_bar=False).tolist()
47
  collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)
 
 
48
 
49
  def search(self, query_text: str, top_k: int = None) -> List[Dict]:
50
  collection = self._get_collection()
 
 
51
  query_embedding = self.embedding_model.encode([query_text]).tolist()
52
+ top_k = top_k if top_k else self.top_k_default
53
  results = collection.query(
54
  query_embeddings=query_embedding,
55
  n_results=top_k,
56
  include=["documents", "metadatas", "distances"]
57
  )
 
58
  search_results = []
59
+ for i, doc in enumerate(results['documents'][0]):
60
+ relevance_score = 1 - results['distances'][0][i]
 
 
 
 
61
  search_results.append({
62
+ 'text': doc,
63
+ 'metadata': results['metadatas'][0][i],
64
+ 'relevance_score': relevance_score
65
  })
 
 
66
  return search_results
67
 
68
+ def ensure_db_populated(db_instance: PolicyVectorDB, chunks_file_path: str):
 
 
 
69
  try:
70
  if db_instance._get_collection().count() == 0:
71
  if not os.path.exists(chunks_file_path):
72
+ print(f"Chunks file not found at {chunks_file_path}")
73
  return False
74
+ with open(chunks_file_path, 'r', encoding='utf-8') as f:
75
+ chunks_to_add = json.load(f)
76
+ db_instance.add_chunks(chunks_to_add)
77
+ return True
 
 
 
78
  else:
79
+ return True
 
80
  except Exception as e:
81
+ print(f"DB Population Error: {e}")
82
  return False