chouchouvs commited on
Commit
8a04dcd
·
verified ·
1 Parent(s): 6cb5d1b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +126 -87
main.py CHANGED
@@ -9,6 +9,7 @@ import tarfile
9
  import logging
10
  import hashlib
11
  from typing import Dict, Any, List, Tuple, Optional
 
12
 
13
  import numpy as np
14
  import faiss
@@ -22,18 +23,25 @@ import gradio as gr
22
  # =============================================================================
23
  # LOGGING
24
  # =============================================================================
25
- LOG = logging.getLogger("remote-indexer-space")
26
  if not LOG.handlers:
27
  h = logging.StreamHandler()
28
  h.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
29
  LOG.addHandler(h)
30
  LOG.setLevel(logging.INFO)
31
 
 
 
 
 
 
 
 
32
  # =============================================================================
33
  # CONFIG (via ENV)
34
  # =============================================================================
35
  PORT = int(os.getenv("PORT", "7860"))
36
- DATA_ROOT = os.getenv("DATA_ROOT", "/tmp/data") # stockage interne du Space
37
  os.makedirs(DATA_ROOT, exist_ok=True)
38
 
39
  # Provider d'embeddings:
@@ -45,11 +53,13 @@ EMB_MODEL = os.getenv("EMB_MODEL", "sentence-transformers/paraphrase-multilingua
45
  EMB_BATCH = int(os.getenv("EMB_BATCH", "32"))
46
  EMB_DIM = int(os.getenv("EMB_DIM", "128")) # utilisé pour dummy
47
 
 
 
 
48
  # =============================================================================
49
- # CACHE DIRECTORIES (crucial pour éviter PermissionError: '/.cache')
50
  # =============================================================================
51
  def _setup_cache_dirs() -> Dict[str, str]:
52
- # HOME peut être vide -> expanduser('~') => '/' -> '/.cache' -> Permission denied
53
  os.environ.setdefault("HOME", "/home/user")
54
 
55
  CACHE_ROOT = os.getenv("CACHE_ROOT", "/tmp/.cache").rstrip("/")
@@ -68,15 +78,12 @@ def _setup_cache_dirs() -> Dict[str, str]:
68
  except Exception as e:
69
  LOG.warning("Impossible de créer %s : %s", p, e)
70
 
71
- # Variables standard HF/Transformers/Torch/ST
72
  os.environ["HF_HOME"] = paths["hf_home"]
73
  os.environ["HF_HUB_CACHE"] = paths["hf_hub"]
74
  os.environ["TRANSFORMERS_CACHE"] = paths["hf_tf"]
75
  os.environ["TORCH_HOME"] = paths["torch"]
76
  os.environ["SENTENCE_TRANSFORMERS_HOME"] = paths["st"]
77
- os.environ["MPLCONFIGDIR"] = paths["mpl"] # évite les warnings matplotlib
78
-
79
- # Qualité de vie
80
  os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
81
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
82
 
@@ -122,6 +129,7 @@ def _proj_dirs(project_id: str) -> Tuple[str, str, str]:
122
  def _add_msg(st: JobState, msg: str):
123
  st.messages.append(f"[{_now()}] {msg}")
124
  LOG.info("[%s] %s", st.job_id, msg)
 
125
 
126
  def _set_stage(st: JobState, stage: str):
127
  st.stage = stage
@@ -183,14 +191,6 @@ def _emb_st(texts: List[str]) -> np.ndarray:
183
  ).astype("float32")
184
  return vecs
185
 
186
- def _st_dim() -> int:
187
- model = _get_st_model()
188
- try:
189
- return int(model.get_sentence_embedding_dimension())
190
- except Exception:
191
- v = model.encode(["dimension probe"], convert_to_numpy=True)
192
- return int(v.shape[1])
193
-
194
  # ----------------------- PROVIDER: Transformers (HF) --------------------------
195
  def _get_hf_model():
196
  global _HF_TOKENIZER, _HF_MODEL
@@ -219,18 +219,11 @@ def _emb_hf(texts: List[str]) -> np.ndarray:
219
  enc = tok(batch, padding=True, truncation=True, return_tensors="pt")
