chouchouvs commited on
Commit
750efe9
·
verified ·
1 Parent(s): ca62f03

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +161 -217
main.py CHANGED
@@ -1,26 +1,14 @@
1
  # -*- coding: utf-8 -*-
2
- """
3
- FastAPI + Gradio : service d’indexation asynchrone avec FAISS.
4
- Ce fichier a été corrigé pour :
5
-
6
- * importer correctement `JobState` (import relatif)
7
- * garantir que le répertoire `app` est dans le PYTHONPATH lorsqu’on lance le script
8
- * conserver toutes les fonctionnalités précédentes (indexation, recherche, UI)
9
- """
10
-
11
  from __future__ import annotations
12
 
13
  import os
14
  import io
15
  import json
16
  import time
17
- import hashlib
18
- import logging
19
  import tarfile
20
- import sys
21
- from pathlib import Path
22
- from typing import List, Dict, Any, Tuple, Optional
23
-
24
  from concurrent.futures import ThreadPoolExecutor
25
 
26
  import numpy as np
@@ -32,19 +20,9 @@ from pydantic import BaseModel
32
 
33
  import gradio as gr
34
 
35
- # --------------------------------------------------------------------------- #
36
- # RÉGLAGE DU PYTHONPATH (pour que les imports relatifs fonctionnent)
37
- # --------------------------------------------------------------------------- #
38
- # Si le script est lancé depuis le répertoire `app/`, le package `app` n’est pas
39
- # découvert automatiquement. On ajoute le répertoire parent au sys.path.
40
- CURRENT_DIR = Path(__file__).resolve().parent
41
- PROJECT_ROOT = CURRENT_DIR.parent
42
- if str(PROJECT_ROOT) not in sys.path:
43
- sys.path.insert(0, str(PROJECT_ROOT))
44
-
45
- # --------------------------------------------------------------------------- #
46
- # LOGGING
47
- # --------------------------------------------------------------------------- #
48
  LOG = logging.getLogger("remote-indexer-async")
49
  if not LOG.handlers:
50
  h = logging.StreamHandler()
@@ -59,25 +37,31 @@ if not DBG.handlers:
59
  DBG.addHandler(hd)
60
  DBG.setLevel(logging.DEBUG)
61
 
62
- # --------------------------------------------------------------------------- #
63
- # CONFIGURATION (variables d’environnement)
64
- # --------------------------------------------------------------------------- #
65
  PORT = int(os.getenv("PORT", "7860"))
66
- DATA_ROOT = os.getenv("DATA_ROOT", "/tmp/data")
67
  os.makedirs(DATA_ROOT, exist_ok=True)
68
 
 
 
 
 
69
  EMB_PROVIDER = os.getenv("EMB_PROVIDER", "dummy").strip().lower()
70
  EMB_MODEL = os.getenv("EMB_MODEL", "sentence-transformers/all-mpnet-base-v2").strip()
71
  EMB_BATCH = int(os.getenv("EMB_BATCH", "32"))
72
- EMB_DIM = int(os.getenv("EMB_DIM", "64")) # dimension réduite (optimisation)
73
 
 
74
  MAX_WORKERS = int(os.getenv("MAX_WORKERS", "1"))
75
 
76
- # --------------------------------------------------------------------------- #
77
- # CACHE DIRECTORIES (évite PermissionError)
78
- # --------------------------------------------------------------------------- #
79
  def _setup_cache_dirs() -> Dict[str, str]:
80
  os.environ.setdefault("HOME", "/home/user")
 
81
  CACHE_ROOT = os.getenv("CACHE_ROOT", "/tmp/.cache").rstrip("/")
82
  paths = {
83
  "root": CACHE_ROOT,
@@ -103,22 +87,32 @@ def _setup_cache_dirs() -> Dict[str, str]:
103
  os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
104
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
105
 
106
- LOG.info("Caches configurés: %s", json.dumps(paths, indent=2))
107
  return paths
108
 
109
-
110
  CACHE_PATHS = _setup_cache_dirs()
111
 
112
- # --------------------------------------------------------------------------- #
113
- # IMPORT DE LA CLASSE DE STATE (corrigé : import relatif)
114
- # --------------------------------------------------------------------------- #
115
- # Le fichier `index_state.py` se trouve dans `app/core/`.
116
- # En étant dans le répertoire `app`, on peut l’importer via le package `core`.
117
- from core.index_state import JobState # <-- IMPORT CORRIGÉ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- # --------------------------------------------------------------------------- #
120
- # GLOBALS
121
- # --------------------------------------------------------------------------- #
122
  JOBS: Dict[str, JobState] = {}
123
 
124
  def _now() -> str:
@@ -132,18 +126,18 @@ def _proj_dirs(project_id: str) -> Tuple[str, str, str]:
132
  os.makedirs(fx_dir, exist_ok=True)
133
  return base, ds_dir, fx_dir
134
 
135
- def _add_msg(st: JobState, msg: str) -> None:
136
  st.messages.append(f"[{_now()}] {msg}")
137
  LOG.info("[%s] %s", st.job_id, msg)
138
  DBG.debug("[%s] %s", st.job_id, msg)
139
 
140
- def _set_stage(st: JobState, stage: str) -> None:
141
  st.stage = stage
142
  _add_msg(st, f"stage={stage}")
143
 
144
- # --------------------------------------------------------------------------- #
145
- # UTILITAIRES (chunking, normalisation, etc.)
146
- # --------------------------------------------------------------------------- #
147
  def _chunk_text(text: str, size: int = 200, overlap: int = 20) -> List[str]:
148
  text = (text or "").replace("\r\n", "\n")
149
  tokens = list(text)
@@ -167,13 +161,7 @@ def _l2_normalize(x: np.ndarray) -> np.ndarray:
167
  n = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
168
  return x / n
169
 
170
- # --------------------------------------------------------------------------- #
171
- # EMBEDDING PROVIDERS
172
- # --------------------------------------------------------------------------- #
173
- _ST_MODEL = None
174
- _HF_TOKENIZER = None
175
- _HF_MODEL = None
176
-
177
  def _emb_dummy(texts: List[str], dim: int = EMB_DIM) -> np.ndarray:
178
  vecs = np.zeros((len(texts), dim), dtype="float32")
179
  for i, t in enumerate(texts):
@@ -183,12 +171,13 @@ def _emb_dummy(texts: List[str], dim: int = EMB_DIM) -> np.ndarray:
183
  vecs[i] = v / (np.linalg.norm(v) + 1e-9)
184
  return vecs
185
 
 
186
  def _get_st_model():
187
  global _ST_MODEL
188
  if _ST_MODEL is None:
189
  from sentence_transformers import SentenceTransformer
190
  _ST_MODEL = SentenceTransformer(EMB_MODEL, cache_folder=CACHE_PATHS["st"])
191
- LOG.info("[st] modèle chargé : %s (cache=%s)", EMB_MODEL, CACHE_PATHS["st"])
192
  return _ST_MODEL
193
 
194
  def _emb_st(texts: List[str]) -> np.ndarray:
@@ -202,6 +191,7 @@ def _emb_st(texts: List[str]) -> np.ndarray:
202
  ).astype("float32")
