|
\ |
|
import json, os |
|
import numpy as np, pandas as pd |
|
import faiss |
|
from sentence_transformers import SentenceTransformer, CrossEncoder |
|
|
|
class SloganSearcher: |
|
def __init__(self, assets_dir="assets", use_rerank=False, rerank_model="cross-encoder/stsb-roberta-base"): |
|
meta_path = os.path.join(assets_dir, "meta.json") |
|
if not os.path.exists(meta_path): |
|
raise FileNotFoundError(f"Missing {meta_path}. Build assets first.") |
|
with open(meta_path, "r") as f: |
|
self.meta = json.load(f) |
|
|
|
self.df = pd.read_parquet(os.path.join(assets_dir, "slogans_clean.parquet")) |
|
self.index = faiss.read_index(os.path.join(assets_dir, "faiss.index")) |
|
self.encoder = SentenceTransformer(self.meta["model_name"]) |
|
|
|
self.use_rerank = use_rerank |
|
self.reranker = CrossEncoder(rerank_model) if use_rerank else None |
|
|
|
self.text_col = self.meta.get("text_col", "description") |
|
self.fallback_col = self.meta.get("fallback_col", "tagline") |
|
self.norm = bool(self.meta.get("normalized", True)) |
|
|
|
def search(self, query: str, top_k=5, rerank_top_n=20): |
|
if not isinstance(query, str) or len(query.strip()) == 0: |
|
return pd.DataFrame(columns=["display", "score"] + (["rerank_score"] if self.use_rerank else [])) |
|
q = self.encoder.encode([query], convert_to_numpy=True, normalize_embeddings=self.norm) |
|
sims, idxs = self.index.search(q, max(int(top_k), int(rerank_top_n) if self.use_rerank else int(top_k))) |
|
idxs = idxs[0].tolist() |
|
sims = sims[0].tolist() |
|
results = self.df.iloc[idxs].copy() |
|
results["score"] = sims |
|
if self.use_rerank: |
|
texts = results[self.text_col].fillna(results[self.fallback_col]).astype(str).tolist() |
|
pairs = [[query, t] for t in texts] |
|
rr = self.reranker.predict(pairs) |
|
results["rerank_score"] = rr |
|
results = results.sort_values("rerank_score", ascending=False).head(int(top_k)) |
|
else: |
|
results = results.head(int(top_k)) |
|
results["display"] = results[self.fallback_col] |
|
cols = ["display", "score"] + (["rerank_score"] if self.use_rerank else []) |
|
return results[cols] |
|
|