semantic-search / search_utils.py
Testys's picture
Update search_utils.py
7ccde22
raw
history blame
6.38 kB
import numpy as np
import pandas as pd
import faiss
from pathlib import Path
from sentence_transformers import SentenceTransformer, util
import streamlit as st
import zipfile
import pandas as pd
from pathlib import Path
import streamlit as st
class MetadataManager:
def __init__(self):
self.shard_dir = Path("metadata_shards")
self.shard_map = {}
self.loaded_shards = {}
self._ensure_unzipped()
self._build_shard_map()
def _ensure_unzipped(self):
"""Extract metadata shards from zip if needed"""
if not self.shard_dir.exists():
zip_path = Path("metadata_shards.zip")
if zip_path.exists():
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(self.shard_dir)
st.toast("βœ… Successfully extracted metadata shards!", icon="πŸ“¦")
else:
raise FileNotFoundError("No metadata shards found!")
def _build_shard_map(self):
"""Map index ranges to shard files"""
for f in self.shard_dir.glob("*.parquet"):
parts = f.stem.split("_")
self.shard_map[(int(parts[1]), int(parts[2]))] = f.name
def get_metadata(self, indices):
"""Retrieve metadata for specific indices"""
results = []
shard_groups = {}
# Group indices by shard
for idx in indices:
for (start, end), shard in self.shard_map.items():
if start <= idx <= end:
if shard not in shard_groups:
shard_groups[shard] = []
shard_groups[shard].append(idx - start)
break
# Load required shards
for shard, local_indices in shard_groups.items():
if shard not in self.loaded_shards:
self.loaded_shards[shard] = pd.read_parquet(
self.shard_dir / shard,
columns=["title", "summary", "source"]
)
results.append(self.loaded_shards[shard].iloc[local_indices])
return pd.concat(results).reset_index(drop=True)
class SemanticSearch:
def __init__(self, shard_dir="compressed_shards"):
self.shard_dir = Path(shard_dir)
self.shard_dir.mkdir(exist_ok=True, parents=True)
self.model = None
self.index_shards = []
self.metadata_mgr = MetadataManager()
@st.cache_resource
def load_model(_self):
return SentenceTransformer('all-MiniLM-L6-v2')
def initialize_system(self):
self.model = self.load_model()
self._load_index_shards()
def _load_index_shards(self):
"""Load FAISS shards directly from local directory"""
for shard_path in sorted(self.shard_dir.glob("*.index")):
self.index_shards.append(faiss.read_index(str(shard_path)))
def search(self, query, top_k=5):
"""Search across all shards"""
query_embedding = self.model.encode([query], convert_to_numpy=True)
all_scores = []
all_indices = []
for shard_idx, index in enumerate(self.index_shards):
distances, indices = index.search(query_embedding, top_k)
# Convert local indices to global shard offsets
global_indices = [
self._calculate_global_index(shard_idx, idx)
for idx in indices[0]
]
all_scores.extend(distances[0])
all_indices.extend(global_indices)
return self._process_results(np.array(all_scores), np.array(all_indices), top_k)
def _calculate_global_index(self, shard_idx, local_idx):
"""Convert shard-local index to global index"""
# Implement your specific shard indexing logic here
# Example: return f"{shard_idx}-{local_idx}"
return local_idx # Simple version if using unique IDs
def _process_results(self, distances, indices, top_k):
"""Format search results"""
results = pd.DataFrame({
'global_index': indices,
'similarity': 1 - (distances / 2) # L2 to cosine approximation
})
return results.sort_values('similarity', ascending=False).head(top_k)
def search_with_threshold(self, query, top_k=5, similarity_threshold=0.6):
"""Threshold-filtered search"""
results = self.search(query, top_k*2)
filtered = results[results['similarity'] > similarity_threshold].head(top_k)
return filtered.reset_index(drop=True)
class MetadataManager:
def __init__(self, repo_id, shard_dir="metadata_shards"):
self.repo_id = repo_id
self.shard_dir = Path(shard_dir)
self.shard_map = {}
self.loaded_shards = {}
self._build_shard_map()
def _build_shard_map(self):
"""Map index ranges to shard files"""
for f in self.shard_dir.glob("*.parquet"):
parts = f.stem.split("_")
self.shard_map[(int(parts[1]), int(parts[2]))] = f.name
def _download_shard(self, shard_name):
"""Download missing shards on demand"""
if not (self.shard_dir/shard_name).exists():
hf_hub_download(
repo_id=self.repo_id,
filename=f"metadata_shards/{shard_name}",
local_dir=self.shard_dir,
cache_dir="metadata_cache"
)
def get_metadata(self, indices):
"""Retrieve metadata for specific indices"""
results = []
# Group indices by shard
shard_groups = {}
for idx in indices:
for (start, end), shard in self.shard_map.items():
if start <= idx <= end:
if shard not in shard_groups:
shard_groups[shard] = []
shard_groups[shard].append(idx - start)
break
# Process each required shard
for shard, local_indices in shard_groups.items():
if shard not in self.loaded_shards:
self._download_shard(shard)
self.loaded_shards[shard] = pd.read_parquet(self.shard_dir/shard)
results.append(self.loaded_shards[shard].iloc[local_indices])
return pd.concat(results).reset_index(drop=True)