203
  return vecs
204
 
 
205
  def _get_hf_model():
206
  global _HF_TOKENIZER, _HF_MODEL
207
  if _HF_MODEL is None or _HF_TOKENIZER is None:
@@ -209,10 +199,10 @@ def _get_hf_model():
209
  _HF_TOKENIZER = AutoTokenizer.from_pretrained(EMB_MODEL, cache_dir=CACHE_PATHS["hf_tf"])
210
  _HF_MODEL = AutoModel.from_pretrained(EMB_MODEL, cache_dir=CACHE_PATHS["hf_tf"])
211
  _HF_MODEL.eval()
212
- LOG.info("[hf] modèle chargé : %s (cache=%s)", EMB_MODEL, CACHE_PATHS["hf_tf"])
213
  return _HF_TOKENIZER, _HF_MODEL
214
 
215
- def _mean_pool(last_hidden_state: np.ndarray, attention_mask: np.ndarray) -> np.ndarray:
216
  mask = attention_mask[..., None].astype(last_hidden_state.dtype)
217
  summed = (last_hidden_state * mask).sum(axis=1)
218
  counts = mask.sum(axis=1).clip(min=1e-9)
@@ -221,28 +211,25 @@ def _mean_pool(last_hidden_state: np.ndarray, attention_mask: np.ndarray) -> np.
221
  def _emb_hf(texts: List[str]) -> np.ndarray:
222
  import torch
223
  tok, mod = _get_hf_model()
224
- all_vecs: List[np.ndarray] = []
225
  bs = max(1, EMB_BATCH)
226
  with torch.no_grad():
227
  for i in range(0, len(texts), bs):
228
- batch = texts[i:i + bs]
229
  enc = tok(batch, padding=True, truncation=True, return_tensors="pt")
230
  out = mod(**enc)
231
  last = out.last_hidden_state # (b, t, h)
232
  pooled = _mean_pool(last.numpy(), enc["attention_mask"].numpy())
233
  all_vecs.append(pooled.astype("float32"))
234
- return np.concatenate(all_vecs, axis=0)
 
235
 
236
- # --------------------------------------------------------------------------- #
237
- # DATASET / FAISS I/O
238
- # --------------------------------------------------------------------------- #
239
- def _save_dataset(ds_dir: str, rows: List[Dict[str, Any]], store_text: bool = True) -> None:
240
  os.makedirs(ds_dir, exist_ok=True)
241
  data_path = os.path.join(ds_dir, "data.jsonl")
242
  with open(data_path, "w", encoding="utf-8") as f:
243
  for r in rows:
244
- if not store_text:
245
- r = {k: v for k, v in r.items() if k != "text"}
246
  f.write(json.dumps(r, ensure_ascii=False) + "\n")
247
  meta = {"format": "jsonl", "columns": ["path", "text", "chunk_id"], "count": len(rows)}
248
  with open(os.path.join(ds_dir, "meta.json"), "w", encoding="utf-8") as f:
@@ -252,7 +239,7 @@ def _load_dataset(ds_dir: str) -> List[Dict[str, Any]]:
252
  data_path = os.path.join(ds_dir, "data.jsonl")
253
  if not os.path.isfile(data_path):
254
  return []
255
- out: List[Dict[str, Any]] = []
256
  with open(data_path, "r", encoding="utf-8") as f:
257
  for line in f:
258
  try:
@@ -261,163 +248,118 @@ def _load_dataset(ds_dir: str) -> List[Dict[str, Any]]:
261
  continue
262
  return out
263
 
264
- def _save_faiss(fx_dir: str, xb: np.ndarray, meta: Dict[str, Any]) -> None:
265
  os.makedirs(fx_dir, exist_ok=True)
266
  idx_path = os.path.join(fx_dir, "emb.faiss")
267
-
268
- # ------------------- INDEX QUANTISÉ (IVF‑PQ) ------------------- #
269
- quantizer = faiss.IndexFlatIP(xb.shape[1]) # inner‑product (cosine si normalisé)
270
- index = faiss.IndexIVFPQ(quantizer, xb.shape[1], 100, 8, 8) # nlist=100, m=8, nbits=8
271
-
272
- # entraînement sur un sous‑échantillon (max 10 k vecteurs)
273
- rng = np.random.default_rng(0)
274
- train = xb[rng.choice(xb.shape[0], min(10_000, xb.shape[0]), replace=False)]
275
- index.train(train)
276
-
277
  index.add(xb)
278
  faiss.write_index(index, idx_path)
279
-
280
- meta.update({"index_type": "IVF_PQ", "nlist": 100, "m": 8, "nbits": 8})
281
  with open(os.path.join(fx_dir, "meta.json"), "w", encoding="utf-8") as f:
