Spaces:
Running
Running
Update app/search.py
Browse files- app/search.py +56 -16
app/search.py
CHANGED
|
@@ -1,37 +1,77 @@
|
|
| 1 |
import json
|
| 2 |
from pathlib import Path
|
| 3 |
-
from typing import List, Dict
|
| 4 |
-
from sentence_transformers import SentenceTransformer
|
| 5 |
import numpy as np
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
def load_index(env: Dict):
|
| 8 |
-
import faiss
|
| 9 |
index_path = Path(env["INDEX_DIR"]) / "faiss.index"
|
| 10 |
meta_path = Path(env["INDEX_DIR"]) / "meta.json"
|
| 11 |
if not index_path.exists():
|
| 12 |
raise RuntimeError("Index not found. Run ingest first.")
|
| 13 |
index = faiss.read_index(str(index_path))
|
| 14 |
-
|
|
|
|
| 15 |
return index, metas
|
| 16 |
|
| 17 |
-
def embed(texts: List[str]):
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
def search(q: str, env: Dict, top_k: int = 15, filters: Dict = None) -> List[Dict]:
|
|
|
|
| 22 |
index, metas = load_index(env)
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
results = []
|
|
|
|
|
|
|
|
|
|
| 26 |
for score, idx in zip(scores[0], idxs[0]):
|
| 27 |
-
if idx == -1:
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
| 31 |
continue
|
| 32 |
-
if "categories" in filters and filters["categories"]:
|
| 33 |
-
if not set(filters["categories"]).intersection(set(m.get("categories",[]))):
|
| 34 |
-
continue
|
| 35 |
m["score"] = float(score)
|
| 36 |
results.append(m)
|
|
|
|
| 37 |
return results
|
|
|
|
| 1 |
import json
|
| 2 |
from pathlib import Path
|
| 3 |
+
from typing import List, Dict, Optional
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
|
| 8 |
+
# ---------- Global embedder (loaded once, CPU-safe) ----------
|
| 9 |
+
_EMBEDDER: Optional[SentenceTransformer] = None
|
| 10 |
+
|
| 11 |
+
def _get_embedder() -> SentenceTransformer:
|
| 12 |
+
global _EMBEDDER
|
| 13 |
+
if _EMBEDDER is None:
|
| 14 |
+
# Explicit device="cpu" avoids any device_map/meta init paths.
|
| 15 |
+
# Use the canonical model id to avoid redirect surprises.
|
| 16 |
+
_EMBEDDER = SentenceTransformer(
|
| 17 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
| 18 |
+
device="cpu"
|
| 19 |
+
)
|
| 20 |
+
# Optional: shorten for speed on Spaces; keep accuracy reasonable
|
| 21 |
+
_EMBEDDER.max_seq_length = 256
|
| 22 |
+
return _EMBEDDER
|
| 23 |
+
|
| 24 |
def load_index(env: Dict):
|
| 25 |
+
import faiss
|
| 26 |
index_path = Path(env["INDEX_DIR"]) / "faiss.index"
|
| 27 |
meta_path = Path(env["INDEX_DIR"]) / "meta.json"
|
| 28 |
if not index_path.exists():
|
| 29 |
raise RuntimeError("Index not found. Run ingest first.")
|
| 30 |
index = faiss.read_index(str(index_path))
|
| 31 |
+
with open(meta_path, "r") as f:
|
| 32 |
+
metas = json.load(f)
|
| 33 |
return index, metas
|
| 34 |
|
| 35 |
+
def embed(texts: List[str]) -> np.ndarray:
|
| 36 |
+
emb = _get_embedder()
|
| 37 |
+
vecs = emb.encode(
|
| 38 |
+
texts,
|
| 39 |
+
convert_to_numpy=True,
|
| 40 |
+
normalize_embeddings=True,
|
| 41 |
+
show_progress_bar=False,
|
| 42 |
+
batch_size=32,
|
| 43 |
+
)
|
| 44 |
+
# FAISS expects float32
|
| 45 |
+
if vecs.dtype != np.float32:
|
| 46 |
+
vecs = vecs.astype(np.float32, copy=False)
|
| 47 |
+
return vecs
|
| 48 |
|
| 49 |
def search(q: str, env: Dict, top_k: int = 15, filters: Dict = None) -> List[Dict]:
|
| 50 |
+
import faiss
|
| 51 |
index, metas = load_index(env)
|
| 52 |
+
|
| 53 |
+
qv = embed([q]) # shape (1, d) float32
|
| 54 |
+
# Defensive: ensure index dim matches query dim
|
| 55 |
+
if hasattr(index, "d") and index.d != qv.shape[1]:
|
| 56 |
+
raise RuntimeError(f"FAISS index dim {getattr(index, 'd', '?')} "
|
| 57 |
+
f"!= embedding dim {qv.shape[1]}")
|
| 58 |
+
|
| 59 |
+
scores, idxs = index.search(qv, top_k) # scores shape (1, k), idxs shape (1, k)
|
| 60 |
+
|
| 61 |
results = []
|
| 62 |
+
f_geo = (filters or {}).get("geo")
|
| 63 |
+
f_cats = (filters or {}).get("categories")
|
| 64 |
+
|
| 65 |
for score, idx in zip(scores[0], idxs[0]):
|
| 66 |
+
if idx == -1:
|
| 67 |
+
continue
|
| 68 |
+
m = dict(metas[idx]) # copy so we don’t mutate the cached list
|
| 69 |
+
if f_geo and m.get("geo") not in f_geo:
|
| 70 |
+
continue
|
| 71 |
+
if f_cats:
|
| 72 |
+
if not set(f_cats).intersection(set(m.get("categories", []))):
|
| 73 |
continue
|
|
|
|
|
|
|
|
|
|
| 74 |
m["score"] = float(score)
|
| 75 |
results.append(m)
|
| 76 |
+
|
| 77 |
return results
|