Upload mle/routing.py
Browse files- mle/routing.py +324 -0
mle/routing.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Routing optimisé via Hamming distance + bit-slicing SIMD
|
| 3 |
+
|
| 4 |
+
Le routing est le cœur de la récupération rapide dans la mémoire distribuée.
|
| 5 |
+
On utilise :
|
| 6 |
+
- Bit-slicing : découpe les vecteurs 4096 bits en tranches de 64 bits
|
| 7 |
+
- Hamming distance via popcount par tranche (SIMD-friendly)
|
| 8 |
+
- Index inversé pour sauts rapides
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from numba import njit, prange, uint64
|
| 13 |
+
from typing import List, Tuple, Dict, Optional
|
| 14 |
+
import logging
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
VECTOR_SIZE = 4096
|
| 19 |
+
SLICE_BITS = 64
|
| 20 |
+
NUM_SLICES = VECTOR_SIZE // SLICE_BITS
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@njit(cache=True)
|
| 24 |
+
def pack_bits_to_uint64(bits: np.ndarray) -> np.ndarray:
|
| 25 |
+
"""
|
| 26 |
+
Pack un vecteur 4096 bits (uint8) en 64 uint64.
|
| 27 |
+
Chaque uint64 contient 64 bits du vecteur original.
|
| 28 |
+
"""
|
| 29 |
+
result = np.zeros(NUM_SLICES, dtype=np.uint64)
|
| 30 |
+
for i in range(NUM_SLICES):
|
| 31 |
+
val = np.uint64(0)
|
| 32 |
+
for j in range(SLICE_BITS):
|
| 33 |
+
if bits[i * SLICE_BITS + j]:
|
| 34 |
+
val |= np.uint64(1) << np.uint64(j)
|
| 35 |
+
result[i] = val
|
| 36 |
+
return result
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@njit(cache=True)
|
| 40 |
+
def unpack_uint64_to_bits(packed: np.ndarray) -> np.ndarray:
|
| 41 |
+
"""Unpack 64 uint64 en un vecteur 4096 bits."""
|
| 42 |
+
result = np.zeros(VECTOR_SIZE, dtype=np.uint8)
|
| 43 |
+
for i in range(NUM_SLICES):
|
| 44 |
+
val = packed[i]
|
| 45 |
+
for j in range(SLICE_BITS):
|
| 46 |
+
result[i * SLICE_BITS + j] = np.uint8((val >> np.uint64(j)) & np.uint64(1))
|
| 47 |
+
return result
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@njit(parallel=True, cache=True)
|
| 51 |
+
def hamming_uint64_batch(query_packed: np.ndarray, table_packed: np.ndarray) -> np.ndarray:
|
| 52 |
+
"""
|
| 53 |
+
Distance de Hamming entre un vecteur packé et N vecteurs packés.
|
| 54 |
+
Utilise XOR + popcount via bit-twiddling (très rapide).
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
query_packed: (NUM_SLICES,) uint64
|
| 58 |
+
table_packed: (N, NUM_SLICES) uint64
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
distances: (N,) int32
|
| 62 |
+
"""
|
| 63 |
+
N = table_packed.shape[0]
|
| 64 |
+
distances = np.empty(N, dtype=np.int32)
|
| 65 |
+
|
| 66 |
+
for i in prange(N):
|
| 67 |
+
dist = np.int32(0)
|
| 68 |
+
for j in range(NUM_SLICES):
|
| 69 |
+
xor_val = query_packed[j] ^ table_packed[i, j]
|
| 70 |
+
# Popcount pour uint64
|
| 71 |
+
x = xor_val
|
| 72 |
+
x = x - ((x >> np.uint64(1)) & np.uint64(0x5555555555555555))
|
| 73 |
+
x = (x & np.uint64(0x3333333333333333)) + ((x >> np.uint64(2)) & np.uint64(0x3333333333333333))
|
| 74 |
+
x = (x + (x >> np.uint64(4))) & np.uint64(0x0F0F0F0F0F0F0F0F)
|
| 75 |
+
x = x + (x >> np.uint64(8))
|
| 76 |
+
x = x + (x >> np.uint64(16))
|
| 77 |
+
x = x + (x >> np.uint64(32))
|
| 78 |
+
dist += np.int32(x & np.uint64(0x7F))
|
| 79 |
+
distances[i] = dist
|
| 80 |
+
|
| 81 |
+
return distances
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class BitSliceIndex:
|
| 85 |
+
"""
|
| 86 |
+
Index inversé par tranches de bits.
|
| 87 |
+
Pour chaque tranche (64 bits), maintient une table de hachage des
|
| 88 |
+
tranches observées vers les vecteurs qui les contiennent.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self):
|
| 92 |
+
# slice_idx -> {slice_hash -> [vector_indices]}
|
| 93 |
+
self.slice_maps: List[Dict[int, List[int]]] = [
|
| 94 |
+
{} for _ in range(NUM_SLICES)
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
def add_vector(self, vector_idx: int, packed: np.ndarray):
|
| 98 |
+
"""Ajoute un vecteur à l'index."""
|
| 99 |
+
for slice_idx in range(NUM_SLICES):
|
| 100 |
+
slice_val = int(packed[slice_idx])
|
| 101 |
+
if slice_val not in self.slice_maps[slice_idx]:
|
| 102 |
+
self.slice_maps[slice_idx][slice_val] = []
|
| 103 |
+
self.slice_maps[slice_idx][slice_val].append(vector_idx)
|
| 104 |
+
|
| 105 |
+
def remove_vector(self, vector_idx: int, packed: np.ndarray):
|
| 106 |
+
"""Retire un vecteur de l'index."""
|
| 107 |
+
for slice_idx in range(NUM_SLICES):
|
| 108 |
+
slice_val = int(packed[slice_idx])
|
| 109 |
+
if slice_val in self.slice_maps[slice_idx]:
|
| 110 |
+
lst = self.slice_maps[slice_idx][slice_val]
|
| 111 |
+
if vector_idx in lst:
|
| 112 |
+
lst.remove(vector_idx)
|
| 113 |
+
|
| 114 |
+
def query_candidates(
|
| 115 |
+
self,
|
| 116 |
+
query_packed: np.ndarray,
|
| 117 |
+
top_slices: int = 8,
|
| 118 |
+
max_candidates: int = 100
|
| 119 |
+
) -> List[Tuple[int, int]]:
|
| 120 |
+
"""
|
| 121 |
+
Retourne les candidats qui partagent le plus de tranches avec la requête.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
List of (vector_idx, match_count)
|
| 125 |
+
"""
|
| 126 |
+
counts: Dict[int, int] = {}
|
| 127 |
+
|
| 128 |
+
# Prend les top_slices tranches les plus discriminantes
|
| 129 |
+
# (celles qui ont le moins d'entrées dans l'index)
|
| 130 |
+
slice_counts = []
|
| 131 |
+
for slice_idx in range(NUM_SLICES):
|
| 132 |
+
slice_val = int(query_packed[slice_idx])
|
| 133 |
+
n_entries = len(self.slice_maps[slice_idx].get(slice_val, []))
|
| 134 |
+
slice_counts.append((slice_idx, n_entries))
|
| 135 |
+
|
| 136 |
+
# Trie : préfère les tranches rares (plus discriminantes)
|
| 137 |
+
slice_counts.sort(key=lambda x: x[1])
|
| 138 |
+
|
| 139 |
+
for slice_idx, _ in slice_counts[:top_slices]:
|
| 140 |
+
slice_val = int(query_packed[slice_idx])
|
| 141 |
+
for vec_idx in self.slice_maps[slice_idx].get(slice_val, []):
|
| 142 |
+
counts[vec_idx] = counts.get(vec_idx, 0) + 1
|
| 143 |
+
|
| 144 |
+
# Trie par nombre de matches
|
| 145 |
+
candidates = sorted(counts.items(), key=lambda x: -x[1])
|
| 146 |
+
return candidates[:max_candidates]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class HammingRouter:
|
| 150 |
+
"""
|
| 151 |
+
Routeur optimisé basé sur Hamming + bit-slicing.
|
| 152 |
+
|
| 153 |
+
Responsabilités :
|
| 154 |
+
- Convertir les vecteurs en format packé uint64
|
| 155 |
+
- Maintenir l'index inversé
|
| 156 |
+
- Router les requêtes vers les voisins les plus proches
|
| 157 |
+
- Apprendre les routes fréquentes
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
use_index: bool = True,
|
| 163 |
+
index_top_slices: int = 8,
|
| 164 |
+
cache_size: int = 1000,
|
| 165 |
+
learn_routes: bool = True,
|
| 166 |
+
):
|
| 167 |
+
self.use_index = use_index
|
| 168 |
+
self.index_top_slices = index_top_slices
|
| 169 |
+
self.learn_routes = learn_routes
|
| 170 |
+
|
| 171 |
+
# Stockage packé
|
| 172 |
+
self.packed_vectors: Dict[int, np.ndarray] = {} # vector_idx -> packed
|
| 173 |
+
self.index = BitSliceIndex()
|
| 174 |
+
|
| 175 |
+
# Cache de routes : pattern_hash -> [(target_idx, frequency)]
|
| 176 |
+
self.route_cache: Dict[int, List[Tuple[int, float]]] = {}
|
| 177 |
+
self.cache_size = cache_size
|
| 178 |
+
self.cache_hits = 0
|
| 179 |
+
self.cache_misses = 0
|
| 180 |
+
|
| 181 |
+
# Stats
|
| 182 |
+
self.query_count = 0
|
| 183 |
+
self.avg_query_time = 0.0
|
| 184 |
+
|
| 185 |
+
def add_vector(self, vector_idx: int, vector: np.ndarray):
|
| 186 |
+
"""Ajoute un vecteur au routeur."""
|
| 187 |
+
packed = pack_bits_to_uint64(vector)
|
| 188 |
+
self.packed_vectors[vector_idx] = packed
|
| 189 |
+
if self.use_index:
|
| 190 |
+
self.index.add_vector(vector_idx, packed)
|
| 191 |
+
|
| 192 |
+
def remove_vector(self, vector_idx: int, vector: np.ndarray):
|
| 193 |
+
"""Retire un vecteur du routeur."""
|
| 194 |
+
if vector_idx in self.packed_vectors:
|
| 195 |
+
packed = self.packed_vectors[vector_idx]
|
| 196 |
+
if self.use_index:
|
| 197 |
+
self.index.remove_vector(vector_idx, packed)
|
| 198 |
+
del self.packed_vectors[vector_idx]
|
| 199 |
+
|
| 200 |
+
def update_vector(self, vector_idx: int, new_vector: np.ndarray):
|
| 201 |
+
"""Met à jour un vecteur existant."""
|
| 202 |
+
old_packed = self.packed_vectors.get(vector_idx)
|
| 203 |
+
new_packed = pack_bits_to_uint64(new_vector)
|
| 204 |
+
self.packed_vectors[vector_idx] = new_packed
|
| 205 |
+
|
| 206 |
+
if self.use_index and old_packed is not None:
|
| 207 |
+
self.index.remove_vector(vector_idx, old_packed)
|
| 208 |
+
self.index.add_vector(vector_idx, new_packed)
|
| 209 |
+
|
| 210 |
+
def _pattern_hash(self, packed: np.ndarray) -> int:
|
| 211 |
+
"""Hachage rapide d'un vecteur packé pour le cache."""
|
| 212 |
+
# XOR fold des tranches avec mix simple (en Python int pour éviter overflow)
|
| 213 |
+
h = 0xcbf29ce484222325 # FNV offset basis
|
| 214 |
+
for i in range(NUM_SLICES):
|
| 215 |
+
h ^= int(packed[i])
|
| 216 |
+
h = (h * 0x100000001b3) & 0xFFFFFFFFFFFFFFFF
|
| 217 |
+
return h
|
| 218 |
+
|
| 219 |
+
def route(
|
| 220 |
+
self,
|
| 221 |
+
query: np.ndarray,
|
| 222 |
+
candidate_indices: Optional[List[int]] = None,
|
| 223 |
+
k: int = 5,
|
| 224 |
+
use_cache: bool = True
|
| 225 |
+
) -> List[Tuple[int, float]]:
|
| 226 |
+
"""
|
| 227 |
+
Route une requête vers les k voisins les plus proches.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
query: vecteur (4096,) uint8
|
| 231 |
+
candidate_indices: indices candidats (si None, utilise tous)
|
| 232 |
+
k: nombre de résultats
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
[(vector_idx, distance)] trié par distance
|
| 236 |
+
"""
|
| 237 |
+
import time
|
| 238 |
+
t0 = time.time()
|
| 239 |
+
|
| 240 |
+
query_packed = pack_bits_to_uint64(query)
|
| 241 |
+
|
| 242 |
+
# Essaie le cache
|
| 243 |
+
if use_cache and self.learn_routes:
|
| 244 |
+
ph = self._pattern_hash(query_packed)
|
| 245 |
+
if ph in self.route_cache:
|
| 246 |
+
self.cache_hits += 1
|
| 247 |
+
# Filtre les entrées qui existent encore
|
| 248 |
+
results = [
|
| 249 |
+
(idx, 0.0) for idx, freq in self.route_cache[ph]
|
| 250 |
+
if idx in self.packed_vectors
|
| 251 |
+
]
|
| 252 |
+
if len(results) >= k:
|
| 253 |
+
t1 = time.time()
|
| 254 |
+
self._update_timing(t1 - t0)
|
| 255 |
+
return results[:k]
|
| 256 |
+
else:
|
| 257 |
+
self.cache_misses += 1
|
| 258 |
+
|
| 259 |
+
# Détermine les candidats
|
| 260 |
+
if candidate_indices is None:
|
| 261 |
+
if self.use_index and len(self.packed_vectors) > 100:
|
| 262 |
+
# Utilise l'index pour réduire les candidats
|
| 263 |
+
candidates = self.index.query_candidates(
|
| 264 |
+
query_packed,
|
| 265 |
+
top_slices=self.index_top_slices,
|
| 266 |
+
max_candidates=200
|
| 267 |
+
)
|
| 268 |
+
candidate_indices = [idx for idx, _ in candidates]
|
| 269 |
+
else:
|
| 270 |
+
candidate_indices = list(self.packed_vectors.keys())
|
| 271 |
+
|
| 272 |
+
if len(candidate_indices) == 0:
|
| 273 |
+
return []
|
| 274 |
+
|
| 275 |
+
# Calcule les distances de Hamming
|
| 276 |
+
candidates_packed = np.array([
|
| 277 |
+
self.packed_vectors[idx] for idx in candidate_indices
|
| 278 |
+
], dtype=np.uint64)
|
| 279 |
+
|
| 280 |
+
distances = hamming_uint64_batch(query_packed, candidates_packed)
|
| 281 |
+
|
| 282 |
+
# Trie
|
| 283 |
+
sorted_idx = np.argsort(distances)[:k]
|
| 284 |
+
results = [
|
| 285 |
+
(candidate_indices[si], float(distances[si]))
|
| 286 |
+
for si in sorted_idx
|
| 287 |
+
]
|
| 288 |
+
|
| 289 |
+
# Met à jour le cache
|
| 290 |
+
if self.learn_routes and use_cache:
|
| 291 |
+
ph = self._pattern_hash(query_packed)
|
| 292 |
+
if ph not in self.route_cache:
|
| 293 |
+
if len(self.route_cache) >= self.cache_size:
|
| 294 |
+
# Éviction aléatoire
|
| 295 |
+
keys = list(self.route_cache.keys())
|
| 296 |
+
to_evict = keys[np.random.randint(0, len(keys))]
|
| 297 |
+
del self.route_cache[to_evict]
|
| 298 |
+
|
| 299 |
+
# Stocke la route apprise
|
| 300 |
+
self.route_cache[ph] = [
|
| 301 |
+
(idx, 1.0 / (1.0 + dist)) for idx, dist in results
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
t1 = time.time()
|
| 305 |
+
self._update_timing(t1 - t0)
|
| 306 |
+
|
| 307 |
+
return results
|
| 308 |
+
|
| 309 |
+
def _update_timing(self, elapsed: float):
|
| 310 |
+
"""Met à jour les statistiques de temps."""
|
| 311 |
+
self.query_count += 1
|
| 312 |
+
self.avg_query_time = (
|
| 313 |
+
self.avg_query_time * (self.query_count - 1) + elapsed
|
| 314 |
+
) / self.query_count
|
| 315 |
+
|
| 316 |
+
def get_stats(self) -> Dict:
|
| 317 |
+
total = self.cache_hits + self.cache_misses
|
| 318 |
+
hit_rate = self.cache_hits / total if total > 0 else 0.0
|
| 319 |
+
return {
|
| 320 |
+
'query_count': self.query_count,
|
| 321 |
+
'avg_query_time_ms': self.avg_query_time * 1000,
|
| 322 |
+
'cache_hit_rate': hit_rate,
|
| 323 |
+
'n_vectors': len(self.packed_vectors),
|
| 324 |
+
}
|