Harry00 commited on
Commit
22007d1
·
verified ·
1 Parent(s): 949842e

Upload mle/inference.py

Browse files
Files changed (1) hide show
  1. mle/inference.py +446 -0
mle/inference.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Moteur d'Inférence avec Apprentissage en Ligne
3
+
4
+ L'inférence minimise l'énergie par descente stochastique locale.
5
+ À chaque itération :
6
+ 1. Calcule les voisins via le routeur
7
+ 2. Évalue l'énergie du paysage
8
+ 3. Sélectionne les flips de bits qui réduisent l'énergie
9
+ 4. Met à jour les associations (apprentissage en ligne)
10
+ 5. Détecte motifs pour abstraction
11
+
12
+ La minimisation est un processus de Monte Carlo / Hopfield-like
13
+ mais avec mémoire adaptative et apprentissage continu.
14
+ """
15
+
16
+ import numpy as np
17
+ from numba import njit, prange
18
+ from typing import List, Tuple, Dict, Optional, Callable
19
+ import logging
20
+ import time
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ VECTOR_SIZE = 4096
25
+
26
+
27
+ @njit(cache=True)
28
+ def random_flip_batch(state: np.ndarray, n_flips: int, rng_seed: int) -> np.ndarray:
29
+ """Flip aléatoire de n_flips bits."""
30
+ np.random.seed(rng_seed)
31
+ new_state = state.copy()
32
+ flip_indices = np.random.choice(VECTOR_SIZE, size=n_flips, replace=False)
33
+ for idx in flip_indices:
34
+ new_state[idx] = 1 - new_state[idx]
35
+ return new_state
36
+
37
+
38
+ @njit(cache=True)
39
+ def hamming_distance(a: np.ndarray, b: np.ndarray) -> int:
40
+ """Distance de Hamming entre deux vecteurs binaires."""
41
+ dist = 0
42
+ for i in range(len(a)):
43
+ dist += a[i] ^ b[i]
44
+ return dist
45
+
46
+
47
+ class InferenceResult:
48
+ """Résultat d'une inférence complète."""
49
+
50
+ def __init__(self):
51
+ self.initial_state: Optional[np.ndarray] = None
52
+ self.final_state: Optional[np.ndarray] = None
53
+ self.energy_trajectory: List[float] = []
54
+ self.neighbor_trajectory: List[List[Tuple[int, float]]] = []
55
+ self.n_iterations = 0
56
+ self.converged = False
57
+ self.creation_events: List[Dict] = []
58
+ self.learning_events: List[Dict] = []
59
+ self.execution_time_ms = 0.0
60
+
61
+
62
+ class InferenceEngine:
63
+ """
64
+ Moteur d'inférence par minimisation d'énergie avec apprentissage en ligne.
65
+
66
+ Paramètres clés:
67
+ - temperature: contrôle le bruit dans la descente (plus haut = plus exploratoire)
68
+ - max_iterations: nombre max d'itérations de minimisation
69
+ - energy_tolerance: seuil de convergence
70
+ - learning_rate: vitesse d'apprentissage pendant l'inférence
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ temperature: float = 0.5,
76
+ cooling_rate: float = 0.995,
77
+ max_iterations: int = 100,
78
+ energy_tolerance: float = 1.0,
79
+ learning_rate: float = 0.01,
80
+ online_learning: bool = True,
81
+ pattern_detection_interval: int = 10,
82
+ convergence_window: int = 5,
83
+ early_stop_threshold: float = 0.001,
84
+ ):
85
+ self.temperature = temperature
86
+ self.cooling_rate = cooling_rate
87
+ self.max_iterations = max_iterations
88
+ self.energy_tolerance = energy_tolerance
89
+ self.learning_rate = learning_rate
90
+ self.online_learning = online_learning
91
+ self.pattern_detection_interval = pattern_detection_interval
92
+ self.convergence_window = convergence_window
93
+ self.early_stop_threshold = early_stop_threshold
94
+
95
+ # Stats
96
+ self.total_inferences = 0
97
+ self.total_iterations = 0
98
+ self.total_converged = 0
99
+ self.avg_inference_time_ms = 0.0
100
+
101
+ def infer(
102
+ self,
103
+ initial_state: np.ndarray,
104
+ memory_table,
105
+ router,
106
+ energy_landscape,
107
+ binder,
108
+ k_neighbors: int = 10,
109
+ external_callback: Optional[Callable] = None,
110
+ ) -> InferenceResult:
111
+ """
112
+ Inférence complète avec minimisation d'énergie et apprentissage en ligne.
113
+
114
+ Args:
115
+ initial_state: état initial du système (4096 bits)
116
+ memory_table: SparseAddressTable
117
+ router: HammingRouter
118
+ energy_landscape: EnergyLandscape
119
+ binder: CircularBinder
120
+ k_neighbors: nombre de voisins à considérer
121
+ external_callback: fonction optionnelle appelée à chaque itération
122
+
123
+ Returns:
124
+ InferenceResult avec trajectoire et événements d'apprentissage
125
+ """
126
+ import time
127
+ t0 = time.time()
128
+
129
+ result = InferenceResult()
130
+ result.initial_state = initial_state.copy()
131
+
132
+ current_state = initial_state.copy()
133
+ temperature = self.temperature
134
+
135
+ prev_energy = float('inf')
136
+ energy_window = []
137
+
138
+ # Trajectoire des états pour détection de motifs
139
+ state_trajectory = [current_state.copy()]
140
+
141
+ for iteration in range(self.max_iterations):
142
+ # 1. Route vers les voisins les plus proches
143
+ neighbors_info = router.route(
144
+ current_state,
145
+ k=k_neighbors,
146
+ use_cache=True
147
+ )
148
+
149
+ if len(neighbors_info) == 0:
150
+ # Pas de voisins : état nouveau, potentiellement créer
151
+ break
152
+
153
+ neighbor_indices = [idx for idx, _ in neighbors_info]
154
+ neighbor_distances = [dist for _, dist in neighbors_info]
155
+
156
+ # Récupère les vecteurs voisins depuis la mémoire
157
+ neighbor_vectors = np.array([
158
+ memory_table.vectors[idx]
159
+ for idx in neighbor_indices
160
+ if memory_table.active_mask[idx]
161
+ ], dtype=np.uint8)
162
+
163
+ if len(neighbor_vectors) == 0:
164
+ break
165
+
166
+ neighbor_ids = [
167
+ memory_table.metadata[idx].id
168
+ for idx in neighbor_indices
169
+ if memory_table.active_mask[idx]
170
+ ]
171
+
172
+ # 2. Calcule l'énergie actuelle
173
+ energy = energy_landscape.compute_energy(
174
+ current_state,
175
+ neighbor_vectors,
176
+ neighbor_ids,
177
+ )
178
+
179
+ result.energy_trajectory.append(energy)
180
+ result.neighbor_trajectory.append(neighbors_info)
181
+
182
+ # 3. Détermine les flips optimaux
183
+ deltas = energy_landscape.get_bit_flip_deltas(
184
+ current_state,
185
+ neighbor_vectors,
186
+ neighbor_ids,
187
+ )
188
+
189
+ # Sélectionne les flips qui réduisent l'énergie
190
+ # avec bruit thermique pour exploration
191
+ flip_probs = np.exp(-deltas / max(temperature, 0.01))
192
+ flip_probs = flip_probs / np.sum(flip_probs)
193
+
194
+ # Choix déterministe + stochastique
195
+ n_candidates = max(1, int(VECTOR_SIZE * 0.005)) # ~20 bits
196
+ top_candidates = np.argsort(-flip_probs)[:n_candidates * 2]
197
+
198
+ # Favorise les flips qui réduisent l'énergie
199
+ beneficial = deltas[top_candidates] < 0
200
+ if np.any(beneficial):
201
+ # Fait tous les flips bénéfiques avec probabilité selon température
202
+ selected = top_candidates[
203
+ np.random.random(len(top_candidates)) < flip_probs[top_candidates]
204
+ ]
205
+ else:
206
+ # Échappement local : flip aléatoire contrôlé
207
+ selected = np.random.choice(
208
+ VECTOR_SIZE,
209
+ size=max(1, int(n_candidates * temperature)),
210
+ replace=False,
211
+ p=flip_probs
212
+ )
213
+
214
+ if len(selected) > 0:
215
+ new_state = current_state.copy()
216
+ new_state[selected] = 1 - new_state[selected]
217
+
218
+ # Calcule la nouvelle énergie
219
+ new_energy = energy_landscape.compute_energy(
220
+ new_state,
221
+ neighbor_vectors,
222
+ neighbor_ids,
223
+ )
224
+
225
+ # Acceptation Metropolis-Hastings
226
+ delta_e = new_energy - energy
227
+ if delta_e < 0 or np.random.random() < np.exp(-delta_e / max(temperature, 0.01)):
228
+ current_state = new_state
229
+ energy = new_energy
230
+
231
+ state_trajectory.append(current_state.copy())
232
+
233
+ # 4. Apprentissage en ligne
234
+ if self.online_learning:
235
+ learning_events = self._online_learning_step(
236
+ current_state,
237
+ neighbor_vectors,
238
+ neighbor_ids,
239
+ neighbor_indices,
240
+ energy,
241
+ iteration,
242
+ memory_table,
243
+ energy_landscape,
244
+ )
245
+ result.learning_events.extend(learning_events)
246
+
247
+ # 5. Détection périodique de motifs pour abstraction
248
+ if iteration > 0 and iteration % self.pattern_detection_interval == 0:
249
+ patterns = memory_table.detect_frequent_patterns(
250
+ [st for st in state_trajectory[-self.pattern_detection_interval:]],
251
+ min_frequency=3
252
+ )
253
+
254
+ for pattern in patterns:
255
+ # Crée une abstraction si le pattern est fréquent
256
+ new_id = memory_table.create_vector(
257
+ context=pattern,
258
+ abstraction_level=1,
259
+ )
260
+ result.creation_events.append({
261
+ 'type': 'abstraction',
262
+ 'id': new_id,
263
+ 'iteration': iteration,
264
+ })
265
+
266
+ # 6. Callback externe
267
+ if external_callback:
268
+ external_callback({
269
+ 'iteration': iteration,
270
+ 'energy': energy,
271
+ 'state': current_state,
272
+ 'neighbors': neighbors_info,
273
+ })
274
+
275
+ # 7. Vérification convergence
276
+ energy_window.append(energy)
277
+ if len(energy_window) > self.convergence_window:
278
+ energy_window.pop(0)
279
+
280
+ if len(energy_window) >= self.convergence_window:
281
+ energy_std = np.std(energy_window)
282
+ energy_mean = np.mean(energy_window)
283
+ if energy_std / max(abs(energy_mean), 1.0) < self.early_stop_threshold:
284
+ result.converged = True
285
+ break
286
+
287
+ # Refroidissement
288
+ temperature *= self.cooling_rate
289
+ prev_energy = energy
290
+
291
+ # Inférence terminée
292
+ result.final_state = current_state.copy()
293
+ result.n_iterations = iteration + 1
294
+
295
+ # Apprentissage post-inférence : renforce les associations
296
+ # si l'inférence a convergé vers un état stable
297
+ if self.online_learning and result.converged:
298
+ self._post_inference_learning(
299
+ result,
300
+ memory_table,
301
+ energy_landscape,
302
+ router,
303
+ )
304
+
305
+ t1 = time.time()
306
+ result.execution_time_ms = (t1 - t0) * 1000
307
+
308
+ # Stats
309
+ self.total_inferences += 1
310
+ self.total_iterations += result.n_iterations
311
+ if result.converged:
312
+ self.total_converged += 1
313
+ self.avg_inference_time_ms = (
314
+ self.avg_inference_time_ms * (self.total_inferences - 1) + result.execution_time_ms
315
+ ) / self.total_inferences
316
+
317
+ return result
318
+
319
+ def _online_learning_step(
320
+ self,
321
+ state: np.ndarray,
322
+ neighbor_vectors: np.ndarray,
323
+ neighbor_ids: List[int],
324
+ neighbor_indices: List[int],
325
+ energy: float,
326
+ iteration: int,
327
+ memory_table,
328
+ energy_landscape,
329
+ ) -> List[Dict]:
330
+ """
331
+ Effectue un pas d'apprentissage pendant l'inférence.
332
+ Mises à jour locales uniquement.
333
+
334
+ Returns:
335
+ Liste d'événements d'apprentissage
336
+ """
337
+ events = []
338
+
339
+ # Met à jour les métadonnées des voisins
340
+ for idx in neighbor_indices:
341
+ if memory_table.active_mask[idx]:
342
+ meta = memory_table.metadata[idx]
343
+ meta.record_access(memory_table.time_step, energy)
344
+
345
+ # Met à jour le paysage d'énergie
346
+ is_stable = iteration > 5 and len(energy_landscape.energy_history) > 10
347
+ energy_landscape.update_from_state(
348
+ state,
349
+ neighbor_ids,
350
+ energy,
351
+ is_stable=is_stable,
352
+ )
353
+
354
+ # Met à jour les coactivations entre voisins
355
+ for i, idx1 in enumerate(neighbor_indices):
356
+ for idx2 in neighbor_indices[i+1:]:
357
+ if memory_table.active_mask[idx1] and memory_table.active_mask[idx2]:
358
+ id1 = memory_table.metadata[idx1].id
359
+ id2 = memory_table.metadata[idx2].id
360
+ # Renforce l'association si coactivation fréquente
361
+ strength = 1.0 / (1.0 + energy / 1000.0)
362
+ memory_table.metadata[idx1].update_coactivation(id2, strength)
363
+ memory_table.metadata[idx2].update_coactivation(id1, strength)
364
+
365
+ # Crée un nouveau vecteur si l'état est suffisamment différent
366
+ # de tous les voisins (configuration récurrente ou nouvelle)
367
+ if iteration > 3:
368
+ min_neighbor_dist = min([
369
+ float(np.sum(state != memory_table.vectors[idx]))
370
+ for idx in neighbor_indices
371
+ if memory_table.active_mask[idx]
372
+ ]) if neighbor_indices else float('inf')
373
+
374
+ if min_neighbor_dist > memory_table.creation_threshold:
375
+ # Nouvelle configuration intéressante
376
+ new_id = memory_table.create_vector(context=state)
377
+ events.append({
378
+ 'type': 'creation',
379
+ 'id': new_id,
380
+ 'reason': 'novel_pattern',
381
+ 'distance': float(min_neighbor_dist),
382
+ 'iteration': iteration,
383
+ })
384
+
385
+ return events
386
+
387
+ def _post_inference_learning(
388
+ self,
389
+ result: InferenceResult,
390
+ memory_table,
391
+ energy_landscape,
392
+ router,
393
+ ):
394
+ """
395
+ Apprentissage après convergence.
396
+ Renforce les associations dans la trajectoire de basse énergie.
397
+ """
398
+ if len(result.neighbor_trajectory) < 3:
399
+ return
400
+
401
+ # Identifie la phase de basse énergie
402
+ energies = np.array(result.energy_trajectory)
403
+ min_energy_idx = int(np.argmin(energies))
404
+
405
+ # Les voisins à ce point sont "la réponse"
406
+ if min_energy_idx < len(result.neighbor_trajectory):
407
+ stable_neighbors = result.neighbor_trajectory[min_energy_idx]
408
+ stable_ids = [nid for nid, _ in stable_neighbors]
409
+
410
+ # Renforce les associations entre voisins stables
411
+ for i, id1 in enumerate(stable_ids):
412
+ for id2 in stable_ids[i+1:]:
413
+ pair = tuple(sorted((id1, id2)))
414
+ current = energy_landscape.associations.get(pair, 0.0)
415
+ energy_landscape.associations[pair] = min(
416
+ 1.0,
417
+ current + self.learning_rate * 2.0
418
+ )
419
+
420
+ # Met à jour le routeur avec les nouvelles routes apprises
421
+ final_state = result.final_state
422
+ final_packed = router.pack_bits_to_uint64(final_state) if hasattr(router, 'pack_bits_to_uint64') else None
423
+ if final_packed is not None:
424
+ ph = router._pattern_hash(final_packed) if hasattr(router, '_pattern_hash') else None
425
+ if ph is not None:
426
+ router.route_cache[ph] = [
427
+ (nid, 1.0 / (1.0 + dist))
428
+ for nid, dist in stable_neighbors
429
+ ]
430
+
431
+ def get_stats(self) -> Dict:
432
+ convergence_rate = (
433
+ self.total_converged / self.total_inferences
434
+ if self.total_inferences > 0 else 0.0
435
+ )
436
+ avg_iterations = (
437
+ self.total_iterations / self.total_inferences
438
+ if self.total_inferences > 0 else 0.0
439
+ )
440
+
441
+ return {
442
+ 'total_inferences': self.total_inferences,
443
+ 'convergence_rate': convergence_rate,
444
+ 'avg_iterations': avg_iterations,
445
+ 'avg_inference_time_ms': self.avg_inference_time_ms,
446
+ }