220
  out = mod(**enc)
221
  last = out.last_hidden_state # (b, t, h)
222
- pooled = _mean_pool(last.numpy(), enc["attention_mask"].numpy()) # numpy
223
  all_vecs.append(pooled.astype("float32"))
224
  vecs = np.concatenate(all_vecs, axis=0)
225
  return _l2_normalize(vecs)
226
 
227
- def _hf_dim() -> int:
228
- try:
229
- _, mod = _get_hf_model()
230
- return int(getattr(mod.config, "hidden_size", 768))
231
- except Exception:
232
- return 768
233
-
234
  # ---------------------------- DATASET / FAISS ---------------------------------
235
  def _save_dataset(ds_dir: str, rows: List[Dict[str, Any]]):
236
  os.makedirs(ds_dir, exist_ok=True)
@@ -278,74 +271,44 @@ def _tar_dir_to_bytes(dir_path: str) -> bytes:
278
  return bio.read()
279
 
280
  # =============================================================================
281
- # FASTAPI
282
  # =============================================================================
283
- fastapi_app = FastAPI(title="remote-indexer", version="2.1.0")
284
- fastapi_app.add_middleware(
285
- CORSMiddleware,
286
- allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
287
- )
288
-
289
- class FileItem(BaseModel):
290
- path: str
291
- text: str
292
-
293
- class IndexRequest(BaseModel):
294
- project_id: str
295
- files: List[FileItem]
296
- chunk_size: int = 200
297
- overlap: int = 20
298
- batch_size: int = 32
299
- store_text: bool = True
300
-
301
- @fastapi_app.get("/health")
302
- def health():
303
- info = {
304
- "ok": True,
305
- "service": "remote-indexer",
306
- "provider": EMB_PROVIDER,
307
- "model": EMB_MODEL if EMB_PROVIDER != "dummy" else None,
308
- "cache_root": os.getenv("CACHE_ROOT", "/tmp/.cache"),
309
- }
310
- return info
311
-
312
- @fastapi_app.get("/")
313
- def root_redirect():
314
- return {"ok": True, "service": "remote-indexer", "ui": "/ui"}
315
-
316
- @fastapi_app.post("/index")
317
- def index(req: IndexRequest):
318
- job_id = hashlib.sha1(f"{req.project_id}{time.time()}".encode()).hexdigest()[:12]
319
- st = JobState(job_id=job_id, project_id=req.project_id, stage="pending", messages=[])
320
- JOBS[job_id] = st
321
- _add_msg(st, f"Job {job_id} créé pour project {req.project_id}")
322
- _add_msg(st, f"Index start project={req.project_id} files={len(req.files)} chunk_size={req.chunk_size} overlap={req.overlap} batch_size={req.batch_size} store_text={req.store_text} provider={EMB_PROVIDER} model={EMB_MODEL if EMB_PROVIDER!='dummy' else '-'}")
323
  try:
324
- base, ds_dir, fx_dir = _proj_dirs(req.project_id)
325
 
326
  # 1) Chunking
327
  _set_stage(st, "chunking")
328
  rows: List[Dict[str, Any]] = []
329
- st.total_files = len(req.files)
330
- for it in req.files:
331
- txt = it.text or ""
332
- chunks = _chunk_text(txt, size=req.chunk_size, overlap=req.overlap)
333
- _add_msg(st, f"{it.path}: len(text)={len(txt)} chunks={len(chunks)}")
334
- for ci, ck in enumerate(chunks):
335
- rows.append({"path": it.path, "text": ck, "chunk_id": ci})
 
336
  st.total_chunks = len(rows)
337
  _add_msg(st, f"Total chunks = {st.total_chunks}")
338
 
339
  # 2) Embedding
340
  _set_stage(st, "embedding")
 
341
  if EMB_PROVIDER == "dummy":
342
- xb = _emb_dummy([r["text"] for r in rows], dim=EMB_DIM)
343
  dim = xb.shape[1]
344
  elif EMB_PROVIDER == "st":
