Harry00 commited on
Commit
bd05e61
·
verified ·
1 Parent(s): b4b14c2

Upload mle/binding.py

Browse files
Files changed (1) hide show
  1. mle/binding.py +237 -0
mle/binding.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Binding via Convolution Circulaire
3
+
4
+ Le binding est une opération fondamentale pour composer et décomposer des
5
+ représentations distribuées. On utilise la convolution circulaire qui :
6
+ - Est commutative dans le domaine fréquentiel (FFT)
7
+ - Permet binding/unbinding par multiplication/déconvolution
8
+ - Supporte la composition de structures
9
+
10
+ Optimisé avec numpy FFT et support pour binding chaîné.
11
+ """
12
+
13
+ import numpy as np
14
+ from numba import njit, prange, complex128, float64
15
+ from typing import List, Tuple, Optional, Dict
16
+ import logging
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ VECTOR_SIZE = 4096
21
+
22
+
23
+ class CircularBinder:
24
+ """
25
+ Binder circulaire pour composition/décomposition de vecteurs.
26
+
27
+ Le binding utilise la convolution circulaire qui peut être calculée
28
+ efficacement via FFT :
29
+ a ⊗ b = IFFT(FFT(a) * FFT(b))
30
+
31
+ Avantages :
32
+ - Commutatif : a ⊗ b = b ⊗ a
33
+ - Associatif pour composition chaînée
34
+ - Unbinding par déconvolution (division dans le domaine fréquentiel)
35
+ """
36
+
37
+ def __init__(self, noise_tolerance: float = 0.1):
38
+ self.noise_tolerance = noise_tolerance
39
+
40
+ # Cache FFT pour binding chaîné
41
+ self.fft_cache: Dict[int, np.ndarray] = {}
42
+ self.cache_limit = 1000
43
+
44
+ # Stats
45
+ self.bind_count = 0
46
+ self.unbind_count = 0
47
+
48
+ def _normalize(self, vec: np.ndarray) -> np.ndarray:
49
+ """
50
+ Normalise un vecteur pour le binding.
51
+ Convertit binaire en {-1, 1} pour la convolution.
52
+ """
53
+ # Binaire (0, 1) -> (-1, 1)
54
+ return 2.0 * vec.astype(np.float64) - 1.0
55
+
56
+ def _denormalize(self, vec: np.ndarray, threshold: float = 0.0) -> np.ndarray:
57
+ """Reconvertit en binaire."""
58
+ return (vec > threshold).astype(np.uint8)
59
+
60
+ def bind(self, a: np.ndarray, b: np.ndarray, use_fft: bool = True) -> np.ndarray:
61
+ """
62
+ Binding de deux vecteurs : a ⊗ b.
63
+
64
+ Args:
65
+ a, b: vecteurs binaires (4096,) uint8
66
+ use_fft: utilise FFT pour efficacité
67
+
68
+ Returns:
69
+ vecteur binaire résultant
70
+ """
71
+ if use_fft:
72
+ # FFT-based convolution
73
+ a_norm = self._normalize(a)
74
+ b_norm = self._normalize(b)
75
+
76
+ fft_a = np.fft.fft(a_norm)
77
+ fft_b = np.fft.fft(b_norm)
78
+
79
+ fft_result = fft_a * fft_b
80
+ result = np.fft.ifft(fft_result).real
81
+
82
+ # Normalise et seuille
83
+ result = result / (np.std(result) + 1e-8)
84
+ return self._denormalize(result)
85
+ else:
86
+ # Convolution directe (plus lente mais pas de distorsion FFT)
87
+ a_norm = self._normalize(a).astype(np.float64)
88
+ b_norm = self._normalize(b).astype(np.float64)
89
+ result = np.convolve(a_norm, b_norm, mode='same')
90
+ result = result / (np.std(result) + 1e-8)
91
+ return self._denormalize(result)
92
+
93
+ def bind_multiple(self, vectors: List[np.ndarray], use_cache: bool = True) -> np.ndarray:
94
+ """
95
+ Binding de plusieurs vecteurs : a ⊗ b ⊗ c ⊗ ...
96
+ Utilise le cache FFT pour éviter de recalculer les FFT.
97
+ """
98
+ if len(vectors) == 0:
99
+ return np.ones(VECTOR_SIZE, dtype=np.uint8) # Identité du binding
100
+
101
+ if len(vectors) == 1:
102
+ return vectors[0].copy()
103
+
104
+ # Cache les FFT si demandé
105
+ if use_cache:
106
+ ids = [id(v) for v in vectors]
107
+ fft_vec = np.ones(VECTOR_SIZE, dtype=np.complex128)
108
+
109
+ for vid, v in zip(ids, vectors):
110
+ if vid in self.fft_cache:
111
+ fft_v = self.fft_cache[vid]
112
+ else:
113
+ v_norm = self._normalize(v)
114
+ fft_v = np.fft.fft(v_norm)
115
+ if len(self.fft_cache) < self.cache_limit:
116
+ self.fft_cache[vid] = fft_v
117
+ fft_vec *= fft_v
118
+
119
+ result = np.fft.ifft(fft_vec).real
120
+ result = result / (np.std(result) + 1e-8)
121
+ return self._denormalize(result)
122
+ else:
123
+ # Bind séquentiel
124
+ result = vectors[0].copy()
125
+ for v in vectors[1:]:
126
+ result = self.bind(result, v)
127
+ return result
128
+
129
+ def unbind(self, bound: np.ndarray, a: np.ndarray, use_fft: bool = True) -> np.ndarray:
130
+ """
131
+ Déconvolution : résout bound = a ⊗ b pour trouver b.
132
+
133
+ Dans le domaine fréquentiel : b = IFFT(FFT(bound) / FFT(a))
134
+ """
135
+ if use_fft:
136
+ bound_norm = self._normalize(bound)
137
+ a_norm = self._normalize(a)
138
+
139
+ fft_bound = np.fft.fft(bound_norm)
140
+ fft_a = np.fft.fft(a_norm)
141
+
142
+ # Évite division par zéro avec tolérance
143
+ fft_a = np.where(np.abs(fft_a) < 1e-8, 1e-8, fft_a)
144
+ fft_result = fft_bound / fft_a
145
+
146
+ result = np.fft.ifft(fft_result).real
147
+ result = result / (np.std(result) + 1e-8)
148
+ return self._denormalize(result)
149
+ else:
150
+ # Déconvolution dans le domaine temporel (plus stable)
151
+ bound_norm = self._normalize(bound).astype(np.float64)
152
+ a_norm = self._normalize(a).astype(np.float64)
153
+
154
+ # Corrélation comme approximation de déconvolution
155
+ result = np.correlate(bound_norm, a_norm, mode='same')
156
+ result = result / (np.std(result) + 1e-8)
157
+ return self._denormalize(result)
158
+
159
+ def bind_role_filler(self, role: np.ndarray, filler: np.ndarray) -> np.ndarray:
160
+ """
161
+ Binding spécialisé role-filler (structure propositionnelle).
162
+ Utile pour représenter "sujet-agent", "objet-patient", etc.
163
+ """
164
+ # Shift circulaire du filler pour éviter collision avec le role
165
+ shifted_filler = np.roll(filler, VECTOR_SIZE // 4)
166
+ return self.bind(role, shifted_filler)
167
+
168
+ def unbind_role_filler(self, bound: np.ndarray, role: np.ndarray) -> np.ndarray:
169
+ """Extrait le filler d'un binding role-filler."""
170
+ unshifted = self.unbind(bound, role)
171
+ return np.roll(unshifted, -VECTOR_SIZE // 4)
172
+
173
+ def extract_similar(
174
+ self,
175
+ bound: np.ndarray,
176
+ candidates: List[np.ndarray],
177
+ top_k: int = 3
178
+ ) -> List[Tuple[np.ndarray, float]]:
179
+ """
180
+ Extrait les candidats les plus similaires à partir d'un vecteur bound.
181
+ Utile pour "décoder" une structure composée.
182
+ """
183
+ results = []
184
+ for cand in candidates:
185
+ # Déconvolution
186
+ decoded = self.unbind(bound, cand)
187
+ # Similarité avec chaque candidat
188
+ best_sim = 0.0
189
+ for other in candidates:
190
+ sim = np.mean(decoded == other)
191
+ if sim > best_sim:
192
+ best_sim = sim
193
+ results.append((cand, best_sim))
194
+
195
+ results.sort(key=lambda x: -x[1])
196
+ return results[:top_k]
197
+
198
+ def compose_structure(
199
+ self,
200
+ role_filler_pairs: List[Tuple[np.ndarray, np.ndarray]]
201
+ ) -> np.ndarray:
202
+ """
203
+ Compose une structure complète à partir de paires role-filler.
204
+ Ex: [(agent, john), (action, run), (patient, ball)]
205
+ """
206
+ bound_pairs = []
207
+ for role, filler in role_filler_pairs:
208
+ bound_pairs.append(self.bind_role_filler(role, filler))
209
+
210
+ # Somme des bindings (superposition)
211
+ result = np.zeros(VECTOR_SIZE, dtype=np.float64)
212
+ for bp in bound_pairs:
213
+ result += self._normalize(bp)
214
+
215
+ result = result / (np.std(result) + 1e-8)
216
+ return self._denormalize(result)
217
+
218
+ def decompose_structure(
219
+ self,
220
+ composite: np.ndarray,
221
+ roles: List[np.ndarray]
222
+ ) -> List[np.ndarray]:
223
+ """
224
+ Décompose une structure en ses fillers.
225
+ """
226
+ fillers = []
227
+ for role in roles:
228
+ filler = self.unbind_role_filler(composite, role)
229
+ fillers.append(filler)
230
+ return fillers
231
+
232
+ def get_stats(self) -> Dict:
233
+ return {
234
+ 'bind_count': self.bind_count,
235
+ 'unbind_count': self.unbind_count,
236
+ 'fft_cache_size': len(self.fft_cache),
237
+ }