File size: 2,235 Bytes
b8397a5
 
 
 
 
 
 
bb8fa77
b8397a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
\
import json, os
import numpy as np, pandas as pd
import faiss
from sentence_transformers import SentenceTransformer, CrossEncoder

class SloganSearcher:
    def __init__(self, assets_dir="assets", use_rerank=False, rerank_model="cross-encoder/stsb-roberta-base"):
        meta_path = os.path.join(assets_dir, "meta.json")
        if not os.path.exists(meta_path):
            raise FileNotFoundError(f"Missing {meta_path}. Build assets first.")
        with open(meta_path, "r") as f:
            self.meta = json.load(f)

        self.df = pd.read_parquet(os.path.join(assets_dir, "slogans_clean.parquet"))
        self.index = faiss.read_index(os.path.join(assets_dir, "faiss.index"))
        self.encoder = SentenceTransformer(self.meta["model_name"])

        self.use_rerank = use_rerank
        self.reranker = CrossEncoder(rerank_model) if use_rerank else None

        self.text_col = self.meta.get("text_col", "description")
        self.fallback_col = self.meta.get("fallback_col", "tagline")
        self.norm = bool(self.meta.get("normalized", True))

    def search(self, query: str, top_k=5, rerank_top_n=20):
        if not isinstance(query, str) or len(query.strip()) == 0:
            return pd.DataFrame(columns=["display", "score"] + (["rerank_score"] if self.use_rerank else []))
        q = self.encoder.encode([query], convert_to_numpy=True, normalize_embeddings=self.norm)
        sims, idxs = self.index.search(q, max(int(top_k), int(rerank_top_n) if self.use_rerank else int(top_k)))
        idxs = idxs[0].tolist()
        sims = sims[0].tolist()
        results = self.df.iloc[idxs].copy()
        results["score"] = sims
        if self.use_rerank:
            texts = results[self.text_col].fillna(results[self.fallback_col]).astype(str).tolist()
            pairs = [[query, t] for t in texts]
            rr = self.reranker.predict(pairs)
            results["rerank_score"] = rr
            results = results.sort_values("rerank_score", ascending=False).head(int(top_k))
        else:
            results = results.head(int(top_k))
        results["display"] = results[self.fallback_col]
        cols = ["display", "score"] + (["rerank_score"] if self.use_rerank else [])
        return results[cols]