eyad-silx commited on
Commit
f0a4dca
·
verified ·
1 Parent(s): c536443

Upload neat\backprop_neat.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. neat//backprop_neat.py +300 -0
neat//backprop_neat.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BackpropNEAT implementation."""
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import numpy as np
6
+ from typing import Dict, List, Tuple
7
+
8
+ from .network import Network
9
+ from .genome import Genome
10
+
11
+ class BackpropNEAT:
12
+ """Backpropagation-based NEAT implementation."""
13
+
14
+ def __init__(self, population_size=5, n_inputs=2, n_outputs=1, n_hidden=64,
15
+ learning_rate=0.01, beta=0.9):
16
+ """Initialize BackpropNEAT."""
17
+ self.population_size = population_size
18
+ self.n_inputs = n_inputs
19
+ self.n_outputs = n_outputs
20
+ self.n_hidden = n_hidden
21
+ self.learning_rate = learning_rate
22
+ self.beta = beta
23
+
24
+ # Initialize population
25
+ self.population = []
26
+ self.momentum_buffers = []
27
+
28
+ for _ in range(population_size):
29
+ # Create genome with skip connections
30
+ genome = Genome(n_inputs, n_outputs, n_hidden)
31
+ genome.add_layer_connections() # Add standard layer connections
32
+ genome.add_skip_connections(0.3) # Add skip connections with 30% probability
33
+
34
+ # Create network from genome
35
+ network = Network(genome)
36
+ self.population.append(network)
37
+
38
+ # Initialize momentum buffer for this network
39
+ momentum = {
40
+ 'weights': {k: jnp.zeros_like(w) for k, w in network.params['weights'].items()},
41
+ 'biases': jnp.zeros_like(network.params['biases']),
42
+ 'gamma': jnp.zeros_like(network.params['gamma']),
43
+ 'beta': jnp.zeros_like(network.params['beta'])
44
+ }
45
+ self.momentum_buffers.append(momentum)
46
+
47
+ # Create train step function
48
+ self._train_step = self._make_train_step()
49
+
50
+ # Bind train step to each network
51
+ for i, network in enumerate(self.population):
52
+ network.population_idx = i
53
+ # Create a bound method for each network
54
+ network._train_step = lambda p, x, y, idx=i: self._train_step(self, p, x, y, idx)
55
+
56
+ def forward(self, params, x):
57
+ """Forward pass through network."""
58
+ return self.population[0].forward(params, x)
59
+
60
+ def _make_train_step(self):
61
+ """Create training step function."""
62
+ # Constants for numerical stability
63
+ eps = 1e-7
64
+ min_lr = 1e-6
65
+ max_lr = 1e-2
66
+
67
+ def loss_fn(params, x, y):
68
+ """Compute loss for parameters."""
69
+ logits = self.forward(params, x)
70
+
71
+ # Binary cross entropy loss with label smoothing
72
+ alpha = 0.1 # Label smoothing factor
73
+
74
+ # Smooth labels
75
+ y_smooth = (1 - alpha) * y + alpha * 0.5
76
+
77
+ # Convert logits to probabilities
78
+ probs = 0.5 * (logits + 1) # Map from [-1,1] to [0,1]
79
+ probs = jnp.clip(probs, eps, 1 - eps)
80
+
81
+ # Compute loss with label smoothing
82
+ bce_loss = -jnp.mean(
83
+ 0.5 * (1 + y_smooth) * jnp.log(probs) +
84
+ 0.5 * (1 - y_smooth) * jnp.log(1 - probs)
85
+ )
86
+
87
+ # L2 regularization with very small weight
88
+ l2_reg = sum(jnp.sum(w ** 2) for w in params['weights'].values())
89
+ return bce_loss + 0.000001 * l2_reg
90
+
91
+ @jax.jit
92
+ def compute_updates(params, x, y):
93
+ """Compute gradients and loss."""
94
+ loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)
95
+ return grads, loss_value
96
+
97
+ def train_step(self, params, x, y, network_idx):
98
+ """Perform single training step with momentum."""
99
+ # Compute gradients
100
+ grads, loss_value = compute_updates(params, x, y)
101
+
102
+ # Get momentum buffer for this network
103
+ momentum = self.momentum_buffers[network_idx]
104
+
105
+ # Gradient norm for adaptive learning rate
106
+ grad_norm = jnp.sqrt(
107
+ sum(jnp.sum(g ** 2) for g in grads['weights'].values()) +
108
+ jnp.sum(grads['biases'] ** 2) +
109
+ jnp.sum(grads['gamma'] ** 2) +
110
+ jnp.sum(grads['beta'] ** 2) +
111
+ eps # Add eps for numerical stability
112
+ )
113
+
114
+ # Compute adaptive learning rate
115
+ if grad_norm > 1.0:
116
+ effective_lr = self.learning_rate / grad_norm
117
+ else:
118
+ effective_lr = self.learning_rate * (1.0 + jnp.log(grad_norm + eps))
119
+
120
+ # Clip learning rate to reasonable range
121
+ effective_lr = jnp.clip(effective_lr, min_lr, max_lr)
122
+
123
+ # Update weights momentum with adaptive learning rate
124
+ new_weights = {}
125
+ for k in params['weights'].keys():
126
+ grad = grads['weights'][k]
127
+
128
+ # Update momentum with gradient clipping
129
+ momentum['weights'][k] = (
130
+ self.beta * momentum['weights'][k] +
131
+ (1 - self.beta) * jnp.clip(grad, -1.0, 1.0)
132
+ )
133
+
134
+ # Apply update with weight decay
135
+ weight_decay = 0.0001 * params['weights'][k]
136
+ new_weights[k] = params['weights'][k] - effective_lr * (
137
+ momentum['weights'][k] + weight_decay
138
+ )
139
+
140
+ # Update biases momentum
141
+ momentum['biases'] = (
142
+ self.beta * momentum['biases'] +
143
+ (1 - self.beta) * jnp.clip(grads['biases'], -1.0, 1.0)
144
+ )
145
+ new_biases = params['biases'] - effective_lr * momentum['biases']
146
+
147
+ # Update layer norm parameters with smaller learning rate
148
+ ln_lr = 0.1 * effective_lr # Slower updates for stability
149
+
150
+ # Gamma (scale)
151
+ momentum['gamma'] = (
152
+ self.beta * momentum['gamma'] +
153
+ (1 - self.beta) * jnp.clip(grads['gamma'], -0.1, 0.1)
154
+ )
155
+ new_gamma = params['gamma'] - ln_lr * momentum['gamma']
156
+ new_gamma = jnp.clip(new_gamma, 0.1, 10.0) # Prevent collapse
157
+
158
+ # Beta (shift)
159
+ momentum['beta'] = (
160
+ self.beta * momentum['beta'] +
161
+ (1 - self.beta) * jnp.clip(grads['beta'], -0.1, 0.1)
162
+ )
163
+ new_beta = params['beta'] - ln_lr * momentum['beta']
164
+
165
+ return {
166
+ 'weights': new_weights,
167
+ 'biases': new_biases,
168
+ 'gamma': new_gamma,
169
+ 'beta': new_beta
170
+ }, loss_value
171
+
172
+ return train_step
173
+
174
+ def _mutate_genome(self, genome: Genome) -> Genome:
175
+ """Mutate genome architecture."""
176
+ new_genome = genome.copy()
177
+
178
+ # Mutate weights and biases
179
+ for key in list(new_genome.params['weights'].keys()):
180
+ if np.random.random() < 0.1:
181
+ new_genome.params['weights'][key] += np.random.normal(0, 0.2)
182
+
183
+ for key in list(new_genome.params['biases'].keys()):
184
+ if np.random.random() < 0.1:
185
+ new_genome.params['biases'][key] += np.random.normal(0, 0.2)
186
+
187
+ return new_genome
188
+
189
+ def _select_parent(self, fitnesses: List[float]) -> int:
190
+ """Select parent using tournament selection."""
191
+ # Tournament selection
192
+ tournament_size = 3
193
+ best_idx = np.random.randint(len(fitnesses))
194
+ best_fitness = fitnesses[best_idx]
195
+
196
+ for _ in range(tournament_size - 1):
197
+ idx = np.random.randint(len(fitnesses))
198
+ if fitnesses[idx] > best_fitness:
199
+ best_idx = idx
200
+ best_fitness = fitnesses[idx]
201
+
202
+ return best_idx
203
+
204
+ def _compute_fitness(self, network: Network, x: jnp.ndarray, y: jnp.ndarray,
205
+ n_epochs: int = 100, batch_size: int = 32) -> float:
206
+ """Compute fitness of network."""
207
+ n_samples = x.shape[0]
208
+ best_loss = float('inf')
209
+ best_accuracy = 0.0
210
+
211
+ # Initial prediction
212
+ initial_pred = network.predict(x)
213
+ initial_acc = float(jnp.mean((initial_pred == y)))
214
+
215
+ # Train network
216
+ no_improve = 0
217
+ for epoch in range(n_epochs):
218
+ # Shuffle data
219
+ perm = np.random.permutation(n_samples)
220
+ x_shuffled = x[perm]
221
+ y_shuffled = y[perm]
222
+
223
+ # Train in batches
224
+ epoch_losses = []
225
+ for i in range(0, n_samples, batch_size):
226
+ batch_x = x_shuffled[i:min(i + batch_size, n_samples)]
227
+ batch_y = y_shuffled[i:min(i + batch_size, n_samples)]
228
+
229
+ # Train step
230
+ network.params, loss = network._train_step(network.params, batch_x, batch_y)
231
+ epoch_losses.append(float(loss))
232
+
233
+ # Update best loss
234
+ avg_loss = float(np.mean(epoch_losses))
235
+ if avg_loss < best_loss:
236
+ best_loss = avg_loss
237
+ no_improve = 0
238
+ else:
239
+ no_improve += 1
240
+
241
+ # Compute accuracy
242
+ predictions = network.predict(x)
243
+ accuracy = float(jnp.mean((predictions == y)))
244
+ best_accuracy = max(best_accuracy, accuracy)
245
+
246
+ # Print progress every 10 epochs
247
+ if epoch % 10 == 0:
248
+ print(f"Epoch {epoch}: Loss = {avg_loss:.4f}, Accuracy = {accuracy:.4f}")
249
+
250
+ # Early stopping if good accuracy or no improvement
251
+ if accuracy > 0.95 or no_improve >= 10:
252
+ print(f"Early stopping at epoch {epoch}")
253
+ print(f"Final accuracy: {accuracy:.4f}")
254
+ break
255
+
256
+ # Print improvement
257
+ print(f"Network improved from {initial_acc:.4f} to {best_accuracy:.4f}")
258
+
259
+ # Fitness based on accuracy
260
+ fitness = best_accuracy
261
+
262
+ return float(fitness)
263
+
264
+ def evolve(self, x: jnp.ndarray, y: jnp.ndarray, n_generations: int = 50) -> Network:
265
+ """Evolve network architectures."""
266
+ for generation in range(n_generations):
267
+ print(f"\nGeneration {generation}")
268
+
269
+ # Evaluate current population
270
+ fitnesses = []
271
+ for network in self.population:
272
+ fitness = self._compute_fitness(network, x, y)
273
+ fitnesses.append(fitness)
274
+
275
+ # Update best network
276
+ if fitness > self.best_fitness:
277
+ self.best_fitness = fitness
278
+ self.best_network = Network(network.genome.copy())
279
+ print(f"New best fitness: {fitness:.4f}")
280
+
281
+ # Create new population through selection and mutation
282
+ new_population = []
283
+
284
+ # Keep best network (elitism)
285
+ best_idx = np.argmax(fitnesses)
286
+ new_population.append(Network(self.population[best_idx].genome.copy()))
287
+
288
+ # Create rest of population
289
+ while len(new_population) < self.population_size:
290
+ # Select parent
291
+ parent_idx = self._select_parent(fitnesses)
292
+ parent = self.population[parent_idx].genome
293
+
294
+ # Create child through mutation
295
+ child_genome = self._mutate_genome(parent)
296
+ new_population.append(Network(child_genome))
297
+
298
+ self.population = new_population
299
+
300
+ return self.best_network