282
  json.dump(meta, f, ensure_ascii=False, indent=2)
283
 
284
  def _load_faiss(fx_dir: str) -> faiss.Index:
285
  idx_path = os.path.join(fx_dir, "emb.faiss")
286
  if not os.path.isfile(idx_path):
287
- raise FileNotFoundError(f"FAISS index introuvable: {idx_path}")
288
- # mmap → l’index reste sur disque, la RAM n’est utilisée que pour les requêtes
289
- return faiss.read_index(idx_path, faiss.IO_FLAG_MMAP)
290
 
291
  def _tar_dir_to_bytes(dir_path: str) -> bytes:
292
  bio = io.BytesIO()
293
- with tarfile.open(fileobj=bio, mode="w:gz", compresslevel=9) as tar:
294
  tar.add(dir_path, arcname=os.path.basename(dir_path))
295
  bio.seek(0)
296
  return bio.read()
297
 
298
- # --------------------------------------------------------------------------- #
299
- # THREAD‑POOL (asynchrone)
300
- # --------------------------------------------------------------------------- #
301
  EXECUTOR = ThreadPoolExecutor(max_workers=max(1, MAX_WORKERS))
302
- LOG.info("ThreadPoolExecutor initialisé: max_workers=%s", MAX_WORKERS)
303
-
304
- def _do_index_job(
305
- st: JobState,
306
- files: List[Dict[str, str]],
307
- chunk_size: int,
308
- overlap: int,
309
- batch_size: int,
310
- store_text: bool,
311
- ) -> None:
312
  """
313
- Pipeline complet :
314
- 1️⃣ Chunking
315
- 2️⃣ Embedding (dummy / st / hf)
316
- 3️⃣ Réduction de dimension (PCA) si besoin
317
- 4️⃣ Sauvegarde du dataset (texte optionnel)
318
- 5️⃣ Index FAISS quantisé + mmap
319
  """
320
  try:
321
  base, ds_dir, fx_dir = _proj_dirs(st.project_id)
322
 
323
- # ------------------- 1️⃣ Chunking -------------------
324
  _set_stage(st, "chunking")
325
  rows: List[Dict[str, Any]] = []
326
  st.total_files = len(files)
327
-
328
- for f in files:
329
- path = (f.get("path") or "unknown").strip()
330
- txt = f.get("text") or ""
331
- chunks = _chunk_text(txt, size=chunk_size, overlap=overlap)
332
- for i, ck in enumerate(chunks):
333
- rows.append({"path": path, "text": ck, "chunk_id": i})
334
-
335
  st.total_chunks = len(rows)
336
  _add_msg(st, f"Total chunks = {st.total_chunks}")
337
 
338
- # ------------------- 2️⃣ Embedding -------------------
339
  _set_stage(st, "embedding")
340
  texts = [r["text"] for r in rows]
341
-
342
  if EMB_PROVIDER == "dummy":
343
  xb = _emb_dummy(texts, dim=EMB_DIM)
 
344
  elif EMB_PROVIDER == "st":
345
  xb = _emb_st(texts)
 
346
  else:
347
  xb = _emb_hf(texts)
348
-
349
- # ------------------- 3️⃣ Réduction PCA (si besoin) -------------------
350
- if xb.shape[1] != EMB_DIM:
351
- from sklearn.decomposition import PCA
352
- pca = PCA(n_components=EMB_DIM, random_state=0)
353
- xb = pca.fit_transform(xb).astype("float32")
354
- LOG.info("Réduction PCA appliquée : %d → %d dimensions", xb.shape[1], EMB_DIM)
355
 
356
  st.embedded = xb.shape[0]
357
- _add_msg(st, f"Embeddings générés : {st.embedded}")
 
358
 
359
- # ------------------- 4️⃣ Sauvegarde dataset -------------------
360
- _save_dataset(ds_dir, rows, store_text=store_text)
361
- _add_msg(st, f"Dataset sauvegardé dans {ds_dir}")
362
 
363
- # ------------------- 5️⃣ Index FAISS -------------------
364
  _set_stage(st, "indexing")
365
- meta = {
366
- "dim": int(xb.shape[1]),
367
  "count": int(xb.shape[0]),
368
  "provider": EMB_PROVIDER,
369
- "model": EMB_MODEL if EMB_PROVIDER != "dummy" else None,
370
  }
371
- _save_faiss(fx_dir, xb, meta)
372
  st.indexed = int(xb.shape[0])
373
  _add_msg(st, f"FAISS écrit sur {os.path.join(fx_dir, 'emb.faiss')}")
 
374
 
375
  _set_stage(st, "done")
376
  st.finished_at = time.time()
377
  except Exception as e:
378
- LOG.exception("Job %s échoué", st.job_id)
379
  st.errors.append(str(e))
380
- _add_msg(st, f"❌ Exception: {e}")
381
  st.stage = "failed"
382
  st.finished_at = time.time()
383
 
384
-
385
- def _submit_job(
386
- project_id: str,
387
- files: List[Dict[str, str]],
388
- chunk_size: int,
389
- overlap: int,
390
- batch_size: int,
391
- store_text: bool,
392
- ) -> str:
393
  job_id = hashlib.sha1(f"{project_id}{time.time()}".encode()).hexdigest()[:12]
394
  st = JobState(job_id=job_id, project_id=project_id, stage="pending", messages=[])
395
  JOBS[job_id] = st
 
 
396
 
397
- LOG.info("Job %s créé %d fichiers", job_id, len(files))
398
-
399
- EXECUTOR.submit(
400
- _do_index_job,
401
- st,
402
- files,
403
- chunk_size,
404
- overlap,
405
- batch_size,
406
- store_text,
407
- )
408
- st.stage = "queued"
409
  return job_id
410
 
411
- # --------------------------------------------------------------------------- #
412
- # FASTAPI
413
- # --------------------------------------------------------------------------- #
414
  fastapi_app = FastAPI(title="remote-indexer-async", version="3.0.0")
