Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| from __future__ import annotations | |
| import os, time, uuid, logging | |
| from typing import List, Optional, Dict, Any, Tuple | |
| import numpy as np | |
| import requests | |
| from fastapi import FastAPI, BackgroundTasks, Header, HTTPException | |
| from pydantic import BaseModel, Field | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import VectorParams, Distance, PointStruct | |
| # ---------- logging ---------- | |
| logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s") | |
| LOG = logging.getLogger("remote_indexer") | |
| # ---------- ENV ---------- | |
| EMB_BACKEND = os.getenv("EMB_BACKEND", "hf").strip().lower() # "hf" (défaut) ou "deepinfra" | |
| # HF | |
| HF_TOKEN = os.getenv("HF_API_TOKEN", "").strip() | |
| HF_MODEL = os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
| # Si tu as un Inference Endpoint privé, ou si tu veux l’API "models/..." : | |
| # ex: https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2 | |
| HF_URL = (os.getenv("HF_API_URL", "").strip() | |
| or f"https://api-inference.huggingface.co/pipeline/feature-extraction/{HF_MODEL}") | |
| # DeepInfra | |
| DI_TOKEN = os.getenv("DEEPINFRA_API_KEY", "").strip() | |
| DI_MODEL = os.getenv("DEEPINFRA_EMBED_MODEL", "thenlper/gte-small").strip() | |
| DI_URL = os.getenv("DEEPINFRA_EMBED_URL", "https://api.deepinfra.com/v1/embeddings").strip() | |
| # Qdrant | |
| QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333") | |
| QDRANT_API = os.getenv("QDRANT_API_KEY", "").strip() | |
| # Auth d’API du service (simple header) | |
| AUTH_TOKEN = os.getenv("REMOTE_INDEX_TOKEN", "").strip() | |
| LOG.info(f"Embeddings backend = {EMB_BACKEND}") | |
| if EMB_BACKEND == "hf" and not HF_TOKEN: | |
| LOG.warning("HF_API_TOKEN manquant — HF /index et /query échoueront.") | |
| if EMB_BACKEND == "deepinfra" and not DI_TOKEN: | |
| LOG.warning("DEEPINFRA_API_KEY manquant — DeepInfra embeddings échoueront.") | |
| # ---------- Clients ---------- | |
| try: | |
| qdr = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API if QDRANT_API else None) | |
| except Exception as e: | |
| LOG.warning(f"Qdrant client init: {e}") | |
| # ---------- Pydantic ---------- | |
| class FileIn(BaseModel): | |
| path: str | |
| text: str | |
| class IndexRequest(BaseModel): | |
| project_id: str = Field(..., min_length=1) | |
| files: List[FileIn] | |
| chunk_size: int = 1200 | |
| overlap: int = 200 | |
| batch_size: int = 8 | |
| store_text: bool = True | |
| class QueryRequest(BaseModel): | |
| project_id: str | |
| query: str | |
| top_k: int = 6 | |
| # ---------- Jobs store (mémoire) ---------- | |
| JOBS: Dict[str, Dict[str, Any]] = {} # {job_id: {"status": "...", "logs": [...], "created": ts}} | |
| # ---------- Utils ---------- | |
| def _auth(x_auth: Optional[str]): | |
| if AUTH_TOKEN and (x_auth or "") != AUTH_TOKEN: | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| def _hf_post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]: | |
| if not HF_TOKEN: | |
| raise RuntimeError("HF_API_TOKEN manquant (backend=hf).") | |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
| try: | |
| r = requests.post(HF_URL, headers=headers, json=batch, timeout=120) | |
| size = int(r.headers.get("Content-Length", "0")) | |
| if r.status_code >= 400: | |
| # Log détaillé pour comprendre le 403/4xx | |
| try: | |
| LOG.error(f"HF error {r.status_code}: {r.text}") | |
| except Exception: | |
| LOG.error(f"HF error {r.status_code} (no body)") | |
| r.raise_for_status() | |
| data = r.json() | |
| except Exception as e: | |
| raise RuntimeError(f"HF POST failed: {e}") | |
| arr = np.array(data, dtype=np.float32) | |
| # [batch, dim] (sentence-transformers) ou [batch, tokens, dim] -> mean-pooling | |
| if arr.ndim == 3: | |
| arr = arr.mean(axis=1) | |
| if arr.ndim != 2: | |
| raise RuntimeError(f"HF: unexpected embeddings shape: {arr.shape}") | |
| # normalisation | |
| norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12 | |
| arr = arr / norms | |
| return arr.astype(np.float32), size | |
| def _di_post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]: | |
| if not DI_TOKEN: | |
| raise RuntimeError("DEEPINFRA_API_KEY manquant (backend=deepinfra).") | |
| headers = {"Authorization": f"Bearer {DI_TOKEN}", "Content-Type": "application/json"} | |
| payload = {"model": DI_MODEL, "input": batch} | |
| try: | |
| r = requests.post(DI_URL, headers=headers, json=payload, timeout=120) | |
| size = int(r.headers.get("Content-Length", "0")) | |
| if r.status_code >= 400: | |
| try: | |
| LOG.error(f"DeepInfra error {r.status_code}: {r.text}") | |
| except Exception: | |
| LOG.error(f"DeepInfra error {r.status_code} (no body)") | |
| r.raise_for_status() | |
| js = r.json() | |
| except Exception as e: | |
| raise RuntimeError(f"DeepInfra POST failed: {e}") | |
| # OpenAI-like : {"data":[{"embedding":[...],"index":0}, ...]} | |
| data = js.get("data") | |
| if not isinstance(data, list) or not data: | |
| raise RuntimeError(f"DeepInfra embeddings: réponse invalide {js}") | |
| embs = [d.get("embedding") for d in data] | |
| arr = np.asarray(embs, dtype=np.float32) | |
| if arr.ndim != 2: | |
| raise RuntimeError(f"DeepInfra: unexpected embeddings shape: {arr.shape}") | |
| # normalisation | |
| norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12 | |
| arr = arr / norms | |
| return arr.astype(np.float32), size | |
| def _post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]: | |
| if EMB_BACKEND == "hf": | |
| return _hf_post_embeddings(batch) | |
| elif EMB_BACKEND == "deepinfra": | |
| return _di_post_embeddings(batch) | |
| else: | |
| raise RuntimeError(f"EMB_BACKEND inconnu: {EMB_BACKEND}") | |
| def _ensure_collection(name: str, dim: int): | |
| try: | |
| qdr.get_collection(name) | |
| return | |
| except Exception: | |
| pass | |
| qdr.create_collection( | |
| collection_name=name, | |
| vectors_config=VectorParams(size=dim, distance=Distance.COSINE), | |
| ) | |
| def _chunk_with_spans(text: str, size: int, overlap: int): | |
| n = len(text) | |
| if size <= 0: | |
| yield (0, n, text) | |
| return | |
| i = 0 | |
| while i < n: | |
| j = min(n, i + size) | |
| yield (i, j, text[i:j]) | |
| i = max(0, j - overlap) | |
| if i >= n: | |
| break | |
| def _append_log(job_id: str, line: str): | |
| job = JOBS.get(job_id) | |
| if not job: return | |
| job["logs"].append(line) | |
| def _set_status(job_id: str, status: str): | |
| job = JOBS.get(job_id) | |
| if not job: return | |
| job["status"] = status | |
| # ---------- Background task ---------- | |
| def run_index_job(job_id: str, req: IndexRequest): | |
| try: | |
| _set_status(job_id, "running") | |
| total_chunks = 0 | |
| LOG.info(f"[{job_id}] Index start project={req.project_id} files={len(req.files)}") | |
| _append_log(job_id, f"Start project={req.project_id} files={len(req.files)} | backend={EMB_BACKEND}") | |
| # warmup -> dimension | |
| warmup = [] | |
| for f in req.files[:1]: | |
| warmup.append(next(_chunk_with_spans(f.text, req.chunk_size, req.overlap))[2]) | |
| embs, sz = _post_embeddings(warmup) | |
| dim = embs.shape[1] | |
| col = f"proj_{req.project_id}" | |
| _ensure_collection(col, dim) | |
| _append_log(job_id, f"Collection ready: {col} (dim={dim})") | |
| point_id = 0 | |
| # boucle fichiers | |
| for fi, f in enumerate(req.files, 1): | |
| chunks, metas = [], [] | |
| for ci, (start, end, chunk_txt) in enumerate(_chunk_with_spans(f.text, req.chunk_size, req.overlap)): | |
| chunks.append(chunk_txt) | |
| payload = {"path": f.path, "chunk": ci, "start": start, "end": end} | |
| if req.store_text: | |
| payload["text"] = chunk_txt | |
| metas.append(payload) | |
| if len(chunks) >= req.batch_size: | |
| vecs, sz = _post_embeddings(chunks) | |
| batch_points = [] | |
| for k, vec in enumerate(vecs): | |
| batch_points.append(PointStruct(id=point_id, vector=vec.tolist(), payload=metas[k])) | |
| point_id += 1 | |
| qdr.upsert(collection_name=col, points=batch_points) | |
| total_chunks += len(chunks) | |
| _append_log(job_id, f"file {fi}/{len(req.files)}: +{len(chunks)} chunks (total={total_chunks}) ~{sz/1024:.1f}KiB") | |
| chunks, metas = [], [] | |
| # flush fin de fichier | |
| if chunks: | |
| vecs, sz = _post_embeddings(chunks) | |
| batch_points = [] | |
| for k, vec in enumerate(vecs): | |
| batch_points.append(PointStruct(id=point_id, vector=vec.tolist(), payload=metas[k])) | |
| point_id += 1 | |
| qdr.upsert(collection_name=col, points=batch_points) | |
| total_chunks += len(chunks) | |
| _append_log(job_id, f"file {fi}/{len(req.files)}: +{len(chunks)} chunks (total={total_chunks}) ~{sz/1024:.1f}KiB") | |
| _append_log(job_id, f"Done. chunks={total_chunks}") | |
| _set_status(job_id, "done") | |
| LOG.info(f"[{job_id}] Index finished. chunks={total_chunks}") | |
| except Exception as e: | |
| LOG.exception("Index job failed") | |
| _append_log(job_id, f"ERROR: {e}") | |
| _set_status(job_id, "error") | |
| # ---------- API ---------- | |
| app = FastAPI() | |
| def root(): | |
| return { | |
| "ok": True, | |
| "service": "remote-indexer", | |
| "backend": EMB_BACKEND, | |
| "hf_url": HF_URL if EMB_BACKEND == "hf" else None, | |
| "di_model": DI_MODEL if EMB_BACKEND == "deepinfra" else None, | |
| "docs": "/health, /index, /status/{job_id}, /query, /wipe" | |
| } | |
| def health(): | |
| return {"ok": True} | |
| def _check_backend_ready(for_query=False): | |
| if EMB_BACKEND == "hf" and not HF_TOKEN: | |
| raise HTTPException(400, "HF_API_TOKEN manquant côté serveur (backend=hf).") | |
| if EMB_BACKEND == "deepinfra" and not DI_TOKEN: | |
| raise HTTPException(400, "DEEPINFRA_API_KEY manquant côté serveur (backend=deepinfra).") | |
| def start_index(req: IndexRequest, background_tasks: BackgroundTasks, x_auth_token: Optional[str] = Header(default=None)): | |
| if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN: | |
| raise HTTPException(401, "Unauthorized") | |
| _check_backend_ready() | |
| job_id = uuid.uuid4().hex[:12] | |
| JOBS[job_id] = {"status": "queued", "logs": [], "created": time.time()} | |
| background_tasks.add_task(run_index_job, job_id, req) | |
| return {"job_id": job_id} | |
| def status(job_id: str, x_auth_token: Optional[str] = Header(default=None)): | |
| if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN: | |
| raise HTTPException(401, "Unauthorized") | |
| j = JOBS.get(job_id) | |
| if not j: | |
| raise HTTPException(404, "job inconnu") | |
| return {"status": j["status"], "logs": j["logs"][-800:]} | |
| def query(req: QueryRequest, x_auth_token: Optional[str] = Header(default=None)): | |
| if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN: | |
| raise HTTPException(401, "Unauthorized") | |
| _check_backend_ready(for_query=True) | |
| vec, _ = _post_embeddings([req.query]) | |
| vec = vec[0].tolist() | |
| col = f"proj_{req.project_id}" | |
| try: | |
| res = qdr.search(collection_name=col, query_vector=vec, limit=int(req.top_k)) | |
| except Exception as e: | |
| raise HTTPException(400, f"Search failed: {e}") | |
| out = [] | |
| for p in res: | |
| pl = p.payload or {} | |
| txt = pl.get("text") | |
| if txt and len(txt) > 800: | |
| txt = txt[:800] + "..." | |
| out.append({"path": pl.get("path"), "chunk": pl.get("chunk"), "start": pl.get("start"), "end": pl.get("end"), "text": txt}) | |
| return {"results": out} | |
| def wipe_collection(project_id: str, x_auth_token: Optional[str] = Header(default=None)): | |
| if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN: | |
| raise HTTPException(401, "Unauthorized") | |
| col = f"proj_{project_id}" | |
| try: | |
| qdr.delete_collection(col) | |
| return {"ok": True} | |
| except Exception as e: | |
| raise HTTPException(400, f"wipe failed: {e}") | |
| # ---------- Entrypoint ---------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", "7860")) | |
| LOG.info(f"===== Application Startup on PORT {port} =====") | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |