morpho-logic-engine / mle /inference.py
Harry00's picture
Upload mle/inference.py
22007d1 verified
"""
Moteur d'Inférence avec Apprentissage en Ligne
L'inférence minimise l'énergie par descente stochastique locale.
À chaque itération :
1. Calcule les voisins via le routeur
2. Évalue l'énergie du paysage
3. Sélectionne les flips de bits qui réduisent l'énergie
4. Met à jour les associations (apprentissage en ligne)
5. Détecte motifs pour abstraction
La minimisation est un processus de Monte Carlo / Hopfield-like
mais avec mémoire adaptative et apprentissage continu.
"""
import numpy as np
from numba import njit, prange
from typing import List, Tuple, Dict, Optional, Callable
import logging
import time
logger = logging.getLogger(__name__)
VECTOR_SIZE = 4096
@njit(cache=True)
def random_flip_batch(state: np.ndarray, n_flips: int, rng_seed: int) -> np.ndarray:
"""Flip aléatoire de n_flips bits."""
np.random.seed(rng_seed)
new_state = state.copy()
flip_indices = np.random.choice(VECTOR_SIZE, size=n_flips, replace=False)
for idx in flip_indices:
new_state[idx] = 1 - new_state[idx]
return new_state
@njit(cache=True)
def hamming_distance(a: np.ndarray, b: np.ndarray) -> int:
"""Distance de Hamming entre deux vecteurs binaires."""
dist = 0
for i in range(len(a)):
dist += a[i] ^ b[i]
return dist
class InferenceResult:
"""Résultat d'une inférence complète."""
def __init__(self):
self.initial_state: Optional[np.ndarray] = None
self.final_state: Optional[np.ndarray] = None
self.energy_trajectory: List[float] = []
self.neighbor_trajectory: List[List[Tuple[int, float]]] = []
self.n_iterations = 0
self.converged = False
self.creation_events: List[Dict] = []
self.learning_events: List[Dict] = []
self.execution_time_ms = 0.0
class InferenceEngine:
"""
Moteur d'inférence par minimisation d'énergie avec apprentissage en ligne.
Paramètres clés:
- temperature: contrôle le bruit dans la descente (plus haut = plus exploratoire)
- max_iterations: nombre max d'itérations de minimisation
- energy_tolerance: seuil de convergence
- learning_rate: vitesse d'apprentissage pendant l'inférence
"""
def __init__(
self,
temperature: float = 0.5,
cooling_rate: float = 0.995,
max_iterations: int = 100,
energy_tolerance: float = 1.0,
learning_rate: float = 0.01,
online_learning: bool = True,
pattern_detection_interval: int = 10,
convergence_window: int = 5,
early_stop_threshold: float = 0.001,
):
self.temperature = temperature
self.cooling_rate = cooling_rate
self.max_iterations = max_iterations
self.energy_tolerance = energy_tolerance
self.learning_rate = learning_rate
self.online_learning = online_learning
self.pattern_detection_interval = pattern_detection_interval
self.convergence_window = convergence_window
self.early_stop_threshold = early_stop_threshold
# Stats
self.total_inferences = 0
self.total_iterations = 0
self.total_converged = 0
self.avg_inference_time_ms = 0.0
def infer(
self,
initial_state: np.ndarray,
memory_table,
router,
energy_landscape,
binder,
k_neighbors: int = 10,
external_callback: Optional[Callable] = None,
) -> InferenceResult:
"""
Inférence complète avec minimisation d'énergie et apprentissage en ligne.
Args:
initial_state: état initial du système (4096 bits)
memory_table: SparseAddressTable
router: HammingRouter
energy_landscape: EnergyLandscape
binder: CircularBinder
k_neighbors: nombre de voisins à considérer
external_callback: fonction optionnelle appelée à chaque itération
Returns:
InferenceResult avec trajectoire et événements d'apprentissage
"""
import time
t0 = time.time()
result = InferenceResult()
result.initial_state = initial_state.copy()
current_state = initial_state.copy()
temperature = self.temperature
prev_energy = float('inf')
energy_window = []
# Trajectoire des états pour détection de motifs
state_trajectory = [current_state.copy()]
for iteration in range(self.max_iterations):
# 1. Route vers les voisins les plus proches
neighbors_info = router.route(
current_state,
k=k_neighbors,
use_cache=True
)
if len(neighbors_info) == 0:
# Pas de voisins : état nouveau, potentiellement créer
break
neighbor_indices = [idx for idx, _ in neighbors_info]
neighbor_distances = [dist for _, dist in neighbors_info]
# Récupère les vecteurs voisins depuis la mémoire
neighbor_vectors = np.array([
memory_table.vectors[idx]
for idx in neighbor_indices
if memory_table.active_mask[idx]
], dtype=np.uint8)
if len(neighbor_vectors) == 0:
break
neighbor_ids = [
memory_table.metadata[idx].id
for idx in neighbor_indices
if memory_table.active_mask[idx]
]
# 2. Calcule l'énergie actuelle
energy = energy_landscape.compute_energy(
current_state,
neighbor_vectors,
neighbor_ids,
)
result.energy_trajectory.append(energy)
result.neighbor_trajectory.append(neighbors_info)
# 3. Détermine les flips optimaux
deltas = energy_landscape.get_bit_flip_deltas(
current_state,
neighbor_vectors,
neighbor_ids,
)
# Sélectionne les flips qui réduisent l'énergie
# avec bruit thermique pour exploration
flip_probs = np.exp(-deltas / max(temperature, 0.01))
flip_probs = flip_probs / np.sum(flip_probs)
# Choix déterministe + stochastique
n_candidates = max(1, int(VECTOR_SIZE * 0.005)) # ~20 bits
top_candidates = np.argsort(-flip_probs)[:n_candidates * 2]
# Favorise les flips qui réduisent l'énergie
beneficial = deltas[top_candidates] < 0
if np.any(beneficial):
# Fait tous les flips bénéfiques avec probabilité selon température
selected = top_candidates[
np.random.random(len(top_candidates)) < flip_probs[top_candidates]
]
else:
# Échappement local : flip aléatoire contrôlé
selected = np.random.choice(
VECTOR_SIZE,
size=max(1, int(n_candidates * temperature)),
replace=False,
p=flip_probs
)
if len(selected) > 0:
new_state = current_state.copy()
new_state[selected] = 1 - new_state[selected]
# Calcule la nouvelle énergie
new_energy = energy_landscape.compute_energy(
new_state,
neighbor_vectors,
neighbor_ids,
)
# Acceptation Metropolis-Hastings
delta_e = new_energy - energy
if delta_e < 0 or np.random.random() < np.exp(-delta_e / max(temperature, 0.01)):
current_state = new_state
energy = new_energy
state_trajectory.append(current_state.copy())
# 4. Apprentissage en ligne
if self.online_learning:
learning_events = self._online_learning_step(
current_state,
neighbor_vectors,
neighbor_ids,
neighbor_indices,
energy,
iteration,
memory_table,
energy_landscape,
)
result.learning_events.extend(learning_events)
# 5. Détection périodique de motifs pour abstraction
if iteration > 0 and iteration % self.pattern_detection_interval == 0:
patterns = memory_table.detect_frequent_patterns(
[st for st in state_trajectory[-self.pattern_detection_interval:]],
min_frequency=3
)
for pattern in patterns:
# Crée une abstraction si le pattern est fréquent
new_id = memory_table.create_vector(
context=pattern,
abstraction_level=1,
)
result.creation_events.append({
'type': 'abstraction',
'id': new_id,
'iteration': iteration,
})
# 6. Callback externe
if external_callback:
external_callback({
'iteration': iteration,
'energy': energy,
'state': current_state,
'neighbors': neighbors_info,
})
# 7. Vérification convergence
energy_window.append(energy)
if len(energy_window) > self.convergence_window:
energy_window.pop(0)
if len(energy_window) >= self.convergence_window:
energy_std = np.std(energy_window)
energy_mean = np.mean(energy_window)
if energy_std / max(abs(energy_mean), 1.0) < self.early_stop_threshold:
result.converged = True
break
# Refroidissement
temperature *= self.cooling_rate
prev_energy = energy
# Inférence terminée
result.final_state = current_state.copy()
result.n_iterations = iteration + 1
# Apprentissage post-inférence : renforce les associations
# si l'inférence a convergé vers un état stable
if self.online_learning and result.converged:
self._post_inference_learning(
result,
memory_table,
energy_landscape,
router,
)
t1 = time.time()
result.execution_time_ms = (t1 - t0) * 1000
# Stats
self.total_inferences += 1
self.total_iterations += result.n_iterations
if result.converged:
self.total_converged += 1
self.avg_inference_time_ms = (
self.avg_inference_time_ms * (self.total_inferences - 1) + result.execution_time_ms
) / self.total_inferences
return result
def _online_learning_step(
self,
state: np.ndarray,
neighbor_vectors: np.ndarray,
neighbor_ids: List[int],
neighbor_indices: List[int],
energy: float,
iteration: int,
memory_table,
energy_landscape,
) -> List[Dict]:
"""
Effectue un pas d'apprentissage pendant l'inférence.
Mises à jour locales uniquement.
Returns:
Liste d'événements d'apprentissage
"""
events = []
# Met à jour les métadonnées des voisins
for idx in neighbor_indices:
if memory_table.active_mask[idx]:
meta = memory_table.metadata[idx]
meta.record_access(memory_table.time_step, energy)
# Met à jour le paysage d'énergie
is_stable = iteration > 5 and len(energy_landscape.energy_history) > 10
energy_landscape.update_from_state(
state,
neighbor_ids,
energy,
is_stable=is_stable,
)
# Met à jour les coactivations entre voisins
for i, idx1 in enumerate(neighbor_indices):
for idx2 in neighbor_indices[i+1:]:
if memory_table.active_mask[idx1] and memory_table.active_mask[idx2]:
id1 = memory_table.metadata[idx1].id
id2 = memory_table.metadata[idx2].id
# Renforce l'association si coactivation fréquente
strength = 1.0 / (1.0 + energy / 1000.0)
memory_table.metadata[idx1].update_coactivation(id2, strength)
memory_table.metadata[idx2].update_coactivation(id1, strength)
# Crée un nouveau vecteur si l'état est suffisamment différent
# de tous les voisins (configuration récurrente ou nouvelle)
if iteration > 3:
min_neighbor_dist = min([
float(np.sum(state != memory_table.vectors[idx]))
for idx in neighbor_indices
if memory_table.active_mask[idx]
]) if neighbor_indices else float('inf')
if min_neighbor_dist > memory_table.creation_threshold:
# Nouvelle configuration intéressante
new_id = memory_table.create_vector(context=state)
events.append({
'type': 'creation',
'id': new_id,
'reason': 'novel_pattern',
'distance': float(min_neighbor_dist),
'iteration': iteration,
})
return events
def _post_inference_learning(
self,
result: InferenceResult,
memory_table,
energy_landscape,
router,
):
"""
Apprentissage après convergence.
Renforce les associations dans la trajectoire de basse énergie.
"""
if len(result.neighbor_trajectory) < 3:
return
# Identifie la phase de basse énergie
energies = np.array(result.energy_trajectory)
min_energy_idx = int(np.argmin(energies))
# Les voisins à ce point sont "la réponse"
if min_energy_idx < len(result.neighbor_trajectory):
stable_neighbors = result.neighbor_trajectory[min_energy_idx]
stable_ids = [nid for nid, _ in stable_neighbors]
# Renforce les associations entre voisins stables
for i, id1 in enumerate(stable_ids):
for id2 in stable_ids[i+1:]:
pair = tuple(sorted((id1, id2)))
current = energy_landscape.associations.get(pair, 0.0)
energy_landscape.associations[pair] = min(
1.0,
current + self.learning_rate * 2.0
)
# Met à jour le routeur avec les nouvelles routes apprises
final_state = result.final_state
final_packed = router.pack_bits_to_uint64(final_state) if hasattr(router, 'pack_bits_to_uint64') else None
if final_packed is not None:
ph = router._pattern_hash(final_packed) if hasattr(router, '_pattern_hash') else None
if ph is not None:
router.route_cache[ph] = [
(nid, 1.0 / (1.0 + dist))
for nid, dist in stable_neighbors
]
def get_stats(self) -> Dict:
convergence_rate = (
self.total_converged / self.total_inferences
if self.total_inferences > 0 else 0.0
)
avg_iterations = (
self.total_iterations / self.total_inferences
if self.total_inferences > 0 else 0.0
)
return {
'total_inferences': self.total_inferences,
'convergence_rate': convergence_rate,
'avg_iterations': avg_iterations,
'avg_inference_time_ms': self.avg_inference_time_ms,
}