415
  fastapi_app.add_middleware(
416
  CORSMiddleware,
417
- allow_origins=["*"],
418
- allow_credentials=True,
419
- allow_methods=["*"],
420
- allow_headers=["*"],
421
  )
422
 
423
  class FileItem(BaseModel):
@@ -430,11 +372,11 @@ class IndexRequest(BaseModel):
430
  chunk_size: int = 200
431
  overlap: int = 20
432
  batch_size: int = 32
433
- store_text: bool = True # on peut désactiver via le payload ou env
434
 
435
  @fastapi_app.get("/health")
436
  def health():
437
- return {
438
  "ok": True,
439
  "service": "remote-indexer-async",
440
  "provider": EMB_PROVIDER,
@@ -442,13 +384,18 @@ def health():
442
  "cache_root": os.getenv("CACHE_ROOT", "/tmp/.cache"),
443
  "workers": MAX_WORKERS,
444
  "data_root": DATA_ROOT,
445
- "emb_dim": EMB_DIM,
446
  }
 
 
 
 
 
447
 
448
  @fastapi_app.post("/index")
449
  def index(req: IndexRequest):
450
  """
451
- Lancement asynchrone : renvoie immédiatement un `job_id`.
 
452
  """
453
  try:
454
  files = [fi.model_dump() for fi in req.files]
@@ -462,7 +409,7 @@ def index(req: IndexRequest):
462
  )
463
  return {"job_id": job_id}
464
  except Exception as e:
465
- LOG.exception("Erreur soumission index")
466
  raise HTTPException(status_code=500, detail=str(e))
467
 
468
  @fastapi_app.get("/status/{job_id}")
@@ -481,16 +428,17 @@ class SearchRequest(BaseModel):
481
  def search(req: SearchRequest):
482
  base, ds_dir, fx_dir = _proj_dirs(req.project_id)
483
 
484
- # Vérifier que lindex existe
485
- if not (os.path.isfile(os.path.join(fx_dir, "emb.faiss")) and
486
- os.path.isfile(os.path.join(ds_dir, "data.jsonl"))):
 
487
  raise HTTPException(status_code=409, detail="Index non prêt (reviens plus tard)")
488
 
489
  rows = _load_dataset(ds_dir)
490
  if not rows:
491
  raise HTTPException(status_code=404, detail="dataset introuvable")
492
 
493
- # Embedding de la requête (même provider que l’index)
494
  if EMB_PROVIDER == "dummy":
495
  q = _emb_dummy([req.query], dim=EMB_DIM)[0:1, :]
496
  elif EMB_PROVIDER == "st":
@@ -498,13 +446,10 @@ def search(req: SearchRequest):
498
  else:
499
  q = _emb_hf([req.query])[0:1, :]
500
 
501
- # Recherche FAISS (mmap)
502
  index = _load_faiss(fx_dir)
503
  if index.d != q.shape[1]:
504
- raise HTTPException(
505
- status_code=500,
506
- detail=f"dim incompatibles : index.d={index.d} vs query={q.shape[1]}",
507
- )
508
  scores, ids = index.search(q, int(max(1, req.k)))
509
  ids = ids[0].tolist()
510
  scores = scores[0].tolist()
@@ -517,73 +462,72 @@ def search(req: SearchRequest):
517
  out.append({"path": r.get("path"), "text": r.get("text"), "score": float(sc)})
518
  return {"results": out}
519
 
520
- # --------------------------------------------------------------------------- #
521
- # EXPORT ARTIFACTS (gzip)
522
- # --------------------------------------------------------------------------- #
523
  @fastapi_app.get("/artifacts/{project_id}/dataset")
524
  def download_dataset(project_id: str):
525
- _, ds_dir, _ = _proj_dirs(project_id)
526
  if not os.path.isdir(ds_dir):
527
  raise HTTPException(status_code=404, detail="Dataset introuvable")
528
  buf = _tar_dir_to_bytes(ds_dir)
529
- hdr = {"Content-Disposition": f'attachment; filename="{project_id}_dataset.tgz"'}
530
- return StreamingResponse(io.BytesIO(buf), media_type="application/gzip", headers=hdr)
531
 
532
  @fastapi_app.get("/artifacts/{project_id}/faiss")
533
  def download_faiss(project_id: str):
534
- _, _, fx_dir = _proj_dirs(project_id)
535
  if not os.path.isdir(fx_dir):
536
  raise HTTPException(status_code=404, detail="FAISS introuvable")
537
  buf = _tar_dir_to_bytes(fx_dir)
538
- hdr = {"Content-Disposition": f'attachment; filename="{project_id}_faiss.tgz"'}
539
- return StreamingResponse(io.BytesIO(buf), media_type="application/gzip", headers=hdr)
540
 
541
- # --------------------------------------------------------------------------- #
542
- # GRADIO UI (facultatif test rapide)
543
- # --------------------------------------------------------------------------- #
544
  def _ui_index(project_id: str, sample_text: str):
545
  files = [{"path": "sample.txt", "text": sample_text}]
 
546
  try:
547
  req = IndexRequest(project_id=project_id, files=[FileItem(**f) for f in files])
548
- except Exception as e:
549
- return f"❌ Validation : {e}"
550
  try:
551
  res = index(req)
552
- return f"Job lancé : {res['job_id']}"
553
  except Exception as e:
554
- return f" Erreur : {e}"
555
 
556
  def _ui_search(project_id: str, query: str, k: int):
557
  try:
558
  res = search(SearchRequest(project_id=project_id, query=query, k=int(k)))
559
  return json.dumps(res, ensure_ascii=False, indent=2)
560
  except Exception as e:
561
- return f" Erreur : {e}"
562
-
563
- with gr.Blocks(title="Remote Indexer (Async – Optimisé)", analytics_enabled=False) as ui:
564
- gr.Markdown("## Remote Indexer — Async (FAISS quantisé, mmap, texte optionnel)")
565
- with gr.Row():
566
- pid = gr.Textbox(label="Project ID", value="DEMO")
567
- txt = gr.Textbox(label="Texte d’exemple", lines=4, value="Alpha bravo charlie delta echo foxtrot.")
568
- btn_idx = gr.Button("Lancer index (sample)")
569
- out_idx = gr.Textbox(label="Résultat")
570
- btn_idx.click(_ui_index, inputs=[pid, txt], outputs=[out_idx])
571
-
572
- with gr.Row():
 
 
573
  q = gr.Textbox(label="Query", value="alpha")
574
- k = gr.Slider(1, 20, value=5, step=1, label="Top‑K")
575
- btn_q = gr.Button("Rechercher")
576
- out_q = gr.Code(label="Résultats")
577
- btn_q.click(_ui_search, inputs=[pid, q, k], outputs=[out_q])
578
 
579
- # Monte l’UI Gradio sur le même serveur FastAPI
580
  fastapi_app = gr.mount_gradio_app(fastapi_app, ui, path="/ui")
581
 
582
- # --------------------------------------------------------------------------- #
583
- # MAIN
584
- # --------------------------------------------------------------------------- #
585
  if __name__ == "__main__":
586
  import uvicorn
587
-
588
- LOG.info("Démarrage Uvicorn – port %s – UI disponible à /ui", PORT)
589
  uvicorn.run(fastapi_app, host="0.0.0.0", port=PORT)
 
1
  # -*- coding: utf-8 -*-
 
 
 
 
 
 
 
 
 
2
  from __future__ import annotations
3
 
4
  import os
5
  import io
6
  import json
7
  import time
 
 
8
  import tarfile
9
+ import logging
10
+ import hashlib
11
+ from typing import Dict, Any, List, Tuple, Optional
 
12
  from concurrent.futures import ThreadPoolExecutor
13
 
14
  import numpy as np
 
20
 
21
  import gradio as gr
22
 
23
+ # =============================================================================
24
+ # LOGGING
25
+ # =============================================================================
 
 
 
 
 
 
 
 
 
 
26
  LOG = logging.getLogger("remote-indexer-async")
27
  if not LOG.handlers:
28
  h = logging.StreamHandler()
 
37
  DBG.addHandler(hd)
38
  DBG.setLevel(logging.DEBUG)
39
 
40
+ # =============================================================================
41
+ # CONFIG (via ENV)
42
+ # =============================================================================
43
  PORT = int(os.getenv("PORT", "7860"))
44
+ DATA_ROOT = os.getenv("DATA_ROOT", "/tmp/data") # stockage interne du Space (volatile en Free)
45
  os.makedirs(DATA_ROOT, exist_ok=True)
46
 
47
+ # Provider d'embeddings:
48
+ # - "dummy" : vecteurs aléatoires déterministes (très rapide)
49
+ # - "st" : Sentence-Transformers (CPU-friendly)
50
+ # - "hf" : Transformers pur (AutoModel/AutoTokenizer)
51
  EMB_PROVIDER = os.getenv("EMB_PROVIDER", "dummy").strip().lower()
52
  EMB_MODEL = os.getenv("EMB_MODEL", "sentence-transformers/all-mpnet-base-v2").strip()
53
  EMB_BATCH = int(os.getenv("EMB_BATCH", "32"))
54
+ EMB_DIM = int(os.getenv("EMB_DIM", "128")) # utilisé pour dummy
55
 
56
+ # Taille du pool de workers (asynchrone)
57
  MAX_WORKERS = int(os.getenv("MAX_WORKERS", "1"))
58
 
59
+ # =============================================================================
60
+ # CACHE DIRECTORIES (évite PermissionError: '/.cache')
61
+ # =============================================================================
62
  def _setup_cache_dirs() -> Dict[str, str]:
63
  os.environ.setdefault("HOME", "/home/user")
64
+
65
  CACHE_ROOT = os.getenv("CACHE_ROOT", "/tmp/.cache").rstrip("/")
66
  paths = {
67
  "root": CACHE_ROOT,
 
87
  os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
88
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
89
 
90
+ LOG.info("Caches configurés: %s", json.dumps(paths, indent=2))
91
  return paths
92
 
 
93
  CACHE_PATHS = _setup_cache_dirs()
94
 
95
+ # Cache global lazy (pour les modèles)
96
+ _ST_MODEL = None
97
+ _HF_TOKENIZER = None
98
+ _HF_MODEL = None
99
+
100
+ # =============================================================================
101
+ # JOB STATE
102
+ # =============================================================================
103
+ class JobState(BaseModel):
104
+ job_id: str
105
+ project_id: str
106
+ stage: str = "pending" # pending -> chunking -> embedding -> indexing -> done/failed
107
+ total_files: int = 0
108
+ total_chunks: int = 0
109
+ embedded: int = 0
110
+ indexed: int = 0
111
+ errors: List[str] = []
112
+ messages: List[str] = []
113
+ started_at: float = time.time()
114
+ finished_at: Optional[float] = None
115
 
 
 
 
116
  JOBS: Dict[str, JobState] = {}
117
 
118
  def _now() -> str:
 
126
  os.makedirs(fx_dir, exist_ok=True)
127
  return base, ds_dir, fx_dir
128
 
129
+ def _add_msg(st: JobState, msg: str):
130
  st.messages.append(f"[{_now()}] {msg}")
131
  LOG.info("[%s] %s", st.job_id, msg)
132
  DBG.debug("[%s] %s", st.job_id, msg)
133
 
134
+ def _set_stage(st: JobState, stage: str):
135
  st.stage = stage
136
  _add_msg(st, f"stage={stage}")
137
 
138
+ # =============================================================================
139
+ # UTILS
140
+ # =============================================================================
141
  def _chunk_text(text: str, size: int = 200, overlap: int = 20) -> List[str]:
142
  text = (text or "").replace("\r\n", "\n")
143
  tokens = list(text)
 
161
  n = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
162
  return x / n
163
 
164
+ # ----------------------- PROVIDER: DUMMY --------------------------------------
 
 
 
 
 
 
165
  def _emb_dummy(texts: List[str], dim: int = EMB_DIM) -> np.ndarray:
166
  vecs = np.zeros((len(texts), dim), dtype="float32")
167
  for i, t in enumerate(texts):
 
171
  vecs[i] = v / (np.linalg.norm(v) + 1e-9)
172
  return vecs
173
 
174
+ # ----------------- PROVIDER: Sentence-Transformers ----------------------------
175
  def _get_st_model():
176
  global _ST_MODEL
177
  if _ST_MODEL is None:
178
  from sentence_transformers import SentenceTransformer
179
  _ST_MODEL = SentenceTransformer(EMB_MODEL, cache_folder=CACHE_PATHS["st"])
180
+ LOG.info("[st] modèle chargé: %s (cache=%s)", EMB_MODEL, CACHE_PATHS["st"])
181
  return _ST_MODEL
182
 
183
  def _emb_st(texts: List[str]) -> np.ndarray:
 
191
  ).astype("float32")
