michaellupo74 commited on
Commit
8a31d42
·
verified ·
1 Parent(s): 450f3df

Update app/search.py

Browse files
Files changed (1) hide show
  1. app/search.py +56 -16
app/search.py CHANGED
@@ -1,37 +1,77 @@
1
  import json
2
  from pathlib import Path
3
- from typing import List, Dict
4
- from sentence_transformers import SentenceTransformer
5
  import numpy as np
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def load_index(env: Dict):
8
- import faiss, json
9
  index_path = Path(env["INDEX_DIR"]) / "faiss.index"
10
  meta_path = Path(env["INDEX_DIR"]) / "meta.json"
11
  if not index_path.exists():
12
  raise RuntimeError("Index not found. Run ingest first.")
13
  index = faiss.read_index(str(index_path))
14
- metas = json.load(open(meta_path))
 
15
  return index, metas
16
 
17
- def embed(texts: List[str]):
18
- model = SentenceTransformer("all-MiniLM-L6-v2")
19
- return model.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
 
 
 
 
 
 
 
 
 
 
20
 
21
  def search(q: str, env: Dict, top_k: int = 15, filters: Dict = None) -> List[Dict]:
 
22
  index, metas = load_index(env)
23
- qv = embed([q])
24
- scores, idxs = index.search(qv, top_k)
 
 
 
 
 
 
 
25
  results = []
 
 
 
26
  for score, idx in zip(scores[0], idxs[0]):
27
- if idx == -1: continue
28
- m = metas[idx]
29
- if filters:
30
- if "geo" in filters and filters["geo"] and m.get("geo") not in filters["geo"]:
 
 
 
31
  continue
32
- if "categories" in filters and filters["categories"]:
33
- if not set(filters["categories"]).intersection(set(m.get("categories",[]))):
34
- continue
35
  m["score"] = float(score)
36
  results.append(m)
 
37
  return results
 
1
  import json
2
  from pathlib import Path
3
+ from typing import List, Dict, Optional
 
4
  import numpy as np
5
 
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ # ---------- Global embedder (loaded once, CPU-safe) ----------
9
+ _EMBEDDER: Optional[SentenceTransformer] = None
10
+
11
+ def _get_embedder() -> SentenceTransformer:
12
+ global _EMBEDDER
13
+ if _EMBEDDER is None:
14
+ # Explicit device="cpu" avoids any device_map/meta init paths.
15
+ # Use the canonical model id to avoid redirect surprises.
16
+ _EMBEDDER = SentenceTransformer(
17
+ "sentence-transformers/all-MiniLM-L6-v2",
18
+ device="cpu"
19
+ )
20
+ # Optional: shorten for speed on Spaces; keep accuracy reasonable
21
+ _EMBEDDER.max_seq_length = 256
22
+ return _EMBEDDER
23
+
24
  def load_index(env: Dict):
25
+ import faiss
26
  index_path = Path(env["INDEX_DIR"]) / "faiss.index"
27
  meta_path = Path(env["INDEX_DIR"]) / "meta.json"
28
  if not index_path.exists():
29
  raise RuntimeError("Index not found. Run ingest first.")
30
  index = faiss.read_index(str(index_path))
31
+ with open(meta_path, "r") as f:
32
+ metas = json.load(f)
33
  return index, metas
34
 
35
+ def embed(texts: List[str]) -> np.ndarray:
36
+ emb = _get_embedder()
37
+ vecs = emb.encode(
38
+ texts,
39
+ convert_to_numpy=True,
40
+ normalize_embeddings=True,
41
+ show_progress_bar=False,
42
+ batch_size=32,
43
+ )
44
+ # FAISS expects float32
45
+ if vecs.dtype != np.float32:
46
+ vecs = vecs.astype(np.float32, copy=False)
47
+ return vecs
48
 
49
  def search(q: str, env: Dict, top_k: int = 15, filters: Dict = None) -> List[Dict]:
50
+ import faiss
51
  index, metas = load_index(env)
52
+
53
+ qv = embed([q]) # shape (1, d) float32
54
+ # Defensive: ensure index dim matches query dim
55
+ if hasattr(index, "d") and index.d != qv.shape[1]:
56
+ raise RuntimeError(f"FAISS index dim {getattr(index, 'd', '?')} "
57
+ f"!= embedding dim {qv.shape[1]}")
58
+
59
+ scores, idxs = index.search(qv, top_k) # scores shape (1, k), idxs shape (1, k)
60
+
61
  results = []
62
+ f_geo = (filters or {}).get("geo")
63
+ f_cats = (filters or {}).get("categories")
64
+
65
  for score, idx in zip(scores[0], idxs[0]):
66
+ if idx == -1:
67
+ continue
68
+ m = dict(metas[idx]) # copy so we don’t mutate the cached list
69
+ if f_geo and m.get("geo") not in f_geo:
70
+ continue
71
+ if f_cats:
72
+ if not set(f_cats).intersection(set(m.get("categories", []))):
73
  continue
 
 
 
74
  m["score"] = float(score)
75
  results.append(m)
76
+
77
  return results