Upload mle/binding.py
Browse files- mle/binding.py +237 -0
mle/binding.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Binding via Convolution Circulaire
|
| 3 |
+
|
| 4 |
+
Le binding est une opération fondamentale pour composer et décomposer des
|
| 5 |
+
représentations distribuées. On utilise la convolution circulaire qui :
|
| 6 |
+
- Est commutative dans le domaine fréquentiel (FFT)
|
| 7 |
+
- Permet binding/unbinding par multiplication/déconvolution
|
| 8 |
+
- Supporte la composition de structures
|
| 9 |
+
|
| 10 |
+
Optimisé avec numpy FFT et support pour binding chaîné.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from numba import njit, prange, complex128, float64
|
| 15 |
+
from typing import List, Tuple, Optional, Dict
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
VECTOR_SIZE = 4096
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CircularBinder:
|
| 24 |
+
"""
|
| 25 |
+
Binder circulaire pour composition/décomposition de vecteurs.
|
| 26 |
+
|
| 27 |
+
Le binding utilise la convolution circulaire qui peut être calculée
|
| 28 |
+
efficacement via FFT :
|
| 29 |
+
a ⊗ b = IFFT(FFT(a) * FFT(b))
|
| 30 |
+
|
| 31 |
+
Avantages :
|
| 32 |
+
- Commutatif : a ⊗ b = b ⊗ a
|
| 33 |
+
- Associatif pour composition chaînée
|
| 34 |
+
- Unbinding par déconvolution (division dans le domaine fréquentiel)
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, noise_tolerance: float = 0.1):
|
| 38 |
+
self.noise_tolerance = noise_tolerance
|
| 39 |
+
|
| 40 |
+
# Cache FFT pour binding chaîné
|
| 41 |
+
self.fft_cache: Dict[int, np.ndarray] = {}
|
| 42 |
+
self.cache_limit = 1000
|
| 43 |
+
|
| 44 |
+
# Stats
|
| 45 |
+
self.bind_count = 0
|
| 46 |
+
self.unbind_count = 0
|
| 47 |
+
|
| 48 |
+
def _normalize(self, vec: np.ndarray) -> np.ndarray:
|
| 49 |
+
"""
|
| 50 |
+
Normalise un vecteur pour le binding.
|
| 51 |
+
Convertit binaire en {-1, 1} pour la convolution.
|
| 52 |
+
"""
|
| 53 |
+
# Binaire (0, 1) -> (-1, 1)
|
| 54 |
+
return 2.0 * vec.astype(np.float64) - 1.0
|
| 55 |
+
|
| 56 |
+
def _denormalize(self, vec: np.ndarray, threshold: float = 0.0) -> np.ndarray:
|
| 57 |
+
"""Reconvertit en binaire."""
|
| 58 |
+
return (vec > threshold).astype(np.uint8)
|
| 59 |
+
|
| 60 |
+
def bind(self, a: np.ndarray, b: np.ndarray, use_fft: bool = True) -> np.ndarray:
|
| 61 |
+
"""
|
| 62 |
+
Binding de deux vecteurs : a ⊗ b.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
a, b: vecteurs binaires (4096,) uint8
|
| 66 |
+
use_fft: utilise FFT pour efficacité
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
vecteur binaire résultant
|
| 70 |
+
"""
|
| 71 |
+
if use_fft:
|
| 72 |
+
# FFT-based convolution
|
| 73 |
+
a_norm = self._normalize(a)
|
| 74 |
+
b_norm = self._normalize(b)
|
| 75 |
+
|
| 76 |
+
fft_a = np.fft.fft(a_norm)
|
| 77 |
+
fft_b = np.fft.fft(b_norm)
|
| 78 |
+
|
| 79 |
+
fft_result = fft_a * fft_b
|
| 80 |
+
result = np.fft.ifft(fft_result).real
|
| 81 |
+
|
| 82 |
+
# Normalise et seuille
|
| 83 |
+
result = result / (np.std(result) + 1e-8)
|
| 84 |
+
return self._denormalize(result)
|
| 85 |
+
else:
|
| 86 |
+
# Convolution directe (plus lente mais pas de distorsion FFT)
|
| 87 |
+
a_norm = self._normalize(a).astype(np.float64)
|
| 88 |
+
b_norm = self._normalize(b).astype(np.float64)
|
| 89 |
+
result = np.convolve(a_norm, b_norm, mode='same')
|
| 90 |
+
result = result / (np.std(result) + 1e-8)
|
| 91 |
+
return self._denormalize(result)
|
| 92 |
+
|
| 93 |
+
def bind_multiple(self, vectors: List[np.ndarray], use_cache: bool = True) -> np.ndarray:
|
| 94 |
+
"""
|
| 95 |
+
Binding de plusieurs vecteurs : a ⊗ b ⊗ c ⊗ ...
|
| 96 |
+
Utilise le cache FFT pour éviter de recalculer les FFT.
|
| 97 |
+
"""
|
| 98 |
+
if len(vectors) == 0:
|
| 99 |
+
return np.ones(VECTOR_SIZE, dtype=np.uint8) # Identité du binding
|
| 100 |
+
|
| 101 |
+
if len(vectors) == 1:
|
| 102 |
+
return vectors[0].copy()
|
| 103 |
+
|
| 104 |
+
# Cache les FFT si demandé
|
| 105 |
+
if use_cache:
|
| 106 |
+
ids = [id(v) for v in vectors]
|
| 107 |
+
fft_vec = np.ones(VECTOR_SIZE, dtype=np.complex128)
|
| 108 |
+
|
| 109 |
+
for vid, v in zip(ids, vectors):
|
| 110 |
+
if vid in self.fft_cache:
|
| 111 |
+
fft_v = self.fft_cache[vid]
|
| 112 |
+
else:
|
| 113 |
+
v_norm = self._normalize(v)
|
| 114 |
+
fft_v = np.fft.fft(v_norm)
|
| 115 |
+
if len(self.fft_cache) < self.cache_limit:
|
| 116 |
+
self.fft_cache[vid] = fft_v
|
| 117 |
+
fft_vec *= fft_v
|
| 118 |
+
|
| 119 |
+
result = np.fft.ifft(fft_vec).real
|
| 120 |
+
result = result / (np.std(result) + 1e-8)
|
| 121 |
+
return self._denormalize(result)
|
| 122 |
+
else:
|
| 123 |
+
# Bind séquentiel
|
| 124 |
+
result = vectors[0].copy()
|
| 125 |
+
for v in vectors[1:]:
|
| 126 |
+
result = self.bind(result, v)
|
| 127 |
+
return result
|
| 128 |
+
|
| 129 |
+
def unbind(self, bound: np.ndarray, a: np.ndarray, use_fft: bool = True) -> np.ndarray:
|
| 130 |
+
"""
|
| 131 |
+
Déconvolution : résout bound = a ⊗ b pour trouver b.
|
| 132 |
+
|
| 133 |
+
Dans le domaine fréquentiel : b = IFFT(FFT(bound) / FFT(a))
|
| 134 |
+
"""
|
| 135 |
+
if use_fft:
|
| 136 |
+
bound_norm = self._normalize(bound)
|
| 137 |
+
a_norm = self._normalize(a)
|
| 138 |
+
|
| 139 |
+
fft_bound = np.fft.fft(bound_norm)
|
| 140 |
+
fft_a = np.fft.fft(a_norm)
|
| 141 |
+
|
| 142 |
+
# Évite division par zéro avec tolérance
|
| 143 |
+
fft_a = np.where(np.abs(fft_a) < 1e-8, 1e-8, fft_a)
|
| 144 |
+
fft_result = fft_bound / fft_a
|
| 145 |
+
|
| 146 |
+
result = np.fft.ifft(fft_result).real
|
| 147 |
+
result = result / (np.std(result) + 1e-8)
|
| 148 |
+
return self._denormalize(result)
|
| 149 |
+
else:
|
| 150 |
+
# Déconvolution dans le domaine temporel (plus stable)
|
| 151 |
+
bound_norm = self._normalize(bound).astype(np.float64)
|
| 152 |
+
a_norm = self._normalize(a).astype(np.float64)
|
| 153 |
+
|
| 154 |
+
# Corrélation comme approximation de déconvolution
|
| 155 |
+
result = np.correlate(bound_norm, a_norm, mode='same')
|
| 156 |
+
result = result / (np.std(result) + 1e-8)
|
| 157 |
+
return self._denormalize(result)
|
| 158 |
+
|
| 159 |
+
def bind_role_filler(self, role: np.ndarray, filler: np.ndarray) -> np.ndarray:
|
| 160 |
+
"""
|
| 161 |
+
Binding spécialisé role-filler (structure propositionnelle).
|
| 162 |
+
Utile pour représenter "sujet-agent", "objet-patient", etc.
|
| 163 |
+
"""
|
| 164 |
+
# Shift circulaire du filler pour éviter collision avec le role
|
| 165 |
+
shifted_filler = np.roll(filler, VECTOR_SIZE // 4)
|
| 166 |
+
return self.bind(role, shifted_filler)
|
| 167 |
+
|
| 168 |
+
def unbind_role_filler(self, bound: np.ndarray, role: np.ndarray) -> np.ndarray:
|
| 169 |
+
"""Extrait le filler d'un binding role-filler."""
|
| 170 |
+
unshifted = self.unbind(bound, role)
|
| 171 |
+
return np.roll(unshifted, -VECTOR_SIZE // 4)
|
| 172 |
+
|
| 173 |
+
def extract_similar(
|
| 174 |
+
self,
|
| 175 |
+
bound: np.ndarray,
|
| 176 |
+
candidates: List[np.ndarray],
|
| 177 |
+
top_k: int = 3
|
| 178 |
+
) -> List[Tuple[np.ndarray, float]]:
|
| 179 |
+
"""
|
| 180 |
+
Extrait les candidats les plus similaires à partir d'un vecteur bound.
|
| 181 |
+
Utile pour "décoder" une structure composée.
|
| 182 |
+
"""
|
| 183 |
+
results = []
|
| 184 |
+
for cand in candidates:
|
| 185 |
+
# Déconvolution
|
| 186 |
+
decoded = self.unbind(bound, cand)
|
| 187 |
+
# Similarité avec chaque candidat
|
| 188 |
+
best_sim = 0.0
|
| 189 |
+
for other in candidates:
|
| 190 |
+
sim = np.mean(decoded == other)
|
| 191 |
+
if sim > best_sim:
|
| 192 |
+
best_sim = sim
|
| 193 |
+
results.append((cand, best_sim))
|
| 194 |
+
|
| 195 |
+
results.sort(key=lambda x: -x[1])
|
| 196 |
+
return results[:top_k]
|
| 197 |
+
|
| 198 |
+
def compose_structure(
|
| 199 |
+
self,
|
| 200 |
+
role_filler_pairs: List[Tuple[np.ndarray, np.ndarray]]
|
| 201 |
+
) -> np.ndarray:
|
| 202 |
+
"""
|
| 203 |
+
Compose une structure complète à partir de paires role-filler.
|
| 204 |
+
Ex: [(agent, john), (action, run), (patient, ball)]
|
| 205 |
+
"""
|
| 206 |
+
bound_pairs = []
|
| 207 |
+
for role, filler in role_filler_pairs:
|
| 208 |
+
bound_pairs.append(self.bind_role_filler(role, filler))
|
| 209 |
+
|
| 210 |
+
# Somme des bindings (superposition)
|
| 211 |
+
result = np.zeros(VECTOR_SIZE, dtype=np.float64)
|
| 212 |
+
for bp in bound_pairs:
|
| 213 |
+
result += self._normalize(bp)
|
| 214 |
+
|
| 215 |
+
result = result / (np.std(result) + 1e-8)
|
| 216 |
+
return self._denormalize(result)
|
| 217 |
+
|
| 218 |
+
def decompose_structure(
|
| 219 |
+
self,
|
| 220 |
+
composite: np.ndarray,
|
| 221 |
+
roles: List[np.ndarray]
|
| 222 |
+
) -> List[np.ndarray]:
|
| 223 |
+
"""
|
| 224 |
+
Décompose une structure en ses fillers.
|
| 225 |
+
"""
|
| 226 |
+
fillers = []
|
| 227 |
+
for role in roles:
|
| 228 |
+
filler = self.unbind_role_filler(composite, role)
|
| 229 |
+
fillers.append(filler)
|
| 230 |
+
return fillers
|
| 231 |
+
|
| 232 |
+
def get_stats(self) -> Dict:
|
| 233 |
+
return {
|
| 234 |
+
'bind_count': self.bind_count,
|
| 235 |
+
'unbind_count': self.unbind_count,
|
| 236 |
+
'fft_cache_size': len(self.fft_cache),
|
| 237 |
+
}
|