Testys commited on
Commit
017ee94
Β·
1 Parent(s): d134e08

Update search_utils.py

Browse files
Files changed (1) hide show
  1. search_utils.py +47 -110
search_utils.py CHANGED
@@ -1,12 +1,9 @@
1
  import numpy as np
2
  import pandas as pd
3
  import faiss
4
- from pathlib import Path
5
- from sentence_transformers import SentenceTransformer, util
6
- import streamlit as st
7
  import zipfile
8
- import pandas as pd
9
  from pathlib import Path
 
10
  import streamlit as st
11
 
12
  class MetadataManager:
@@ -14,33 +11,39 @@ class MetadataManager:
14
  self.shard_dir = Path("metadata_shards")
15
  self.shard_map = {}
16
  self.loaded_shards = {}
 
17
  self._ensure_unzipped()
18
  self._build_shard_map()
19
 
20
  def _ensure_unzipped(self):
21
- """Extract metadata shards from zip if needed"""
22
  if not self.shard_dir.exists():
23
  zip_path = Path("metadata_shards.zip")
24
  if zip_path.exists():
25
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
26
  zip_ref.extractall(self.shard_dir)
27
- st.toast("βœ… Successfully extracted metadata shards!", icon="πŸ“¦")
28
  else:
29
- raise FileNotFoundError("No metadata shards found!")
 
30
 
31
  def _build_shard_map(self):
32
- """Map index ranges to shard files"""
33
- for f in self.shard_dir.glob("*.parquet"):
 
34
  parts = f.stem.split("_")
35
- self.shard_map[(int(parts[1]), int(parts[2]))] = f.name
 
 
 
36
 
37
- def get_metadata(self, indices):
38
- """Retrieve metadata for specific indices"""
39
  results = []
40
  shard_groups = {}
41
 
42
- # Group indices by shard
43
- for idx in indices:
44
  for (start, end), shard in self.shard_map.items():
45
  if start <= idx <= end:
46
  if shard not in shard_groups:
@@ -48,125 +51,59 @@ class MetadataManager:
48
  shard_groups[shard].append(idx - start)
49
  break
50
 
51
- # Load required shards
52
  for shard, local_indices in shard_groups.items():
53
  if shard not in self.loaded_shards:
54
  self.loaded_shards[shard] = pd.read_parquet(
55
  self.shard_dir / shard,
56
  columns=["title", "summary", "source"]
57
  )
58
-
59
  results.append(self.loaded_shards[shard].iloc[local_indices])
60
 
61
  return pd.concat(results).reset_index(drop=True)
62
 
63
-
64
  class SemanticSearch:
65
- def __init__(self, shard_dir="compressed_shards"):
66
- self.shard_dir = Path(shard_dir)
67
- self.shard_dir.mkdir(exist_ok=True, parents=True)
68
  self.model = None
69
  self.index_shards = []
70
  self.metadata_mgr = MetadataManager()
71
-
 
72
  @st.cache_resource
73
  def load_model(_self):
74
  return SentenceTransformer('all-MiniLM-L6-v2')
75
-
76
  def initialize_system(self):
77
  self.model = self.load_model()
78
- self._load_index_shards()
79
 
80
- def _load_index_shards(self):
81
- """Load FAISS shards directly from local directory"""
 
82
  for shard_path in sorted(self.shard_dir.glob("*.index")):
83
- self.index_shards.append(faiss.read_index(str(shard_path)))
 
 
 
 
 
 
84
 
85
  def search(self, query, top_k=5):
86
- """Search across all shards"""
87
  query_embedding = self.model.encode([query], convert_to_numpy=True)
88
- all_scores = []
89
- all_indices = []
90
-
 
91
  for shard_idx, index in enumerate(self.index_shards):
92
  distances, indices = index.search(query_embedding, top_k)
93
- # Convert local indices to global shard offsets
94
- global_indices = [
95
- self._calculate_global_index(shard_idx, idx)
96
- for idx in indices[0]
97
- ]
98
- all_scores.extend(distances[0])
99
- all_indices.extend(global_indices)
100
-
101
- return self._process_results(np.array(all_scores), np.array(all_indices), top_k)
102
-
103
- def _calculate_global_index(self, shard_idx, local_idx):
104
- """Convert shard-local index to global index"""
105
- # Implement your specific shard indexing logic here
106
- # Example: return f"{shard_idx}-{local_idx}"
107
- return local_idx # Simple version if using unique IDs
108
-
109
- def _process_results(self, distances, indices, top_k):
110
- """Format search results"""
111
- results = pd.DataFrame({
112
- 'global_index': indices,
113
- 'similarity': 1 - (distances / 2) # L2 to cosine approximation
114
- })
115
- return results.sort_values('similarity', ascending=False).head(top_k)
116
-
117
- def search_with_threshold(self, query, top_k=5, similarity_threshold=0.6):
118
- """Threshold-filtered search"""
119
- results = self.search(query, top_k*2)
120
- filtered = results[results['similarity'] > similarity_threshold].head(top_k)
121
- return filtered.reset_index(drop=True)
122
-
123
-
124
-
125
-
126
- class MetadataManager:
127
- def __init__(self, repo_id, shard_dir="metadata_shards"):
128
- self.repo_id = repo_id
129
- self.shard_dir = Path(shard_dir)
130
- self.shard_map = {}
131
- self.loaded_shards = {}
132
- self._build_shard_map()
133
-
134
- def _build_shard_map(self):
135
- """Map index ranges to shard files"""
136
- for f in self.shard_dir.glob("*.parquet"):
137
- parts = f.stem.split("_")
138
- self.shard_map[(int(parts[1]), int(parts[2]))] = f.name
139
-
140
- def _download_shard(self, shard_name):
141
- """Download missing shards on demand"""
142
- if not (self.shard_dir/shard_name).exists():
143
- hf_hub_download(
144
- repo_id=self.repo_id,
145
- filename=f"metadata_shards/{shard_name}",
146
- local_dir=self.shard_dir,
147
- cache_dir="metadata_cache"
148
- )
149
-
150
- def get_metadata(self, indices):
151
- """Retrieve metadata for specific indices"""
152
- results = []
153
-
154
- # Group indices by shard
155
- shard_groups = {}
156
- for idx in indices:
157
- for (start, end), shard in self.shard_map.items():
158
- if start <= idx <= end:
159
- if shard not in shard_groups:
160
- shard_groups[shard] = []
161
- shard_groups[shard].append(idx - start)
162
- break
163
-
164
- # Process each required shard
165
- for shard, local_indices in shard_groups.items():
166
- if shard not in self.loaded_shards:
167
- self._download_shard(shard)
168
- self.loaded_shards[shard] = pd.read_parquet(self.shard_dir/shard)
169
-
170
- results.append(self.loaded_shards[shard].iloc[local_indices])
171
-
172
- return pd.concat(results).reset_index(drop=True)
 
