chouchouvs commited on
Commit
e08033d
·
verified ·
1 Parent(s): 9059356

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +134 -62
main.py CHANGED
@@ -80,89 +80,161 @@ if "deepinfra" in EMB_BACKEND_ORDER and not DI_TOKEN:
80
  LOG.warning("DEEPINFRA_API_KEY manquant — tentatives DeepInfra échoueront.")
81
 
82
  # ---------- Vector store abstraction ----------
83
- class VectorStoreBase:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def ensure_collection(self, name: str, dim: int): ...
85
- def upsert(self, name: str, vectors: np.ndarray, payloads: List[dict]) -> int: ...
86
- def search(self, name: str, query_vec: np.ndarray, limit: int):
87
- """return list of objects with .score and .payload"""
88
- ...
89
  def wipe(self, name: str): ...
90
 
91
- class MemoryHit:
92
- def __init__(self, score: float, payload: dict):
93
- self.score = score
94
- self.payload = payload
95
 
96
- class MemoryStore(VectorStoreBase):
97
- """Simple store en mémoire (cosine sur vecteurs normalisés). Persistance: vie du process."""
98
  def __init__(self):
99
- self.data: Dict[str, Dict[str, Any]] = {} # {col: {"dim": d, "vecs": np.ndarray [N,d], "payloads": List[dict]}}
100
- LOG.warning("Vector store: MEMORY (fallback). Les données sont volatiles (perdues au restart).")
101
 
102
  def ensure_collection(self, name: str, dim: int):
103
- col = self.data.get(name)
104
- if not col:
105
- self.data[name] = {"dim": dim, "vecs": np.zeros((0, dim), dtype=np.float32), "payloads": []}
106
-
107
- def upsert(self, name: str, vectors: np.ndarray, payloads: List[dict]) -> int:
108
- self.ensure_collection(name, vectors.shape[1])
109
- col = self.data[name]
110
- if vectors.ndim != 2 or vectors.shape[1] != col["dim"]:
111
- raise RuntimeError(f"MemoryStore: bad shape {vectors.shape}, expected (*,{col['dim']})")
112
- col["vecs"] = np.vstack([col["vecs"], vectors.astype(np.float32)])
113
- col["payloads"].extend(payloads)
114
- return vectors.shape[0]
115
-
116
- def search(self, name: str, query_vec: np.ndarray, limit: int):
117
- col = self.data.get(name)
118
- if not col or col["vecs"].shape[0] == 0:
119
  return []
120
- V = col["vecs"] # [N,d], déjà normalisés
121
- q = query_vec.reshape(1, -1) # [1,d]
122
- scores = (V @ q.T).ravel() # cos sim
123
- idx = np.argsort(-scores)[:limit]
124
- return [MemoryHit(float(scores[i]), col["payloads"][i]) for i in idx]
 
 
 
 
 
 
 
125
 
126
  def wipe(self, name: str):
127
- if name in self.data:
128
- del self.data[name]
129
 
130
- class QdrantStore(VectorStoreBase):
131
- def __init__(self, url: str, api_key: Optional[str]):
132
- if QdrantClient is None:
133
- raise RuntimeError("qdrant-client non installé.")
 
 
134
  self.client = QdrantClient(url=url, api_key=api_key if api_key else None)
135
- # ping rapide
 
 
 
 
136
  try:
137
- _ = self.client.get_collections()
138
- LOG.info("Connecté à Qdrant.")
139
- except Exception as e:
140
- raise RuntimeError(f"Connexion Qdrant impossible: {e}")
 
141
 
142
  def ensure_collection(self, name: str, dim: int):
 
143
  try:
144
- self.client.get_collection(name); return
145
  except Exception:
146
- pass
147
- self.client.create_collection(
148
- collection_name=name,
149
- vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
150
- )
151
-
152
- def upsert(self, name: str, vectors: np.ndarray, payloads: List[dict]) -> int:
153
- points = [
154
- PointStruct(id=None, vector=v.tolist(), payload=payloads[i])
 
 
 
 
 
 
 
 
 
 
 
 
155
  for i, v in enumerate(vectors)
156
  ]