192
  return vecs
193
 
194
+ # ----------------------- PROVIDER: Transformers (HF) --------------------------
195
  def _get_hf_model():
196
  global _HF_TOKENIZER, _HF_MODEL
197
  if _HF_MODEL is None or _HF_TOKENIZER is None:
 
199
  _HF_TOKENIZER = AutoTokenizer.from_pretrained(EMB_MODEL, cache_dir=CACHE_PATHS["hf_tf"])
200
  _HF_MODEL = AutoModel.from_pretrained(EMB_MODEL, cache_dir=CACHE_PATHS["hf_tf"])
201
  _HF_MODEL.eval()
202
+ LOG.info("[hf] modèle chargé: %s (cache=%s)", EMB_MODEL, CACHE_PATHS["hf_tf"])
203
  return _HF_TOKENIZER, _HF_MODEL
204
 
205
+ def _mean_pool(last_hidden_state: "np.ndarray", attention_mask: "np.ndarray") -> "np.ndarray":
206
  mask = attention_mask[..., None].astype(last_hidden_state.dtype)
207
  summed = (last_hidden_state * mask).sum(axis=1)
208
  counts = mask.sum(axis=1).clip(min=1e-9)
 
211
  def _emb_hf(texts: List[str]) -> np.ndarray:
212
  import torch
213
  tok, mod = _get_hf_model()
214
+ all_vecs = []
215
  bs = max(1, EMB_BATCH)
216
  with torch.no_grad():
217
  for i in range(0, len(texts), bs):
218
+ batch = texts[i:i+bs]
219
  enc = tok(batch, padding=True, truncation=True, return_tensors="pt")
220
  out = mod(**enc)
221
  last = out.last_hidden_state # (b, t, h)
222
  pooled = _mean_pool(last.numpy(), enc["attention_mask"].numpy())
223
  all_vecs.append(pooled.astype("float32"))
224
+ vecs = np.concatenate(all_vecs, axis=0)
225
+ return _l2_normalize(vecs)
226
 
227
+ # ---------------------------- DATASET / FAISS ---------------------------------
228
+ def _save_dataset(ds_dir: str, rows: List[Dict[str, Any]]):
 
 
229
  os.makedirs(ds_dir, exist_ok=True)
230
  data_path = os.path.join(ds_dir, "data.jsonl")
231
  with open(data_path, "w", encoding="utf-8") as f:
232
  for r in rows:
 
 
233
  f.write(json.dumps(r, ensure_ascii=False) + "\n")
234
  meta = {"format": "jsonl", "columns": ["path", "text", "chunk_id"], "count": len(rows)}
235
  with open(os.path.join(ds_dir, "meta.json"), "w", encoding="utf-8") as f:
 
239
  data_path = os.path.join(ds_dir, "data.jsonl")
240
  if not os.path.isfile(data_path):
241
  return []
242
+ out = []
243
  with open(data_path, "r", encoding="utf-8") as f:
244
  for line in f:
245
  try:
 
248
  continue
249
  return out
250
 
251
+ def _save_faiss(fx_dir: str, xb: np.ndarray, meta: Dict[str, Any]):
252
  os.makedirs(fx_dir, exist_ok=True)
253
  idx_path = os.path.join(fx_dir, "emb.faiss")
254
+ index = faiss.IndexFlatIP(xb.shape[1]) # cosine ~ inner product si embeddings normalisés
 
 
 
 
 
 
 
 
 
255
  index.add(xb)
256
  faiss.write_index(index, idx_path)
 
 
257
  with open(os.path.join(fx_dir, "meta.json"), "w", encoding="utf-8") as f:
258
  json.dump(meta, f, ensure_ascii=False, indent=2)
259
 
260
  def _load_faiss(fx_dir: str) -> faiss.Index:
261
  idx_path = os.path.join(fx_dir, "emb.faiss")
262
  if not os.path.isfile(idx_path):
263
+ raise FileNotFoundError(f"FAISS index introuvable: {idx_path}")
264
+ return faiss.read_index(idx_path)
 
265
 
266
  def _tar_dir_to_bytes(dir_path: str) -> bytes:
267
  bio = io.BytesIO()
268
+ with tarfile.open(fileobj=bio, mode="w:gz") as tar:
269
  tar.add(dir_path, arcname=os.path.basename(dir_path))
270
  bio.seek(0)
271
  return bio.read()
272
 
273
+ # =============================================================================
274
+ # WORKER POOL (asynchrone)
275
+ # =============================================================================
276
  EXECUTOR = ThreadPoolExecutor(max_workers=max(1, MAX_WORKERS))
277
+ LOG.info("ThreadPoolExecutor initialisé : max_workers=%s", MAX_WORKERS)
278
+
279
+ def _do_index_job(st: JobState, files: List[Dict[str, str]], chunk_size: int, overlap: int, batch_size: int, store_text: bool) -> None:
 
 
 
 
 
 
 