345
- xb = _emb_st([r["text"] for r in rows])
346
  dim = xb.shape[1]
347
- else: # "hf"
348
- xb = _emb_hf([r["text"] for r in rows])
349
  dim = xb.shape[1]
350
 
351
  st.embedded = xb.shape[0]
@@ -367,17 +330,86 @@ def index(req: IndexRequest):
367
  _save_faiss(fx_dir, xb, meta=faiss_meta)
368
  st.indexed = int(xb.shape[0])
369
  _add_msg(st, f"FAISS écrit sur {os.path.join(fx_dir, 'emb.faiss')}")
370
- _add_msg(st, f"OK — dataset+index prêts (projet={req.project_id})")
371
 
372
  _set_stage(st, "done")
373
  st.finished_at = time.time()
374
- return {"job_id": job_id}
375
  except Exception as e:
376
- LOG.exception("index failed")
377
  st.errors.append(str(e))
378
  _add_msg(st, f"❌ Exception: {e}")
379
  st.stage = "failed"
380
  st.finished_at = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  raise HTTPException(status_code=500, detail=str(e))
382
 
383
  @fastapi_app.get("/status/{job_id}")
@@ -395,9 +427,16 @@ class SearchRequest(BaseModel):
395
  @fastapi_app.post("/search")
396
  def search(req: SearchRequest):
397
  base, ds_dir, fx_dir = _proj_dirs(req.project_id)
 
 
 
 
 
 
 
398
  rows = _load_dataset(ds_dir)
399
  if not rows:
400
- raise HTTPException(status_code=404, detail="dataset introuvable (index pas encore construit ?)")
401
 
402
  # Embedding de la requête avec le MÊME provider
403
  if EMB_PROVIDER == "dummy":
@@ -443,7 +482,7 @@ def download_faiss(project_id: str):
443
  return StreamingResponse(io.BytesIO(buf), media_type="application/gzip", headers=headers)
444
 
445
  # =============================================================================
446
- # GRADIO UI (facultatif)
447
  # =============================================================================
448
  def _ui_index(project_id: str, sample_text: str):
449
  files = [{"path": "sample.txt", "text": sample_text}]
@@ -465,9 +504,9 @@ def _ui_search(project_id: str, query: str, k: int):
465
  except Exception as e:
466
  return f"Erreur search: {e}"
467
 
468
- with gr.Blocks(title="Remote Indexer (FAISS)", analytics_enabled=False) as ui:
469
- gr.Markdown("## Remote Indexer — demo UI (API: `/index`, `/status/{job}`, `/search`, `/artifacts/...`).")
470
- gr.Markdown(f"**Provider**: `{EMB_PROVIDER}` — **Model**: `{EMB_MODEL if EMB_PROVIDER!='dummy' else '-'}` — **Cache**: `{os.getenv('CACHE_ROOT', '/tmp/.cache')}`")
471
  with gr.Tab("Index"):
472
  pid = gr.Textbox(label="Project ID", value="DEEPWEB")
473
  sample = gr.Textbox(label="Texte d’exemple", value="Alpha bravo charlie delta echo foxtrot.", lines=4)
@@ -490,5 +529,5 @@ fastapi_app = gr.mount_gradio_app(fastapi_app, ui, path="/ui")
490
  # =============================================================================
491
  if __name__ == "__main__":
492
  import uvicorn
493
- LOG.info("Démarrage Uvicorn sur 0.0.0.0:%s (UI_PATH=/ui)", PORT)
494
- uvicorn.run(fastapi_app, host="0.0.0.0", port=PORT)
 
9
  import logging
10
  import hashlib
11
  from typing import Dict, Any, List, Tuple, Optional
12
+ from concurrent.futures import ThreadPoolExecutor
13
 
14
  import numpy as np
15
  import faiss
 
23
  # =============================================================================
24
  # LOGGING
25
  # =============================================================================
26
+ LOG = logging.getLogger("remote-indexer-async")
27
  if not LOG.handlers:
28
  h = logging.StreamHandler()
29
  h.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
30
  LOG.addHandler(h)
