Spaces:
Running
Running
Create search.py
Browse files- src/api/search.py +150 -0
src/api/search.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import time
|
| 3 |
+
import traceback
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, File, Form, HTTPException, Request, UploadFile, Depends
|
| 6 |
+
|
| 7 |
+
from src.core.config import DEFAULT_PINECONE_KEY, IDX_FACES, IDX_OBJECTS
|
| 8 |
+
from src.core.security import get_verified_keys
|
| 9 |
+
from src.services.db_client import (
|
| 10 |
+
merge_face_results, merge_object_results,
|
| 11 |
+
pinecone_pool, search_faces, search_objects,
|
| 12 |
+
)
|
| 13 |
+
from src.core.logging import log
|
| 14 |
+
from src.common.utils import face_ui_score, get_ip, is_default_key, to_list
|
| 15 |
+
|
| 16 |
+
router = APIRouter()
|
| 17 |
+
|
| 18 |
+
@router.post("/api/search")
|
| 19 |
+
async def search_database(
|
| 20 |
+
request: Request,
|
| 21 |
+
file: UploadFile = File(...),
|
| 22 |
+
detect_faces: bool = Form(True),
|
| 23 |
+
user_id: str = Form(""),
|
| 24 |
+
keys: dict = Depends(get_verified_keys)
|
| 25 |
+
):
|
| 26 |
+
ip = get_ip(request)
|
| 27 |
+
start = time.perf_counter()
|
| 28 |
+
mode = "guest" if is_default_key(keys["pinecone_key"], DEFAULT_PINECONE_KEY) else "personal"
|
| 29 |
+
|
| 30 |
+
log("INFO", "search.start", user_id=user_id or "anonymous", ip=ip, mode=mode,
|
| 31 |
+
filename=file.filename, detect_faces=detect_faces)
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
file_bytes = await file.read()
|
| 35 |
+
ai_manager = request.app.state.ai
|
| 36 |
+
sem = request.app.state.ai_semaphore
|
| 37 |
+
|
| 38 |
+
async with sem:
|
| 39 |
+
vectors = await ai_manager.process_image_bytes_async(file_bytes, detect_faces=detect_faces)
|
| 40 |
+
|
| 41 |
+
inference_ms = round((time.perf_counter() - start) * 1000)
|
| 42 |
+
face_vectors = [v for v in vectors if v["type"] == "face"]
|
| 43 |
+
object_vectors = [v for v in vectors if v["type"] == "object"]
|
| 44 |
+
lanes_used = list({v["type"] for v in vectors})
|
| 45 |
+
|
| 46 |
+
log("INFO", "search.inference_done", user_id=user_id or "anonymous", ip=ip, mode=mode,
|
| 47 |
+
face_vecs=len(face_vectors), obj_vecs=len(object_vectors), inference_ms=inference_ms)
|
| 48 |
+
|
| 49 |
+
pc = pinecone_pool.get(keys["pinecone_key"])
|
| 50 |
+
idx_obj = pc.Index(IDX_OBJECTS)
|
| 51 |
+
idx_face = pc.Index(IDX_FACES)
|
| 52 |
+
|
| 53 |
+
if detect_faces and face_vectors:
|
| 54 |
+
return await _run_face_search(face_vectors, object_vectors, idx_face, idx_obj, start, user_id, ip, mode, lanes_used)
|
| 55 |
+
else:
|
| 56 |
+
return await _run_object_search(object_vectors, idx_obj, start, user_id, ip, mode, lanes_used)
|
| 57 |
+
|
| 58 |
+
except HTTPException:
|
| 59 |
+
raise
|
| 60 |
+
except Exception as e:
|
| 61 |
+
log("ERROR", "search.error", user_id=user_id or "anonymous", ip=ip, mode=mode,
|
| 62 |
+
error=str(e), traceback=traceback.format_exc()[-800:])
|
| 63 |
+
raise HTTPException(500, str(e))
|
| 64 |
+
|
| 65 |
+
async def _run_face_search(face_vectors, object_vectors, idx_face, idx_obj, start, user_id, ip, mode, lanes_used) -> dict:
|
| 66 |
+
async def _query_face(fv: dict) -> dict:
|
| 67 |
+
vec = to_list(fv["vector"])
|
| 68 |
+
det_score = fv.get("det_score", 1.0)
|
| 69 |
+
try:
|
| 70 |
+
image_map = await asyncio.to_thread(search_faces, idx_face, vec, det_score)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
if "404" in str(e):
|
| 73 |
+
raise HTTPException(404, "Pinecone index not found. Go to Settings → Verify & Save.")
|
| 74 |
+
raise
|
| 75 |
+
return {
|
| 76 |
+
"query_face_idx": fv.get("face_idx", 0),
|
| 77 |
+
"query_face_crop": fv.get("face_crop", ""),
|
| 78 |
+
"query_bbox": fv.get("bbox", []),
|
| 79 |
+
"det_score": det_score,
|
| 80 |
+
"face_width_px": fv.get("face_width_px", 0),
|
| 81 |
+
"matches": sorted(
|
| 82 |
+
[
|
| 83 |
+
{
|
| 84 |
+
"url": url,
|
| 85 |
+
"score": face_ui_score(d["raw_score"]),
|
| 86 |
+
"raw_score": round(d["raw_score"], 4),
|
| 87 |
+
"face_crop": d["face_crop"],
|
| 88 |
+
"folder": d["folder"],
|
| 89 |
+
"caption": "👤 Verified Identity",
|
| 90 |
+
}
|
| 91 |
+
for url, d in image_map.items()
|
| 92 |
+
],
|
| 93 |
+
key=lambda x: x["score"], reverse=True,
|
| 94 |
+
)[:50],
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
async def _query_obj_single(ov: dict) -> list:
|
| 98 |
+
vec = to_list(ov["vector"])
|
| 99 |
+
try:
|
| 100 |
+
return await asyncio.to_thread(search_objects, idx_obj, vec)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
if "404" in str(e):
|
| 103 |
+
raise HTTPException(404, "Pinecone index not found.")
|
| 104 |
+
raise
|
| 105 |
+
|
| 106 |
+
face_tasks = [_query_face(fv) for fv in face_vectors]
|
| 107 |
+
obj_tasks = [_query_obj_single(ov) for ov in object_vectors]
|
| 108 |
+
all_results = await asyncio.gather(*face_tasks, *obj_tasks)
|
| 109 |
+
|
| 110 |
+
raw_groups = list(all_results[:len(face_tasks)])
|
| 111 |
+
obj_nested = list(all_results[len(face_tasks):])
|
| 112 |
+
|
| 113 |
+
merged_face = merge_face_results(raw_groups)
|
| 114 |
+
merged_objects = merge_object_results(obj_nested)
|
| 115 |
+
|
| 116 |
+
face_groups = [g for g in raw_groups if g.get("matches")]
|
| 117 |
+
|
| 118 |
+
duration_ms = round((time.perf_counter() - start) * 1000)
|
| 119 |
+
log("INFO", "search.complete", user_id=user_id or "anonymous", ip=ip, mode=mode,
|
| 120 |
+
lanes=["face", "object"], face_groups=len(face_groups), face_results=len(merged_face),
|
| 121 |
+
object_results=len(merged_objects), duration_ms=duration_ms)
|
| 122 |
+
|
| 123 |
+
return {
|
| 124 |
+
"mode": "face",
|
| 125 |
+
"face_groups": face_groups,
|
| 126 |
+
"results": merged_face,
|
| 127 |
+
"object_results": merged_objects,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
async def _run_object_search(object_vectors, idx_obj, start, user_id, ip, mode, lanes_used) -> dict:
|
| 131 |
+
if not object_vectors:
|
| 132 |
+
return {"mode": "object", "results": [], "face_groups": []}
|
| 133 |
+
|
| 134 |
+
async def _query_obj(ov: dict) -> list:
|
| 135 |
+
vec = to_list(ov["vector"])
|
| 136 |
+
try:
|
| 137 |
+
return await asyncio.to_thread(search_objects, idx_obj, vec)
|
| 138 |
+
except Exception as e:
|
| 139 |
+
if "404" in str(e):
|
| 140 |
+
raise HTTPException(404, "Pinecone index not found.")
|
| 141 |
+
raise
|
| 142 |
+
|
| 143 |
+
nested = await asyncio.gather(*[_query_obj(ov) for ov in object_vectors])
|
| 144 |
+
final = merge_object_results(nested)
|
| 145 |
+
|
| 146 |
+
duration_ms = round((time.perf_counter() - start) * 1000)
|
| 147 |
+
log("INFO", "search.complete", user_id=user_id or "anonymous", ip=ip, mode=mode,
|
| 148 |
+
lanes=lanes_used, results=len(final), duration_ms=duration_ms)
|
| 149 |
+
|
| 150 |
+
return {"mode": "object", "results": final, "face_groups": []}
|