AdarshDRC commited on
Commit
8affd2a
·
verified ·
1 Parent(s): ef7075f

Create search.py

Browse files
Files changed (1) hide show
  1. src/api/search.py +150 -0
src/api/search.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import time
3
+ import traceback
4
+
5
+ from fastapi import APIRouter, File, Form, HTTPException, Request, UploadFile, Depends
6
+
7
+ from src.core.config import DEFAULT_PINECONE_KEY, IDX_FACES, IDX_OBJECTS
8
+ from src.core.security import get_verified_keys
9
+ from src.services.db_client import (
10
+ merge_face_results, merge_object_results,
11
+ pinecone_pool, search_faces, search_objects,
12
+ )
13
+ from src.core.logging import log
14
+ from src.common.utils import face_ui_score, get_ip, is_default_key, to_list
15
+
16
+ router = APIRouter()
17
+
18
+ @router.post("/api/search")
19
+ async def search_database(
20
+ request: Request,
21
+ file: UploadFile = File(...),
22
+ detect_faces: bool = Form(True),
23
+ user_id: str = Form(""),
24
+ keys: dict = Depends(get_verified_keys)
25
+ ):
26
+ ip = get_ip(request)
27
+ start = time.perf_counter()
28
+ mode = "guest" if is_default_key(keys["pinecone_key"], DEFAULT_PINECONE_KEY) else "personal"
29
+
30
+ log("INFO", "search.start", user_id=user_id or "anonymous", ip=ip, mode=mode,
31
+ filename=file.filename, detect_faces=detect_faces)
32
+
33
+ try:
34
+ file_bytes = await file.read()
35
+ ai_manager = request.app.state.ai
36
+ sem = request.app.state.ai_semaphore
37
+
38
+ async with sem:
39
+ vectors = await ai_manager.process_image_bytes_async(file_bytes, detect_faces=detect_faces)
40
+
41
+ inference_ms = round((time.perf_counter() - start) * 1000)
42
+ face_vectors = [v for v in vectors if v["type"] == "face"]
43
+ object_vectors = [v for v in vectors if v["type"] == "object"]
44
+ lanes_used = list({v["type"] for v in vectors})
45
+
46
+ log("INFO", "search.inference_done", user_id=user_id or "anonymous", ip=ip, mode=mode,
47
+ face_vecs=len(face_vectors), obj_vecs=len(object_vectors), inference_ms=inference_ms)
48
+
49
+ pc = pinecone_pool.get(keys["pinecone_key"])
50
+ idx_obj = pc.Index(IDX_OBJECTS)
51
+ idx_face = pc.Index(IDX_FACES)
52
+
53
+ if detect_faces and face_vectors:
54
+ return await _run_face_search(face_vectors, object_vectors, idx_face, idx_obj, start, user_id, ip, mode, lanes_used)
55
+ else:
56
+ return await _run_object_search(object_vectors, idx_obj, start, user_id, ip, mode, lanes_used)
57
+
58
+ except HTTPException:
59
+ raise
60
+ except Exception as e:
61
+ log("ERROR", "search.error", user_id=user_id or "anonymous", ip=ip, mode=mode,
62
+ error=str(e), traceback=traceback.format_exc()[-800:])
63
+ raise HTTPException(500, str(e))
64
+
65
+ async def _run_face_search(face_vectors, object_vectors, idx_face, idx_obj, start, user_id, ip, mode, lanes_used) -> dict:
66
+ async def _query_face(fv: dict) -> dict:
67
+ vec = to_list(fv["vector"])
68
+ det_score = fv.get("det_score", 1.0)
69
+ try:
70
+ image_map = await asyncio.to_thread(search_faces, idx_face, vec, det_score)
71
+ except Exception as e:
72
+ if "404" in str(e):
73
+ raise HTTPException(404, "Pinecone index not found. Go to Settings → Verify & Save.")
74
+ raise
75
+ return {
76
+ "query_face_idx": fv.get("face_idx", 0),
77
+ "query_face_crop": fv.get("face_crop", ""),
78
+ "query_bbox": fv.get("bbox", []),
79
+ "det_score": det_score,
80
+ "face_width_px": fv.get("face_width_px", 0),
81
+ "matches": sorted(
82
+ [
83
+ {
84
+ "url": url,
85
+ "score": face_ui_score(d["raw_score"]),
86
+ "raw_score": round(d["raw_score"], 4),
87
+ "face_crop": d["face_crop"],
88
+ "folder": d["folder"],
89
+ "caption": "👤 Verified Identity",
90
+ }
91
+ for url, d in image_map.items()
92
+ ],
93
+ key=lambda x: x["score"], reverse=True,
94
+ )[:50],
95
+ }
96
+
97
+ async def _query_obj_single(ov: dict) -> list:
98
+ vec = to_list(ov["vector"])
99
+ try:
100
+ return await asyncio.to_thread(search_objects, idx_obj, vec)
101
+ except Exception as e:
102
+ if "404" in str(e):
103
+ raise HTTPException(404, "Pinecone index not found.")
104
+ raise
105
+
106
+ face_tasks = [_query_face(fv) for fv in face_vectors]
107
+ obj_tasks = [_query_obj_single(ov) for ov in object_vectors]
108
+ all_results = await asyncio.gather(*face_tasks, *obj_tasks)
109
+
110
+ raw_groups = list(all_results[:len(face_tasks)])
111
+ obj_nested = list(all_results[len(face_tasks):])
112
+
113
+ merged_face = merge_face_results(raw_groups)
114
+ merged_objects = merge_object_results(obj_nested)
115
+
116
+ face_groups = [g for g in raw_groups if g.get("matches")]
117
+
118
+ duration_ms = round((time.perf_counter() - start) * 1000)
119
+ log("INFO", "search.complete", user_id=user_id or "anonymous", ip=ip, mode=mode,
120
+ lanes=["face", "object"], face_groups=len(face_groups), face_results=len(merged_face),
121
+ object_results=len(merged_objects), duration_ms=duration_ms)
122
+
123
+ return {
124
+ "mode": "face",
125
+ "face_groups": face_groups,
126
+ "results": merged_face,
127
+ "object_results": merged_objects,
128
+ }
129
+
130
+ async def _run_object_search(object_vectors, idx_obj, start, user_id, ip, mode, lanes_used) -> dict:
131
+ if not object_vectors:
132
+ return {"mode": "object", "results": [], "face_groups": []}
133
+
134
+ async def _query_obj(ov: dict) -> list:
135
+ vec = to_list(ov["vector"])
136
+ try:
137
+ return await asyncio.to_thread(search_objects, idx_obj, vec)
138
+ except Exception as e:
139
+ if "404" in str(e):
140
+ raise HTTPException(404, "Pinecone index not found.")
141
+ raise
142
+
143
+ nested = await asyncio.gather(*[_query_obj(ov) for ov in object_vectors])
144
+ final = merge_object_results(nested)
145
+
146
+ duration_ms = round((time.perf_counter() - start) * 1000)
147
+ log("INFO", "search.complete", user_id=user_id or "anonymous", ip=ip, mode=mode,
148
+ lanes=lanes_used, results=len(final), duration_ms=duration_ms)
149
+
150
+ return {"mode": "object", "results": final, "face_groups": []}