sloganAI / logic /search.py
asaf1602's picture
Upload folder using huggingface_hub
bb8fa77 verified
\
import json, os
import numpy as np, pandas as pd
import faiss
from sentence_transformers import SentenceTransformer, CrossEncoder
class SloganSearcher:
def __init__(self, assets_dir="assets", use_rerank=False, rerank_model="cross-encoder/stsb-roberta-base"):
meta_path = os.path.join(assets_dir, "meta.json")
if not os.path.exists(meta_path):
raise FileNotFoundError(f"Missing {meta_path}. Build assets first.")
with open(meta_path, "r") as f:
self.meta = json.load(f)
self.df = pd.read_parquet(os.path.join(assets_dir, "slogans_clean.parquet"))
self.index = faiss.read_index(os.path.join(assets_dir, "faiss.index"))
self.encoder = SentenceTransformer(self.meta["model_name"])
self.use_rerank = use_rerank
self.reranker = CrossEncoder(rerank_model) if use_rerank else None
self.text_col = self.meta.get("text_col", "description")
self.fallback_col = self.meta.get("fallback_col", "tagline")
self.norm = bool(self.meta.get("normalized", True))
def search(self, query: str, top_k=5, rerank_top_n=20):
if not isinstance(query, str) or len(query.strip()) == 0:
return pd.DataFrame(columns=["display", "score"] + (["rerank_score"] if self.use_rerank else []))
q = self.encoder.encode([query], convert_to_numpy=True, normalize_embeddings=self.norm)
sims, idxs = self.index.search(q, max(int(top_k), int(rerank_top_n) if self.use_rerank else int(top_k)))
idxs = idxs[0].tolist()
sims = sims[0].tolist()
results = self.df.iloc[idxs].copy()
results["score"] = sims
if self.use_rerank:
texts = results[self.text_col].fillna(results[self.fallback_col]).astype(str).tolist()
pairs = [[query, t] for t in texts]
rr = self.reranker.predict(pairs)
results["rerank_score"] = rr
results = results.sort_values("rerank_score", ascending=False).head(int(top_k))
else:
results = results.head(int(top_k))
results["display"] = results[self.fallback_col]
cols = ["display", "score"] + (["rerank_score"] if self.use_rerank else [])
return results[cols]