Harry00 commited on
Commit
b4b14c2
·
verified ·
1 Parent(s): 22007d1

Upload mle/routing.py

Browse files
Files changed (1) hide show
  1. mle/routing.py +324 -0
mle/routing.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Routing optimisé via Hamming distance + bit-slicing SIMD
3
+
4
+ Le routing est le cœur de la récupération rapide dans la mémoire distribuée.
5
+ On utilise :
6
+ - Bit-slicing : découpe les vecteurs 4096 bits en tranches de 64 bits
7
+ - Hamming distance via popcount par tranche (SIMD-friendly)
8
+ - Index inversé pour sauts rapides
9
+ """
10
+
11
+ import numpy as np
12
+ from numba import njit, prange, uint64
13
+ from typing import List, Tuple, Dict, Optional
14
+ import logging
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ VECTOR_SIZE = 4096
19
+ SLICE_BITS = 64
20
+ NUM_SLICES = VECTOR_SIZE // SLICE_BITS
21
+
22
+
23
+ @njit(cache=True)
24
+ def pack_bits_to_uint64(bits: np.ndarray) -> np.ndarray:
25
+ """
26
+ Pack un vecteur 4096 bits (uint8) en 64 uint64.
27
+ Chaque uint64 contient 64 bits du vecteur original.
28
+ """
29
+ result = np.zeros(NUM_SLICES, dtype=np.uint64)
30
+ for i in range(NUM_SLICES):
31
+ val = np.uint64(0)
32
+ for j in range(SLICE_BITS):
33
+ if bits[i * SLICE_BITS + j]:
34
+ val |= np.uint64(1) << np.uint64(j)
35
+ result[i] = val
36
+ return result
37
+
38
+
39
+ @njit(cache=True)
40
+ def unpack_uint64_to_bits(packed: np.ndarray) -> np.ndarray:
41
+ """Unpack 64 uint64 en un vecteur 4096 bits."""
42
+ result = np.zeros(VECTOR_SIZE, dtype=np.uint8)
43
+ for i in range(NUM_SLICES):
44
+ val = packed[i]
45
+ for j in range(SLICE_BITS):
46
+ result[i * SLICE_BITS + j] = np.uint8((val >> np.uint64(j)) & np.uint64(1))
47
+ return result
48
+
49
+
50
+ @njit(parallel=True, cache=True)
51
+ def hamming_uint64_batch(query_packed: np.ndarray, table_packed: np.ndarray) -> np.ndarray:
52
+ """
53
+ Distance de Hamming entre un vecteur packé et N vecteurs packés.
54
+ Utilise XOR + popcount via bit-twiddling (très rapide).
55
+
56
+ Args:
57
+ query_packed: (NUM_SLICES,) uint64
58
+ table_packed: (N, NUM_SLICES) uint64
59
+
60
+ Returns:
61
+ distances: (N,) int32
62
+ """
63
+ N = table_packed.shape[0]
64
+ distances = np.empty(N, dtype=np.int32)
65
+
66
+ for i in prange(N):
67
+ dist = np.int32(0)
68
+ for j in range(NUM_SLICES):
69
+ xor_val = query_packed[j] ^ table_packed[i, j]
70
+ # Popcount pour uint64
71
+ x = xor_val
72
+ x = x - ((x >> np.uint64(1)) & np.uint64(0x5555555555555555))
73
+ x = (x & np.uint64(0x3333333333333333)) + ((x >> np.uint64(2)) & np.uint64(0x3333333333333333))
74
+ x = (x + (x >> np.uint64(4))) & np.uint64(0x0F0F0F0F0F0F0F0F)
75
+ x = x + (x >> np.uint64(8))
76
+ x = x + (x >> np.uint64(16))
77
+ x = x + (x >> np.uint64(32))
78
+ dist += np.int32(x & np.uint64(0x7F))
79
+ distances[i] = dist
80
+
81
+ return distances
82
+
83
+
84
+ class BitSliceIndex:
85
+ """
86
+ Index inversé par tranches de bits.
87
+ Pour chaque tranche (64 bits), maintient une table de hachage des
88
+ tranches observées vers les vecteurs qui les contiennent.
89
+ """
90
+
91
+ def __init__(self):
92
+ # slice_idx -> {slice_hash -> [vector_indices]}
93
+ self.slice_maps: List[Dict[int, List[int]]] = [
94
+ {} for _ in range(NUM_SLICES)
95
+ ]
96
+
97
+ def add_vector(self, vector_idx: int, packed: np.ndarray):
98
+ """Ajoute un vecteur à l'index."""
99
+ for slice_idx in range(NUM_SLICES):
100
+ slice_val = int(packed[slice_idx])
101
+ if slice_val not in self.slice_maps[slice_idx]:
102
+ self.slice_maps[slice_idx][slice_val] = []
103
+ self.slice_maps[slice_idx][slice_val].append(vector_idx)
104
+
105
+ def remove_vector(self, vector_idx: int, packed: np.ndarray):
106
+ """Retire un vecteur de l'index."""
107
+ for slice_idx in range(NUM_SLICES):
108
+ slice_val = int(packed[slice_idx])
109
+ if slice_val in self.slice_maps[slice_idx]:
110
+ lst = self.slice_maps[slice_idx][slice_val]
111
+ if vector_idx in lst:
112
+ lst.remove(vector_idx)
113
+
114
+ def query_candidates(
115
+ self,
116
+ query_packed: np.ndarray,
117
+ top_slices: int = 8,
118
+ max_candidates: int = 100
119
+ ) -> List[Tuple[int, int]]:
120
+ """
121
+ Retourne les candidats qui partagent le plus de tranches avec la requête.
122
+
123
+ Returns:
124
+ List of (vector_idx, match_count)
125
+ """
126
+ counts: Dict[int, int] = {}
127
+
128
+ # Prend les top_slices tranches les plus discriminantes
129
+ # (celles qui ont le moins d'entrées dans l'index)
130
+ slice_counts = []
131
+ for slice_idx in range(NUM_SLICES):
132
+ slice_val = int(query_packed[slice_idx])
133
+ n_entries = len(self.slice_maps[slice_idx].get(slice_val, []))
134
+ slice_counts.append((slice_idx, n_entries))
135
+
136
+ # Trie : préfère les tranches rares (plus discriminantes)
137
+ slice_counts.sort(key=lambda x: x[1])
138
+
139
+ for slice_idx, _ in slice_counts[:top_slices]:
140
+ slice_val = int(query_packed[slice_idx])
141
+ for vec_idx in self.slice_maps[slice_idx].get(slice_val, []):
142
+ counts[vec_idx] = counts.get(vec_idx, 0) + 1
143
+
144
+ # Trie par nombre de matches
145
+ candidates = sorted(counts.items(), key=lambda x: -x[1])
146
+ return candidates[:max_candidates]
147
+
148
+
149
+ class HammingRouter:
150
+ """
151
+ Routeur optimisé basé sur Hamming + bit-slicing.
152
+
153
+ Responsabilités :
154
+ - Convertir les vecteurs en format packé uint64
155
+ - Maintenir l'index inversé
156
+ - Router les requêtes vers les voisins les plus proches
157
+ - Apprendre les routes fréquentes
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ use_index: bool = True,
163
+ index_top_slices: int = 8,
164
+ cache_size: int = 1000,
165
+ learn_routes: bool = True,
166
+ ):
167
+ self.use_index = use_index
168
+ self.index_top_slices = index_top_slices
169
+ self.learn_routes = learn_routes
170
+
171
+ # Stockage packé
172
+ self.packed_vectors: Dict[int, np.ndarray] = {} # vector_idx -> packed
173
+ self.index = BitSliceIndex()
174
+
175
+ # Cache de routes : pattern_hash -> [(target_idx, frequency)]
176
+ self.route_cache: Dict[int, List[Tuple[int, float]]] = {}
177
+ self.cache_size = cache_size
178
+ self.cache_hits = 0
179
+ self.cache_misses = 0
180
+
181
+ # Stats
182
+ self.query_count = 0
183
+ self.avg_query_time = 0.0
184
+
185
+ def add_vector(self, vector_idx: int, vector: np.ndarray):
186
+ """Ajoute un vecteur au routeur."""
187
+ packed = pack_bits_to_uint64(vector)
188
+ self.packed_vectors[vector_idx] = packed
189
+ if self.use_index:
190
+ self.index.add_vector(vector_idx, packed)
191
+
192
+ def remove_vector(self, vector_idx: int, vector: np.ndarray):
193
+ """Retire un vecteur du routeur."""
194
+ if vector_idx in self.packed_vectors:
195
+ packed = self.packed_vectors[vector_idx]
196
+ if self.use_index:
197
+ self.index.remove_vector(vector_idx, packed)
198
+ del self.packed_vectors[vector_idx]
199
+
200
+ def update_vector(self, vector_idx: int, new_vector: np.ndarray):
201
+ """Met à jour un vecteur existant."""
202
+ old_packed = self.packed_vectors.get(vector_idx)
203
+ new_packed = pack_bits_to_uint64(new_vector)
204
+ self.packed_vectors[vector_idx] = new_packed
205
+
206
+ if self.use_index and old_packed is not None:
207
+ self.index.remove_vector(vector_idx, old_packed)
208
+ self.index.add_vector(vector_idx, new_packed)
209
+
210
+ def _pattern_hash(self, packed: np.ndarray) -> int:
211
+ """Hachage rapide d'un vecteur packé pour le cache."""
212
+ # XOR fold des tranches avec mix simple (en Python int pour éviter overflow)
213
+ h = 0xcbf29ce484222325 # FNV offset basis
214
+ for i in range(NUM_SLICES):
215
+ h ^= int(packed[i])
216
+ h = (h * 0x100000001b3) & 0xFFFFFFFFFFFFFFFF
217
+ return h
218
+
219
+ def route(
220
+ self,
221
+ query: np.ndarray,
222
+ candidate_indices: Optional[List[int]] = None,
223
+ k: int = 5,
224
+ use_cache: bool = True
225
+ ) -> List[Tuple[int, float]]:
226
+ """
227
+ Route une requête vers les k voisins les plus proches.
228
+
229
+ Args:
230
+ query: vecteur (4096,) uint8
231
+ candidate_indices: indices candidats (si None, utilise tous)
232
+ k: nombre de résultats
233
+
234
+ Returns:
235
+ [(vector_idx, distance)] trié par distance
236
+ """
237
+ import time
238
+ t0 = time.time()
239
+
240
+ query_packed = pack_bits_to_uint64(query)
241
+
242
+ # Essaie le cache
243
+ if use_cache and self.learn_routes:
244
+ ph = self._pattern_hash(query_packed)
245
+ if ph in self.route_cache:
246
+ self.cache_hits += 1
247
+ # Filtre les entrées qui existent encore
248
+ results = [
249
+ (idx, 0.0) for idx, freq in self.route_cache[ph]
250
+ if idx in self.packed_vectors
251
+ ]
252
+ if len(results) >= k:
253
+ t1 = time.time()
254
+ self._update_timing(t1 - t0)
255
+ return results[:k]
256
+ else:
257
+ self.cache_misses += 1
258
+
259
+ # Détermine les candidats
260
+ if candidate_indices is None:
261
+ if self.use_index and len(self.packed_vectors) > 100:
262
+ # Utilise l'index pour réduire les candidats
263
+ candidates = self.index.query_candidates(
264
+ query_packed,
265
+ top_slices=self.index_top_slices,
266
+ max_candidates=200
267
+ )
268
+ candidate_indices = [idx for idx, _ in candidates]
269
+ else:
270
+ candidate_indices = list(self.packed_vectors.keys())
271
+
272
+ if len(candidate_indices) == 0:
273
+ return []
274
+
275
+ # Calcule les distances de Hamming
276
+ candidates_packed = np.array([
277
+ self.packed_vectors[idx] for idx in candidate_indices
278
+ ], dtype=np.uint64)
279
+
280
+ distances = hamming_uint64_batch(query_packed, candidates_packed)
281
+
282
+ # Trie
283
+ sorted_idx = np.argsort(distances)[:k]
284
+ results = [
285
+ (candidate_indices[si], float(distances[si]))
286
+ for si in sorted_idx
287
+ ]
288
+
289
+ # Met à jour le cache
290
+ if self.learn_routes and use_cache:
291
+ ph = self._pattern_hash(query_packed)
292
+ if ph not in self.route_cache:
293
+ if len(self.route_cache) >= self.cache_size:
294
+ # Éviction aléatoire
295
+ keys = list(self.route_cache.keys())
296
+ to_evict = keys[np.random.randint(0, len(keys))]
297
+ del self.route_cache[to_evict]
298
+
299
+ # Stocke la route apprise
300
+ self.route_cache[ph] = [
301
+ (idx, 1.0 / (1.0 + dist)) for idx, dist in results
302
+ ]
303
+
304
+ t1 = time.time()
305
+ self._update_timing(t1 - t0)
306
+
307
+ return results
308
+
309
+ def _update_timing(self, elapsed: float):
310
+ """Met à jour les statistiques de temps."""
311
+ self.query_count += 1
312
+ self.avg_query_time = (
313
+ self.avg_query_time * (self.query_count - 1) + elapsed
314
+ ) / self.query_count
315
+
316
+ def get_stats(self) -> Dict:
317
+ total = self.cache_hits + self.cache_misses
318
+ hit_rate = self.cache_hits / total if total > 0 else 0.0
319
+ return {
320
+ 'query_count': self.query_count,
321
+ 'avg_query_time_ms': self.avg_query_time * 1000,
322
+ 'cache_hit_rate': hit_rate,
323
+ 'n_vectors': len(self.packed_vectors),
324
+ }