157
- self.client.upsert(collection_name=name, points=points)
158
- return len(points)
159
-
160
- def search(self, name: str, query_vec: np.ndarray, limit: int):
161
- res = self.client.search(collection_name=name, query_vector=query_vec.tolist(), limit=limit)
162
- return res
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  def wipe(self, name: str):
165
- self.client.delete_collection(name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  # Sélection / auto-fallback du store
168
  STORE: VectorStoreBase
 
80
  LOG.warning("DEEPINFRA_API_KEY manquant — tentatives DeepInfra échoueront.")
81
 
82
  # ---------- Vector store abstraction ----------
83
+ # ---------- Vector Stores ----------
84
+ from typing import Dict, Any, List, Optional, Tuple
85
+ import numpy as np
86
+ import logging
87
+
88
+ LOG = logging.getLogger("remote_indexer")
89
+
90
+ try:
91
+ from qdrant_client import QdrantClient
92
+ from qdrant_client.http.models import VectorParams, Distance, PointStruct
93
+ except Exception:
94
+ QdrantClient = None
95
+ PointStruct = None
96
+
97
+
98
+ class BaseStore:
99
  def ensure_collection(self, name: str, dim: int): ...
100
+ def upsert(self, name: str, vectors: np.ndarray, payloads: List[Dict[str, Any]]) -> int: ...
101
+ def search(self, name: str, query_vec: np.ndarray, top_k: int) -> List[Dict[str, Any]]: ...
 
 
102
  def wipe(self, name: str): ...
103
 
 
 
 
 
104
 
105
+ class MemoryStore(BaseStore):
106
+ """Store en mémoire (volatile)."""
107
  def __init__(self):
108
+ # { collection: {"vecs": [np.ndarray], "payloads": [dict]} }
109
+ self.db: Dict[str, Dict[str, List[Any]]] = {}
110
 
111
  def ensure_collection(self, name: str, dim: int):
112
+ self.db.setdefault(name, {"vecs": [], "payloads": [], "dim": dim})
113
+
114
+ def upsert(self, name: str, vectors: np.ndarray, payloads: List[Dict[str, Any]]) -> int:
115
+ if name not in self.db:
116
+ raise RuntimeError(f"MemoryStore: collection {name} inconnue")
117
+ if len(vectors) != len(payloads):
118
+ raise ValueError("MemoryStore.upsert: tailles vectors/payloads incohérentes")
119
+ self.db[name]["vecs"].extend([v.astype(np.float32) for v in vectors])
120
+ self.db[name]["payloads"].extend(payloads)
121
+ return len(vectors)
122
+
123
+ def search(self, name: str, query_vec: np.ndarray, top_k: int) -> List[Dict[str, Any]]:
124
+ if name not in self.db or not self.db[name]["vecs"]:
 
 
 
125
  return []
126
+ mat = np.vstack(self.db[name]["vecs"]) # [N, dim]
127
+ q = query_vec.reshape(1, -1).astype(np.float32) # [1, dim]
128
+ # cosine similarity sur vecteurs normalisés
129
+ # (on suppose que les embeddings sont déjà normalisés en amont)
130
+ sims = (mat @ q.T).ravel() # [N]
131
+ top_idx = np.argsort(-sims)[:top_k]
132
+ out = []
133
+ for i in top_idx:
134
+ pl = dict(self.db[name]["payloads"][i])
135
+ pl["_score"] = float(sims[i])
136
+ out.append(pl)
137
+ return out
138
 
139
  def wipe(self, name: str):
140
+ self.db.pop(name, None)
 
141
 
142
+
143
+ class QdrantStore(BaseStore):
144
+ """Store Qdrant avec gestion d'IDs séquentiels par collection."""
145
+ def __init__(self, url: str, api_key: Optional[str] = None):
146
+ if QdrantClient is None or PointStruct is None:
147
+ raise RuntimeError("qdrant_client non disponible")
148
  self.client = QdrantClient(url=url, api_key=api_key if api_key else None)
149
+ # compteur d'IDs par collection
150
+ self._next_ids: Dict[str, int] = {}
151
+
152
+ def _init_next_id(self, name: str):
153
+ # on cherche le count exact des points existants pour démarrer l'ID à count
154
  try:
155
+ cnt = self.client.count(collection_name=name, exact=True).count
156
+ except Exception:
157
+ # si count échoue (collection vide juste créée), on démarre à 0
158
+ cnt = 0
159
+ self._next_ids[name] = int(cnt)
160
 
161
  def ensure_collection(self, name: str, dim: int):
162
+ # si existe déjà, rien à faire ; sinon, création
163
  try:
164
+ self.client.get_collection(name)
165
  except Exception:
166
+ self.client.create_collection(
167
+ collection_name=name,
168
+ vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
169
+ )
170
+ # initialiser le prochain id si absent
171
+ if name not in self._next_ids:
172
+ self._init_next_id(name)
173
+
174
+ def upsert(self, name: str, vectors: np.ndarray, payloads: List[Dict[str, Any]]) -> int:
175
+ if vectors is None or len(vectors) == 0:
176
+ return 0
177
+ if len(vectors) != len(payloads):
178
+ raise ValueError("QdrantStore.upsert: tailles vectors/payloads incohérentes")
179
+
180
+ if name not in self._next_ids:
181
+ self._init_next_id(name)
182
+
183
+ start = self._next_ids[name]
184
+ # construction des points avec IDs séquentiels (int)
185
+ pts = [
186
+ PointStruct(id=start + i, vector=v.astype(np.float32).tolist(), payload=payloads[i])
187
  for i, v in enumerate(vectors)
188
  ]
189
+ self.client.upsert(collection_name=name, points=pts)
190
+ added = len(pts)
191
+ self._next_ids[name] += added
192
+ LOG.debug(f"QdrantStore.upsert: +{added} points (next_id={self._next_ids[name]})")
193
+ return added
194
+
195
+ def search(self, name: str, query_vec: np.ndarray, top_k: int) -> List[Dict[str, Any]]:
196
+ if query_vec.ndim == 2:
197
+ qv = query_vec[0].astype(np.float32).tolist()
198
+ else:
199
+ qv = query_vec.astype(np.float32).tolist()
200
+ res = self.client.search(collection_name=name, query_vector=qv, limit=int(top_k))
201
+ out = []
202
+ for p in res:
203
+ pl = p.payload or {}
204
+ pl["_score"] = float(p.score) if hasattr(p, "score") else None
205
+ out.append(pl)
206
+ return out
207
 
208
  def wipe(self, name: str):
209
+ try:
210
+ self.client.delete_collection(name)
211
+ except Exception:
212
+ pass
213
+ self._next_ids.pop(name, None)
214
+
215
+
216
+ # ---------- Initialisation du store actif ----------
217
+ import os
218
+
219
+ VECTOR_STORE = os.getenv("VECTOR_STORE", "qdrant").strip().lower()
220
+ QDRANT_URL = os.getenv("QDRANT_URL", "").strip()
221
+ QDRANT_API = os.getenv("QDRANT_API_KEY", "").strip()
222
+
223
+ try:
224
+ if VECTOR_STORE == "qdrant" and QDRANT_URL:
225
+ STORE: BaseStore = QdrantStore(QDRANT_URL, api_key=QDRANT_API)
226
+ # test léger: liste des collections
227
+ _ = STORE.client.get_collections()
228
+ LOG.info("Connecté à Qdrant.")
229
+ VECTOR_STORE_ACTIVE = "QdrantStore"
230
+ else:
231
+ raise RuntimeError("Qdrant non configuré, fallback mémoire.")
232
+ except Exception as e:
233
+ LOG.error(f"Qdrant indisponible ({e}) — fallback en mémoire.")
234
+ STORE = MemoryStore()
235
+ VECTOR_STORE_ACTIVE = "MemoryStore"
236
+ LOG.warning("Vector store: MEMORY (fallback). Les données sont volatiles (perdues au restart).")
237
+
238
 
239
  # Sélection / auto-fallback du store
240
  STORE: VectorStoreBase