semantic-search / search_utils.py
Testys's picture
Update search_utils.py
2dec497
raw
history blame
7.12 kB
import numpy as np
import pandas as pd
import faiss
import zipfile
from pathlib import Path
from sentence_transformers import SentenceTransformer, util
import streamlit as st
class MetadataManager:
def __init__(self):
self.shard_dir = Path("metadata_shards")
self.shard_map = {}
self.loaded_shards = {}
self.total_docs = 0
self._ensure_unzipped() # Removed Streamlit elements from here
self._build_shard_map()
def _ensure_unzipped(self):
"""Handle ZIP extraction without Streamlit elements"""
if not self.shard_dir.exists():
zip_path = Path("metadata_shards.zip")
if zip_path.exists():
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(self.shard_dir)
else:
raise FileNotFoundError("Metadata ZIP file not found")
def _build_shard_map(self):
"""Create index range to shard mapping"""
self.total_docs = 0
for f in sorted(self.shard_dir.glob("*.parquet")):
parts = f.stem.split("_")
start = int(parts[1])
end = int(parts[2])
self.shard_map[(start, end)] = f.name
self.total_docs = max(self.total_docs, end + 1)
def get_metadata(self, global_indices):
"""Retrieve metadata with validation"""
# Check for empty numpy array properly
if isinstance(global_indices, np.ndarray) and global_indices.size == 0:
return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
# Convert numpy array to list for processing
indices_list = global_indices.tolist() if isinstance(global_indices, np.ndarray) else global_indices
# Filter valid indices
valid_indices = [idx for idx in indices_list if 0 <= idx < self.total_docs]
if not valid_indices:
return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
# Group indices by shard with boundary check
shard_groups = {}
for idx in valid_indices:
found = False
for (start, end), shard in self.shard_map.items():
if start <= idx <= end:
if shard not in shard_groups:
shard_groups[shard] = []
shard_groups[shard].append(idx - start)
found = True
break
if not found:
st.warning(f"Index {idx} out of shard range (0-{self.total_docs-1})")
# Load and process shards
results = []
for shard, local_indices in shard_groups.items():
try:
if shard not in self.loaded_shards:
self.loaded_shards[shard] = pd.read_parquet(
self.shard_dir / shard,
columns=["title", "summary", "source"]
)
if local_indices:
results.append(self.loaded_shards[shard].iloc[local_indices])
except Exception as e:
st.error(f"Error loading shard {shard}: {str(e)}")
continue
return pd.concat(results).reset_index(drop=True) if results else pd.DataFrame()
class SemanticSearch:
def __init__(self):
self.shard_dir = Path("compressed_shards")
self.model = None
self.index_shards = []
self.metadata_mgr = MetadataManager() # No Streamlit elements in constructor
self.shard_sizes = []
@st.cache_resource
def load_model(_self):
return SentenceTransformer('all-MiniLM-L6-v2')
def initialize_system(self):
self.model = self.load_model()
self._load_faiss_shards()
def _load_faiss_shards(self):
"""Load all FAISS index shards"""
self.shard_sizes = []
for shard_path in sorted(self.shard_dir.glob("*.index")):
index = faiss.read_index(str(shard_path))
self.index_shards.append(index)
self.shard_sizes.append(index.ntotal)
def _global_index(self, shard_idx, local_idx):
"""Convert local index to global index"""
return sum(self.shard_sizes[:shard_idx]) + local_idx
def search(self, query, top_k=5):
"""Search with validation"""
if not query or not self.index_shards:
return pd.DataFrame()
try:
query_embedding = self.model.encode([query], convert_to_numpy=True)
except Exception as e:
st.error(f"Query encoding failed: {str(e)}")
return pd.DataFrame()
all_distances = []
all_global_indices = []
# Search with index validation
for shard_idx, index in enumerate(self.index_shards):
if index.ntotal == 0:
continue
try:
distances, indices = index.search(query_embedding, top_k)
valid_indices = [idx for idx in indices[0] if 0 <= idx < index.ntotal]
global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
all_distances.extend(distances[0][:len(valid_indices)])
all_global_indices.extend(global_indices)
except Exception as e:
st.error(f"Search failed in shard {shard_idx}: {str(e)}")
continue
# Ensure equal array lengths
min_length = min(len(all_distances), len(all_global_indices))
return self._process_results(
np.array(all_distances[:min_length]),
np.array(all_global_indices[:min_length]),
top_k
)
def _process_results(self, distances, global_indices, top_k):
"""Process raw search results into formatted DataFrame"""
# Proper numpy array emptiness checks
if global_indices.size == 0 or distances.size == 0:
return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
try:
# Convert numpy indices to Python list for metadata retrieval
indices_list = global_indices.tolist()
# Get metadata for matched indices
results = self.metadata_mgr.get_metadata(indices_list)
# Ensure distances match results length
if len(results) != len(distances):
distances = distances[:len(results)]
# Calculate similarity scores
results['similarity'] = 1 - (distances / 2)
# Deduplicate and sort results
results = results.drop_duplicates(subset=["title", "source"])
.sort_values("similarity", ascending=False)
.head(top_k)
return results.reset_index(drop=True)
except Exception as e:
st.error(f"Result processing failed: {str(e)}")
return pd.DataFrame(columns=["title", "summary", "source", "similarity"])