chouchouvs commited on
Commit
c724d0c
·
verified ·
1 Parent(s): 152b314

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +234 -0
main.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import annotations
3
+ import os, time, uuid, logging
4
+ from typing import List, Optional, Dict, Any, Tuple
5
+ import requests
6
+ import numpy as np
7
+ from fastapi import FastAPI, BackgroundTasks, Header, HTTPException
8
+ from pydantic import BaseModel, Field
9
+ from qdrant_client import QdrantClient
10
+ from qdrant_client.http.models import VectorParams, Distance, PointStruct
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ LOG = logging.getLogger("remote_indexer")
14
+
15
+ # ---------- ENV ----------
16
+ AUTH_TOKEN = os.getenv("REMOTE_INDEX_TOKEN", "").strip() # simple header auth
17
+ HF_TOKEN = os.getenv("HF_API_TOKEN", "").strip()
18
+ HF_MODEL = os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
19
+ HF_URL = os.getenv("HF_API_URL", "").strip() or f"https://api-inference.huggingface.co/pipeline/feature-extraction/{HF_MODEL}"
20
+
21
+ QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
22
+ QDRANT_API = os.getenv("QDRANT_API_KEY", "").strip()
23
+
24
+ if not HF_TOKEN:
25
+ LOG.warning("HF_API_TOKEN manquant — le service refusera /index et /query.")
26
+
27
+ # ---------- Clients ----------
28
+ qdr = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API if QDRANT_API else None)
29
+
30
+ # ---------- Pydantic ----------
31
+ class FileIn(BaseModel):
32
+ path: str
33
+ text: str
34
+
35
+ class IndexRequest(BaseModel):
36
+ project_id: str = Field(..., min_length=1)
37
+ files: List[FileIn]
38
+ chunk_size: int = 1200
39
+ overlap: int = 200
40
+ batch_size: int = 8
41
+ store_text: bool = True
42
+
43
+ class QueryRequest(BaseModel):
44
+ project_id: str
45
+ query: str
46
+ top_k: int = 6
47
+
48
+ # ---------- Jobs store (en mémoire) ----------
49
+ JOBS: Dict[str, Dict[str, Any]] = {} # {job_id: {"status": "...", "logs": [...], "created": ts}}
50
+
51
+ # ---------- Utils ----------
52
+ def _auth(x_auth: Optional[str]):
53
+ if AUTH_TOKEN and (x_auth or "") != AUTH_TOKEN:
54
+ raise HTTPException(status_code=401, detail="Unauthorized")
55
+
56
+ def _post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
57
+ if not HF_TOKEN:
58
+ raise RuntimeError("HF_API_TOKEN manquant (server).")
59
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
60
+ r = requests.post(HF_URL, headers=headers, json=batch, timeout=120)
61
+ size = int(r.headers.get("Content-Length", "0"))
62
+ r.raise_for_status()
63
+ data = r.json()
64
+ arr = np.array(data, dtype=np.float32)
65
+ # arr: [batch, dim] (sentence-transformers)
66
+ # ou [batch, tokens, dim] -> mean pooling
67
+ if arr.ndim == 3:
68
+ arr = arr.mean(axis=1)
69
+ if arr.ndim != 2:
70
+ raise RuntimeError(f"Unexpected embeddings shape: {arr.shape}")
71
+ # normalisation
72
+ norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
73
+ arr = arr / norms
74
+ return arr.astype(np.float32), size
75
+
76
+ def _ensure_collection(name: str, dim: int):
77
+ try:
78
+ qdr.get_collection(name)
79
+ return
80
+ except Exception:
81
+ pass
82
+ qdr.create_collection(
83
+ collection_name=name,
84
+ vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
85
+ )
86
+
87
+ def _chunk_with_spans(text: str, size: int, overlap: int):
88
+ n = len(text)
89
+ if size <= 0:
90
+ yield (0, n, text)
91
+ return
92
+ i = 0
93
+ while i < n:
94
+ j = min(n, i + size)
95
+ yield (i, j, text[i:j])
96
+ i = max(0, j - overlap)
97
+ if i >= n:
98
+ break
99
+
100
+ def _append_log(job_id: str, line: str):
101
+ job = JOBS.get(job_id)
102
+ if not job: return
103
+ job["logs"].append(line)
104
+
105
+ def _set_status(job_id: str, status: str):
106
+ job = JOBS.get(job_id)
107
+ if not job: return
108
+ job["status"] = status
109
+
110
+ # ---------- Background task ----------
111
+ def run_index_job(job_id: str, req: IndexRequest):
112
+ try:
113
+ _set_status(job_id, "running")
114
+ total_chunks = 0
115
+ LOG.info(f"[{job_id}] Index start project={req.project_id} files={len(req.files)}")
116
+ _append_log(job_id, f"Start project={req.project_id} files={len(req.files)}")
117
+
118
+ # premier batch pour récupérer la dimension
119
+ # on prépare un mini lot
120
+ warmup = []
121
+ for f in req.files[:1]:
122
+ warmup.append(next(_chunk_with_spans(f.text, req.chunk_size, req.overlap))[2])
123
+ embs, sz = _post_embeddings(warmup)
124
+ dim = embs.shape[1]
125
+ col = f"proj_{req.project_id}"
126
+ _ensure_collection(col, dim)
127
+ _append_log(job_id, f"Collection ready: {col} (dim={dim})")
128
+
129
+ points_buffer: List[PointStruct] = []
130
+ point_id = 0
131
+
132
+ def flush_points():
133
+ nonlocal points_buffer
134
+ if points_buffer:
135
+ qdr.upsert(collection_name=col, points=points_buffer)
136
+ points_buffer = []
137
+
138
+ # boucle fichiers
139
+ for fi, f in enumerate(req.files, 1):
140
+ chunks, metas = [], []
141
+ for ci, (start, end, chunk_txt) in enumerate(_chunk_with_spans(f.text, req.chunk_size, req.overlap)):
142
+ chunks.append(chunk_txt)
143
+ payload = {"path": f.path, "chunk": ci, "start": start, "end": end}
144
+ if req.store_text:
145
+ payload["text"] = chunk_txt
146
+ metas.append(payload)
147
+
148
+ if len(chunks) >= req.batch_size:
149
+ vecs, sz = _post_embeddings(chunks)
150
+ batch_points = []
151
+ for k, vec in enumerate(vecs):
152
+ batch_points.append(PointStruct(id=point_id, vector=vec.tolist(), payload=metas[k]))
153
+ point_id += 1
154
+ qdr.upsert(collection_name=col, points=batch_points)
155
+ total_chunks += len(chunks)
156
+ _append_log(job_id, f"file {fi}/{len(req.files)}: +{len(chunks)} chunks (total={total_chunks}) ~{sz/1024:.1f}KiB")
157
+ chunks, metas = [], []
158
+
159
+ # flush fin de fichier
160
+ if chunks:
161
+ vecs, sz = _post_embeddings(chunks)
162
+ batch_points = []
163
+ for k, vec in enumerate(vecs):
164
+ batch_points.append(PointStruct(id=point_id, vector=vec.tolist(), payload=metas[k]))
165
+ point_id += 1
166
+ qdr.upsert(collection_name=col, points=batch_points)
167
+ total_chunks += len(chunks)
168
+ _append_log(job_id, f"file {fi}/{len(req.files)}: +{len(chunks)} chunks (total={total_chunks}) ~{sz/1024:.1f}KiB")
169
+
170
+ flush_points()
171
+ _append_log(job_id, f"Done. chunks={total_chunks}")
172
+ _set_status(job_id, "done")
173
+ LOG.info(f"[{job_id}] Index finished. chunks={total_chunks}")
174
+ except Exception as e:
175
+ LOG.exception("Index job failed")
176
+ _append_log(job_id, f"ERROR: {e}")
177
+ _set_status(job_id, "error")
178
+
179
+ # ---------- API ----------
180
+ app = FastAPI()
181
+
182
+ @app.get("/health")
183
+ def health():
184
+ return {"ok": True}
185
+
186
+ @app.post("/index")
187
+ def start_index(req: IndexRequest, background_tasks: BackgroundTasks, x_auth_token: Optional[str] = Header(default=None)):
188
+ _auth(x_auth_token)
189
+ if not HF_TOKEN:
190
+ raise HTTPException(400, "HF_API_TOKEN manquant côté serveur.")
191
+ job_id = uuid.uuid4().hex[:12]
192
+ JOBS[job_id] = {"status": "queued", "logs": [], "created": time.time()}
193
+ background_tasks.add_task(run_index_job, job_id, req)
194
+ return {"job_id": job_id}
195
+
196
+ @app.get("/status/{job_id}")
197
+ def status(job_id: str, x_auth_token: Optional[str] = Header(default=None)):
198
+ _auth(x_auth_token)
199
+ j = JOBS.get(job_id)
200
+ if not j:
201
+ raise HTTPException(404, "job inconnu")
202
+ return {"status": j["status"], "logs": j["logs"][-800:]}
203
+
204
+ @app.post("/query")
205
+ def query(req: QueryRequest, x_auth_token: Optional[str] = Header(default=None)):
206
+ _auth(x_auth_token)
207
+ if not HF_TOKEN:
208
+ raise HTTPException(400, "HF_API_TOKEN manquant côté serveur.")
209
+ vec, _ = _post_embeddings([req.query])
210
+ vec = vec[0].tolist()
211
+ col = f"proj_{req.project_id}"
212
+ try:
213
+ res = qdr.search(collection_name=col, query_vector=vec, limit=int(req.top_k))
214
+ except Exception as e:
215
+ raise HTTPException(400, f"Search failed: {e}")
216
+ out = []
217
+ for p in res:
218
+ pl = p.payload or {}
219
+ txt = pl.get("text")
220
+ # hard cap snippet size
221
+ if txt and len(txt) > 800:
222
+ txt = txt[:800] + "..."
223
+ out.append({"path": pl.get("path"), "chunk": pl.get("chunk"), "start": pl.get("start"), "end": pl.get("end"), "text": txt})
224
+ return {"results": out}
225
+
226
+ @app.post("/wipe")
227
+ def wipe_collection(project_id: str, x_auth_token: Optional[str] = Header(default=None)):
228
+ _auth(x_auth_token)
229
+ col = f"proj_{project_id}"
230
+ try:
231
+ qdr.delete_collection(col)
232
+ return {"ok": True}
233
+ except Exception as e:
234
+ raise HTTPException(400, f"wipe failed: {e}")