280
  """
281
+ Tâche lourde lancée dans un worker thread.
282
+ Met à jour l'état 'st' tout au long du pipeline.
 
 
 
 
283
  """
284
  try:
285
  base, ds_dir, fx_dir = _proj_dirs(st.project_id)
286
 
287
+ # 1) Chunking
288
  _set_stage(st, "chunking")
289
  rows: List[Dict[str, Any]] = []
290
  st.total_files = len(files)
291
+ for it in files:
292
+ path = (it.get("path") or "unknown").strip()
293
+ txt = it.get("text") or ""
294
+ chks = _chunk_text(txt, size=int(chunk_size), overlap=int(overlap))
295
+ _add_msg(st, f"{path}: len(text)={len(txt)} chunks={len(chks)}")
296
+ for ci, ck in enumerate(chks):
297
+ rows.append({"path": path, "text": ck, "chunk_id": ci})
 
298
  st.total_chunks = len(rows)
299
  _add_msg(st, f"Total chunks = {st.total_chunks}")
300
 
301
+ # 2) Embedding
302
  _set_stage(st, "embedding")
303
  texts = [r["text"] for r in rows]
 
304
  if EMB_PROVIDER == "dummy":
305
  xb = _emb_dummy(texts, dim=EMB_DIM)
306
+ dim = xb.shape[1]
307
  elif EMB_PROVIDER == "st":
308
  xb = _emb_st(texts)
309
+ dim = xb.shape[1]
310
  else:
311
  xb = _emb_hf(texts)
312
+ dim = xb.shape[1]
 
 
 
 
 
 
313
 
314
  st.embedded = xb.shape[0]
315
+ _add_msg(st, f"Embeddings {st.embedded}/{st.total_chunks}")
316
+ _add_msg(st, f"Embeddings dim={dim}")
317
 
318
+ # 3) Sauvegarde dataset (texte)
319
+ _save_dataset(ds_dir, rows)
320
+ _add_msg(st, f"Dataset (sans index) sauvegardé dans {ds_dir}")
321
 
322
+ # 4) FAISS
323
  _set_stage(st, "indexing")
324
+ faiss_meta = {
325
+ "dim": int(dim),
326
  "count": int(xb.shape[0]),
327
  "provider": EMB_PROVIDER,
328
+ "model": EMB_MODEL if EMB_PROVIDER != "dummy" else None
329
  }
330
+ _save_faiss(fx_dir, xb, meta=faiss_meta)
331
  st.indexed = int(xb.shape[0])
332
  _add_msg(st, f"FAISS écrit sur {os.path.join(fx_dir, 'emb.faiss')}")
333
+ _add_msg(st, f"OK — dataset+index prêts (projet={st.project_id})")
334
 
335
  _set_stage(st, "done")
336
  st.finished_at = time.time()
337
  except Exception as e:
338
+ LOG.exception("Job %s failed", st.job_id)
339
  st.errors.append(str(e))
340
+ _add_msg(st, f"❌ Exception: {e}")
341
  st.stage = "failed"
342
  st.finished_at = time.time()
343
 
344
+ def _submit_job(project_id: str, files: List[Dict[str, str]], chunk_size: int, overlap: int, batch_size: int, store_text: bool) -> str:
 
 
 
 
 
 
 
 
345
  job_id = hashlib.sha1(f"{project_id}{time.time()}".encode()).hexdigest()[:12]
346
  st = JobState(job_id=job_id, project_id=project_id, stage="pending", messages=[])
347
  JOBS[job_id] = st
348
+ _add_msg(st, f"Job {job_id} créé pour project {project_id}")
349
+ _add_msg(st, f"Index start project={project_id} files={len(files)} chunk_size={chunk_size} overlap={overlap} batch_size={batch_size} store_text={store_text} provider={EMB_PROVIDER} model={EMB_MODEL if EMB_PROVIDER!='dummy' else '-'}")
350
 
351
+ # Soumission au pool (retour immédiat)
352
+ EXECUTOR.submit(_do_index_job, st, files, chunk_size, overlap, batch_size, store_text)
353
+ _set_stage(st, "queued")
 
 
 
 
 
 
 
 
 
354
  return job_id
355
 
356
+ # =============================================================================
357
+ # FASTAPI
358
+ # =============================================================================
359
  fastapi_app = FastAPI(title="remote-indexer-async", version="3.0.0")
360
  fastapi_app.add_middleware(
361
  CORSMiddleware,
362
+ allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
 
 
 
363
  )
364
 
365
  class FileItem(BaseModel):
 
372
  chunk_size: int = 200
373
  overlap: int = 20
374
  batch_size: int = 32
375
+ store_text: bool = True
376
 
377
  @fastapi_app.get("/health")
378
  def health():
379
+ info = {
380
  "ok": True,
381
  "service": "remote-indexer-async",
382
  "provider": EMB_PROVIDER,
 
384
  "cache_root": os.getenv("CACHE_ROOT", "/tmp/.cache"),
385
  "workers": MAX_WORKERS,
386
  "data_root": DATA_ROOT,
 
387
  }
388
+ return info
389
+
390
+ @fastapi_app.get("/")
391
+ def root_redirect():
392
+ return {"ok": True, "service": "remote-indexer-async", "ui": "/ui"}
393
 
394
  @fastapi_app.post("/index")
395
  def index(req: IndexRequest):
396
  """
397
+ ASYNCHRONE : retourne immédiatement un job_id.
398
+ Le traitement est effectué en arrière-plan par le pool de threads.
399
  """
400
  try:
401
  files = [fi.model_dump() for fi in req.files]
 
409
  )
410
  return {"job_id": job_id}
411
  except Exception as e:
412
+ LOG.exception("index failed (submit)")
413
  raise HTTPException(status_code=500, detail=str(e))
414
 