1
  import numpy as np
2
  import pandas as pd
3
  import faiss
 
 
 
4
  import zipfile
 
5
  from pathlib import Path
6
+ from sentence_transformers import SentenceTransformer, util
7
  import streamlit as st
8
 
9
  class MetadataManager:
 
11
  self.shard_dir = Path("metadata_shards")
12
  self.shard_map = {}
13
  self.loaded_shards = {}
14
+ self.total_docs = 0
15
  self._ensure_unzipped()
16
  self._build_shard_map()
17
 
18
  def _ensure_unzipped(self):
19
+ """Handle ZIP extraction automatically"""
20
  if not self.shard_dir.exists():
21
  zip_path = Path("metadata_shards.zip")
22
  if zip_path.exists():
23
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
24
  zip_ref.extractall(self.shard_dir)
25
+ st.toast("πŸ“¦ Metadata shards extracted successfully!", icon="βœ…")
26
  else:
27
+ st.error("❌ Missing metadata_shards.zip file!")
28
+ raise FileNotFoundError("Metadata ZIP file not found")
29
 
30
  def _build_shard_map(self):
31
+ """Create index range to shard mapping"""
32
+ self.total_docs = 0
33
+ for f in sorted(self.shard_dir.glob("*.parquet")):
34
  parts = f.stem.split("_")
35
+ start = int(parts[1])
36
+ end = int(parts[2])
37
+ self.shard_map[(start, end)] = f.name
38
+ self.total_docs = max(self.total_docs, end + 1)
39
 
40
+ def get_metadata(self, global_indices):
41
+ """Retrieve metadata for global indices"""
42
  results = []
43
  shard_groups = {}
44
 
45
+ # Organize indices by their respective shards
46
+ for idx in global_indices:
47
  for (start, end), shard in self.shard_map.items():
48
  if start <= idx <= end:
49
  if shard not in shard_groups:
 
51
  shard_groups[shard].append(idx - start)
52
  break
53
 
54
+ # Load and process required shards
55
  for shard, local_indices in shard_groups.items():
56
  if shard not in self.loaded_shards:
57
  self.loaded_shards[shard] = pd.read_parquet(
58
  self.shard_dir / shard,
59
  columns=["title", "summary", "source"]
60
  )
 
61
  results.append(self.loaded_shards[shard].iloc[local_indices])
62
 
63
  return pd.concat(results).reset_index(drop=True)
64
 
 
65
  class SemanticSearch:
66
+ def __init__(self):
67
+ self.shard_dir = Path("compressed_shards")
 
68
  self.model = None
69
  self.index_shards = []
70
  self.metadata_mgr = MetadataManager()
71
+ self.shard_sizes = []
72
+
73
  @st.cache_resource
74
  def load_model(_self):
75
  return SentenceTransformer('all-MiniLM-L6-v2')
76
+
77
  def initialize_system(self):
78
  self.model = self.load_model()
79
+ self._load_faiss_shards()
80
 
81
+ def _load_faiss_shards(self):
82
+ """Load all FAISS index shards"""
83
+ self.shard_sizes = []
84
  for shard_path in sorted(self.shard_dir.glob("*.index")):
85
+ index = faiss.read_index(str(shard_path))
86
+ self.index_shards.append(index)
87
+ self.shard_sizes.append(index.ntotal)
88
+
89
+ def _global_index(self, shard_idx, local_idx):
90
+ """Convert local index to global index"""
91
+ return sum(self.shard_sizes[:shard_idx]) + local_idx
92
 
93
  def search(self, query, top_k=5):
94
+ """Main search functionality"""
95
  query_embedding = self.model.encode([query], convert_to_numpy=True)
96
+ all_distances = []
97
+ all_global_indices = []
98
+
99
+ # Search across all shards
100
  for shard_idx, index in enumerate(self.index_shards):
101
  distances, indices = index.search(query_embedding, top_k)
102
+ global_indices = [self._global_index(shard_idx, idx) for idx in indices[0]]
103
+ all_distances.extend(distances[0])
104
+ all_global_indices.extend(global_indices)
105
+
106
+ # Process and format results
107
+ results = self.metadata_mgr.get_metadata(all_global_indices)
108
+ results['similarity'] = 1 - (np.array(all_distances) / 2) # Convert L2 to cosine
109
+ return results.sort_values('similarity', ascending=False).head(top_k)