31
  LOG.setLevel(logging.INFO)
32
 
33
+ DBG = logging.getLogger("remote-indexer-async.debug")
34
+ if not DBG.handlers:
35
+ hd = logging.StreamHandler()
36
+ hd.setFormatter(logging.Formatter("[DEBUG] %(asctime)s - %(message)s"))
37
+ DBG.addHandler(hd)
38
+ DBG.setLevel(logging.DEBUG)
39
+
40
  # =============================================================================
41
  # CONFIG (via ENV)
42
  # =============================================================================
43
  PORT = int(os.getenv("PORT", "7860"))
44
+ DATA_ROOT = os.getenv("DATA_ROOT", "/tmp/data") # stockage interne du Space (volatile en Free)
45
  os.makedirs(DATA_ROOT, exist_ok=True)
46
 
47
  # Provider d'embeddings:
 
53
  EMB_BATCH = int(os.getenv("EMB_BATCH", "32"))
54
  EMB_DIM = int(os.getenv("EMB_DIM", "128")) # utilisé pour dummy
55
 
56
+ # Taille du pool de workers (asynchrone)
57
+ MAX_WORKERS = int(os.getenv("MAX_WORKERS", "1"))
58
+
59
  # =============================================================================
60
+ # CACHE DIRECTORIES (évite PermissionError: '/.cache')
61
  # =============================================================================
62
  def _setup_cache_dirs() -> Dict[str, str]:
 
63
  os.environ.setdefault("HOME", "/home/user")
64
 
65
  CACHE_ROOT = os.getenv("CACHE_ROOT", "/tmp/.cache").rstrip("/")
 
78
  except Exception as e:
79
  LOG.warning("Impossible de créer %s : %s", p, e)
80
 
 
81
  os.environ["HF_HOME"] = paths["hf_home"]
82
  os.environ["HF_HUB_CACHE"] = paths["hf_hub"]
83
  os.environ["TRANSFORMERS_CACHE"] = paths["hf_tf"]
84
  os.environ["TORCH_HOME"] = paths["torch"]
85
  os.environ["SENTENCE_TRANSFORMERS_HOME"] = paths["st"]
86
+ os.environ["MPLCONFIGDIR"] = paths["mpl"]
 
 
87
  os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
88
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
89
 
 
129
  def _add_msg(st: JobState, msg: str):
130
  st.messages.append(f"[{_now()}] {msg}")
131
  LOG.info("[%s] %s", st.job_id, msg)
132
+ DBG.debug("[%s] %s", st.job_id, msg)
133
 
134
  def _set_stage(st: JobState, stage: str):
135
  st.stage = stage
 
191
  ).astype("float32")
192
  return vecs
193
 
 
 
 
 
 
 
 
 
194
  # ----------------------- PROVIDER: Transformers (HF) --------------------------
195
  def _get_hf_model():
196
  global _HF_TOKENIZER, _HF_MODEL
 
219
  enc = tok(batch, padding=True, truncation=True, return_tensors="pt")
220
  out = mod(**enc)
221
  last = out.last_hidden_state # (b, t, h)
222
+ pooled = _mean_pool(last.numpy(), enc["attention_mask"].numpy())
223
  all_vecs.append(pooled.astype("float32"))
224
  vecs = np.concatenate(all_vecs, axis=0)
225
  return _l2_normalize(vecs)
226
 
 
 
 
 
 
 
 
227
  # ---------------------------- DATASET / FAISS ---------------------------------
228
  def _save_dataset(ds_dir: str, rows: List[Dict[str, Any]]):
229
  os.makedirs(ds_dir, exist_ok=True)
 
271
  return bio.read()
272
 
273
  # =============================================================================
274
+ # WORKER POOL (asynchrone)
275
  # =============================================================================