415
  @fastapi_app.get("/status/{job_id}")
 
428
  def search(req: SearchRequest):
429
  base, ds_dir, fx_dir = _proj_dirs(req.project_id)
430
 
431
+ # Si l'index n'existe pas encore, on répond 409 (conflit / pas prêt)
432
+ idx_path = os.path.join(fx_dir, "emb.faiss")
433
+ ds_path = os.path.join(ds_dir, "data.jsonl")
434
+ if not (os.path.isfile(idx_path) and os.path.isfile(ds_path)):
435
  raise HTTPException(status_code=409, detail="Index non prêt (reviens plus tard)")
436
 
437
  rows = _load_dataset(ds_dir)
438
  if not rows:
439
  raise HTTPException(status_code=404, detail="dataset introuvable")
440
 
441
+ # Embedding de la requête avec le MÊME provider
442
  if EMB_PROVIDER == "dummy":
443
  q = _emb_dummy([req.query], dim=EMB_DIM)[0:1, :]
444
  elif EMB_PROVIDER == "st":
 
446
  else:
447
  q = _emb_hf([req.query])[0:1, :]
448
 
449
+ # FAISS
450
  index = _load_faiss(fx_dir)
451
  if index.d != q.shape[1]:
452
+ raise HTTPException(status_code=500, detail=f"dim incompatibles: index.d={index.d} vs query={q.shape[1]}")
 
 
 
453
  scores, ids = index.search(q, int(max(1, req.k)))
454
  ids = ids[0].tolist()
455
  scores = scores[0].tolist()
 
462
  out.append({"path": r.get("path"), "text": r.get("text"), "score": float(sc)})
463
  return {"results": out}
464
 
465
+ # ----------- ARTIFACTS EXPORT -----------
 
 
466
  @fastapi_app.get("/artifacts/{project_id}/dataset")
467
  def download_dataset(project_id: str):
468
+ base, ds_dir, _ = _proj_dirs(project_id)
469
  if not os.path.isdir(ds_dir):
470
  raise HTTPException(status_code=404, detail="Dataset introuvable")
471
  buf = _tar_dir_to_bytes(ds_dir)
472
+ headers = {"Content-Disposition": f'attachment; filename="{project_id}_dataset.tgz"'}
473
+ return StreamingResponse(io.BytesIO(buf), media_type="application/gzip", headers=headers)
474
 
475
  @fastapi_app.get("/artifacts/{project_id}/faiss")
476
  def download_faiss(project_id: str):
477
+ base, _, fx_dir = _proj_dirs(project_id)
478
  if not os.path.isdir(fx_dir):
479
  raise HTTPException(status_code=404, detail="FAISS introuvable")
480
  buf = _tar_dir_to_bytes(fx_dir)
481
+ headers = {"Content-Disposition": f'attachment; filename="{project_id}_faiss.tgz"'}
482
+ return StreamingResponse(io.BytesIO(buf), media_type="application/gzip", headers=headers)
483
 
484
+ # =============================================================================
485
+ # GRADIO UI (facultatif de test)
486
+ # =============================================================================
487
  def _ui_index(project_id: str, sample_text: str):
488
  files = [{"path": "sample.txt", "text": sample_text}]
489
+ from pydantic import ValidationError
490
  try:
491
  req = IndexRequest(project_id=project_id, files=[FileItem(**f) for f in files])
492
+ except ValidationError as e:
493
+ return f"Erreur: {e}"
494
  try:
495
  res = index(req)
496
+ return f"Job lancé: {res['job_id']}"
497
  except Exception as e:
498
+ return f"Erreur index: {e}"
499
 
500
  def _ui_search(project_id: str, query: str, k: int):
501
  try:
502
  res = search(SearchRequest(project_id=project_id, query=query, k=int(k)))
503
  return json.dumps(res, ensure_ascii=False, indent=2)
504
  except Exception as e:
505
+ return f"Erreur search: {e}"
506
+
507
+ with gr.Blocks(title="Remote Indexer (Async FAISS)", analytics_enabled=False) as ui:
508
+ gr.Markdown("## Remote Indexer — **Async** (API: `/index`, `/status/{job}`, `/search`, `/artifacts/...`).")
509
+ gr.Markdown(f"**Provider**: `{EMB_PROVIDER}` — **Model**: `{EMB_MODEL if EMB_PROVIDER!='dummy' else '-'}` — **Cache**: `{os.getenv('CACHE_ROOT', '/tmp/.cache')}` — **Workers**: `{MAX_WORKERS}`")
510
+ with gr.Tab("Index"):
511
+ pid = gr.Textbox(label="Project ID", value="DEEPWEB")
512
+ sample = gr.Textbox(label="Texte d’exemple", value="Alpha bravo charlie delta echo foxtrot.", lines=4)
513
+ btn = gr.Button("Lancer index (sample)")
514
+ out = gr.Textbox(label="Résultat")
515
+ btn.click(_ui_index, inputs=[pid, sample], outputs=[out])
516
+
517
+ with gr.Tab("Search"):
518
+ pid2 = gr.Textbox(label="Project ID", value="DEEPWEB")
519
  q = gr.Textbox(label="Query", value="alpha")
520
+ k = gr.Slider(1, 20, value=5, step=1, label="k")
521
+ btn2 = gr.Button("Rechercher")
522
+ out2 = gr.Code(label="Résultats")
523
+ btn2.click(_ui_search, inputs=[pid2, q, k], outputs=[out2])
524
 
 
525
  fastapi_app = gr.mount_gradio_app(fastapi_app, ui, path="/ui")
526
 
527
+ # =============================================================================
528
+ # MAIN
529
+ # =============================================================================
530
  if __name__ == "__main__":
531
  import uvicorn
532
+ LOG.info("Démarrage Uvicorn sur 0.0.0.0:%s (UI_PATH=/ui) — async index", PORT)
 
533
  uvicorn.run(fastapi_app, host="0.0.0.0", port=PORT)