276
+ EXECUTOR = ThreadPoolExecutor(max_workers=max(1, MAX_WORKERS))
277
+ LOG.info("ThreadPoolExecutor initialisé : max_workers=%s", MAX_WORKERS)
278
+
279
+ def _do_index_job(st: JobState, files: List[Dict[str, str]], chunk_size: int, overlap: int, batch_size: int, store_text: bool) -> None:
280
+ """
281
+ Tâche lourde lancée dans un worker thread.
282
+ Met à jour l'état 'st' tout au long du pipeline.
283
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  try:
285
+ base, ds_dir, fx_dir = _proj_dirs(st.project_id)
286
 
287
  # 1) Chunking
288
  _set_stage(st, "chunking")
289
  rows: List[Dict[str, Any]] = []
290
+ st.total_files = len(files)
291
+ for it in files:
292
+ path = (it.get("path") or "unknown").strip()
293
+ txt = it.get("text") or ""
294
+ chks = _chunk_text(txt, size=int(chunk_size), overlap=int(overlap))
295
+ _add_msg(st, f"{path}: len(text)={len(txt)} chunks={len(chks)}")
296
+ for ci, ck in enumerate(chks):
297
+ rows.append({"path": path, "text": ck, "chunk_id": ci})
298
  st.total_chunks = len(rows)
299
  _add_msg(st, f"Total chunks = {st.total_chunks}")
300
 
301
  # 2) Embedding
302
  _set_stage(st, "embedding")
303
+ texts = [r["text"] for r in rows]
304
  if EMB_PROVIDER == "dummy":
305
+ xb = _emb_dummy(texts, dim=EMB_DIM)
306
  dim = xb.shape[1]
307
  elif EMB_PROVIDER == "st":
308
+ xb = _emb_st(texts)
309
  dim = xb.shape[1]
310
+ else:
311
+ xb = _emb_hf(texts)
312
  dim = xb.shape[1]
313
 
314
  st.embedded = xb.shape[0]
 
330
  _save_faiss(fx_dir, xb, meta=faiss_meta)
331
  st.indexed = int(xb.shape[0])
332
  _add_msg(st, f"FAISS écrit sur {os.path.join(fx_dir, 'emb.faiss')}")
333
+ _add_msg(st, f"OK — dataset+index prêts (projet={st.project_id})")
334
 
335
  _set_stage(st, "done")
336
  st.finished_at = time.time()
 
337
  except Exception as e:
338
+ LOG.exception("Job %s failed", st.job_id)
339
  st.errors.append(str(e))
340
  _add_msg(st, f"❌ Exception: {e}")
341
  st.stage = "failed"
342
  st.finished_at = time.time()
343
+
344
+ def _submit_job(project_id: str, files: List[Dict[str, str]], chunk_size: int, overlap: int, batch_size: int, store_text: bool) -> str:
345
+ job_id = hashlib.sha1(f"{project_id}{time.time()}".encode()).hexdigest()[:12]
346
+ st = JobState(job_id=job_id, project_id=project_id, stage="pending", messages=[])
347
+ JOBS[job_id] = st
348
+ _add_msg(st, f"Job {job_id} créé pour project {project_id}")
349
+ _add_msg(st, f"Index start project={project_id} files={len(files)} chunk_size={chunk_size} overlap={overlap} batch_size={batch_size} store_text={store_text} provider={EMB_PROVIDER} model={EMB_MODEL if EMB_PROVIDER!='dummy' else '-'}")
350
+
351
+ # Soumission au pool (retour immédiat)
352
+ EXECUTOR.submit(_do_index_job, st, files, chunk_size, overlap, batch_size, store_text)
353
+ _set_stage(st, "queued")
354
+ return job_id
355
+
356
+ # =============================================================================
357
+ # FASTAPI
358
+ # =============================================================================
359
+ fastapi_app = FastAPI(title="remote-indexer-async", version="3.0.0")
360
+ fastapi_app.add_middleware(
361
+ CORSMiddleware,
362
+ allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
363
+ )
364
+
365
+ class FileItem(BaseModel):
366
+ path: str
367
+ text: str
368
+
369
+ class IndexRequest(BaseModel):
370
+ project_id: str
371
+ files: List[FileItem]
372
+ chunk_size: int = 200
373
+ overlap: int = 20
374
+ batch_size: int = 32
375
+ store_text: bool = True
376
+
377
+ @fastapi_app.get("/health")
378
+ def health():
379
+ info = {
380
+ "ok": True,
381
+ "service": "remote-indexer-async",
382
+ "provider": EMB_PROVIDER,
383
+ "model": EMB_MODEL if EMB_PROVIDER != "dummy" else None,
384
+ "cache_root": os.getenv("CACHE_ROOT", "/tmp/.cache"),
385
+ "workers": MAX_WORKERS,
386
+ "data_root": DATA_ROOT,
387
+ }
388
+ return info
389
+
390
+ @fastapi_app.get("/")
391
+ def root_redirect():
392
+ return {"ok": True, "service": "remote-indexer-async", "ui": "/ui"}
393
+
394
+ @fastapi_app.post("/index")
395
+ def index(req: IndexRequest):
396
+ """
397
+ ASYNCHRONE : retourne immédiatement un job_id.
398
+ Le traitement est effectué en arrière-plan par le pool de threads.
399
+ """
400
+ try:
401
+ files = [fi.model_dump() for fi in req.files]
402
+ job_id = _submit_job(
403
+ project_id=req.project_id,
404
+ files=files,
405
+ chunk_size=int(req.chunk_size),
406
+ overlap=int(req.overlap),
407
+ batch_size=int(req.batch_size),
408
+ store_text=bool(req.store_text),
409
+ )
410
+ return {"job_id": job_id}
411
+ except Exception as e:
412
+ LOG.exception("index failed (submit)")
413
  raise HTTPException(status_code=500, detail=str(e))
414
 
415
  @fastapi_app.get("/status/{job_id}")
 
427
  @fastapi_app.post("/search")
428
  def search(req: SearchRequest):
429
  base, ds_dir, fx_dir = _proj_dirs(req.project_id)
430
+
431
+ # Si l'index n'existe pas encore, on répond 409 (conflit / pas prêt)
432
+ idx_path = os.path.join(fx_dir, "emb.faiss")
433
+ ds_path = os.path.join(ds_dir, "data.jsonl")
434
+ if not (os.path.isfile(idx_path) and os.path.isfile(ds_path)):
435
+ raise HTTPException(status_code=409, detail="Index non prêt (reviens plus tard)")
436
+
437
  rows = _load_dataset(ds_dir)
438
  if not rows:
439
+ raise HTTPException(status_code=404, detail="dataset introuvable")
440
 
441
  # Embedding de la requête avec le MÊME provider
442
  if EMB_PROVIDER == "dummy":
 
482
  return StreamingResponse(io.BytesIO(buf), media_type="application/gzip", headers=headers)
483
 
484
  # =============================================================================
485
+ # GRADIO UI (facultatif de test)
486
  # =============================================================================
487
  def _ui_index(project_id: str, sample_text: str):
488
  files = [{"path": "sample.txt", "text": sample_text}]
 
504
  except Exception as e:
505
  return f"Erreur search: {e}"
506
 
507
+ with gr.Blocks(title="Remote Indexer (Async FAISS)", analytics_enabled=False) as ui:
508
+ gr.Markdown("## Remote Indexer — **Async** (API: `/index`, `/status/{job}`, `/search`, `/artifacts/...`).")
509
+ gr.Markdown(f"**Provider**: `{EMB_PROVIDER}` — **Model**: `{EMB_MODEL if EMB_PROVIDER!='dummy' else '-'}` — **Cache**: `{os.getenv('CACHE_ROOT', '/tmp/.cache')}` — **Workers**: `{MAX_WORKERS}`")
510
  with gr.Tab("Index"):
511
  pid = gr.Textbox(label="Project ID", value="DEEPWEB")
512
  sample = gr.Textbox(label="Texte d’exemple", value="Alpha bravo charlie delta echo foxtrot.", lines=4)
 
529
  # =============================================================================
530
  if __name__ == "__main__":
531
  import uvicorn
532
+ LOG.info("Démarrage Uvicorn sur 0.0.0.0:%s (UI_PATH=/ui) — async index", PORT)
533
+ uvicorn.run(fastapi_app, host="0.0.0.0", port=PORT)