eyad-silx commited on
Commit
ecccd48
·
verified ·
1 Parent(s): 4d76cfb

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +434 -0
train.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train NEAT networks to play volleyball using hardware acceleration when available."""
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from jax import random
6
+ from evojax.task.slimevolley import SlimeVolley
7
+ from typing import List, Tuple, Dict
8
+ import numpy as np
9
+ import time
10
+ from PIL import Image
11
+ import io
12
+ import os
13
+
14
+ # Try to initialize JAX with GPU
15
+ try:
16
+ # Configure JAX to use GPU
17
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
18
+ os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
19
+ os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
20
+
21
+ # Check available devices
22
+ print(f"JAX devices available: {jax.devices()}")
23
+ print(f"Using device: {jax.devices()[0].platform.upper()}")
24
+ except Exception as e:
25
+ print(f"Note: Using CPU - {str(e)}")
26
+
27
+ class NodeGene:
28
+ """A gene representing a node in the neural network."""
29
+ def __init__(self, id: int, node_type: str, activation: str = 'tanh'):
30
+ self.id = id
31
+ self.type = node_type # 'input', 'hidden', or 'output'
32
+ self.activation = activation
33
+
34
+ # Use deterministic key generation
35
+ seed = abs(hash(f"node_{id}")) % (2**32 - 1) # Ensure positive seed
36
+ key = random.PRNGKey(seed)
37
+ self.bias = float(random.normal(key, shape=()) * 0.1)
38
+
39
+ class ConnectionGene:
40
+ """A gene representing a connection between nodes."""
41
+ def __init__(self, source: int, target: int, weight: float = None, enabled: bool = True):
42
+ self.source = source
43
+ self.target = target
44
+ self.enabled = enabled
45
+ self.innovation = hash((source, target))
46
+
47
+ if weight is None:
48
+ # Use deterministic key generation
49
+ seed = abs(hash(f"conn_{source}_{target}")) % (2**32 - 1)
50
+ key = random.PRNGKey(seed)
51
+ weight = float(random.normal(key, shape=()) * 0.1)
52
+ self.weight = weight
53
+
54
+ class Genome:
55
+ def __init__(self, n_inputs: int, n_outputs: int):
56
+ # Create input nodes (0 to n_inputs-1)
57
+ self.node_genes = {i: NodeGene(i, 'input') for i in range(n_inputs)}
58
+
59
+ # Create exactly 3 output nodes for left, right, jump
60
+ n_outputs = 3 # Force exactly 3 outputs
61
+ for i in range(n_outputs):
62
+ self.node_genes[n_inputs + i] = NodeGene(n_inputs + i, 'output')
63
+
64
+ self.connection_genes: List[ConnectionGene] = []
65
+
66
+ # Initialize with randomized connections using unique keys
67
+ seed = int(time.time() * 1000) % (2**32 - 1)
68
+ master_key = random.PRNGKey(seed)
69
+
70
+ # Add direct connections with random weights
71
+ for i in range(n_inputs):
72
+ for j in range(n_outputs):
73
+ master_key, key = random.split(master_key)
74
+ if random.uniform(key, shape=()) < 0.7: # 70% chance of connection
75
+ master_key, key = random.split(master_key)
76
+ weight = float(random.normal(key, shape=()) * 0.5) # Larger initial weights
77
+ self.connection_genes.append(
78
+ ConnectionGene(i, n_inputs + j, weight=weight)
79
+ )
80
+
81
+ # Add hidden nodes with random connections
82
+ master_key, key = random.split(master_key)
83
+ n_hidden = int(random.randint(key, (), 1, 4)) # Random number of hidden nodes
84
+ hidden_start = n_inputs + n_outputs
85
+
86
+ for i in range(n_hidden):
87
+ node_id = hidden_start + i
88
+ self.node_genes[node_id] = NodeGene(node_id, 'hidden')
89
+
90
+ # Connect random inputs to this hidden node
91
+ for j in range(n_inputs):
92
+ master_key, key = random.split(master_key)
93
+ if random.uniform(key, shape=()) < 0.5:
94
+ master_key, key = random.split(master_key)
95
+ weight = float(random.normal(key, shape=()) * 0.5)
96
+ self.connection_genes.append(
97
+ ConnectionGene(j, node_id, weight=weight)
98
+ )
99
+
100
+ # Connect this hidden node to random outputs
101
+ for j in range(n_outputs):
102
+ master_key, key = random.split(master_key)
103
+ if random.uniform(key, shape=()) < 0.5:
104
+ master_key, key = random.split(master_key)
105
+ weight = float(random.normal(key, shape=()) * 0.5)
106
+ self.connection_genes.append(
107
+ ConnectionGene(node_id, n_inputs + j, weight=weight)
108
+ )
109
+
110
+ def mutate(self, config: Dict):
111
+ seed = int(time.time() * 1000) % (2**32 - 1)
112
+ key = random.PRNGKey(seed)
113
+
114
+ # Mutate connection weights
115
+ for conn in self.connection_genes:
116
+ key, subkey = random.split(key)
117
+ if random.uniform(subkey, shape=()) < config['weight_mutation_rate']:
118
+ key, subkey = random.split(key)
119
+ # Sometimes reset weight completely
120
+ if random.uniform(subkey, shape=()) < 0.1:
121
+ key, subkey = random.split(key)
122
+ conn.weight = float(random.normal(subkey, shape=()) * 0.5)
123
+ else:
124
+ # Otherwise adjust existing weight
125
+ key, subkey = random.split(key)
126
+ conn.weight += float(random.normal(subkey) * config['weight_mutation_power'])
127
+
128
+ # Mutate node biases
129
+ for node in self.node_genes.values():
130
+ key, subkey = random.split(key)
131
+ if random.uniform(subkey, shape=()) < 0.1: # 10% chance to mutate bias
132
+ key, subkey = random.split(key)
133
+ node.bias += float(random.normal(subkey) * 0.1)
134
+
135
+ # Add new node
136
+ key, subkey = random.split(key)
137
+ if random.uniform(subkey, shape=()) < config['add_node_rate']:
138
+ if self.connection_genes:
139
+ # Choose random connection to split
140
+ conn = np.random.choice(self.connection_genes)
141
+ new_id = max(self.node_genes.keys()) + 1
142
+
143
+ # Create new node with random bias
144
+ self.node_genes[new_id] = NodeGene(new_id, 'hidden')
145
+
146
+ # Create two new connections with some randomization
147
+ key, subkey = random.split(key)
148
+ weight1 = float(random.normal(subkey, shape=()) * 0.5)
149
+ key, subkey = random.split(key)
150
+ weight2 = float(random.normal(subkey, shape=()) * 0.5)
151
+
152
+ self.connection_genes.append(
153
+ ConnectionGene(conn.source, new_id, weight=weight1)
154
+ )
155
+ self.connection_genes.append(
156
+ ConnectionGene(new_id, conn.target, weight=weight2)
157
+ )
158
+
159
+ # Disable old connection
160
+ conn.enabled = False
161
+
162
+ # Add new connection
163
+ key, subkey = random.split(key)
164
+ if random.uniform(subkey, shape=()) < config['add_connection_rate']:
165
+ # Get all possible nodes
166
+ nodes = list(self.node_genes.keys())
167
+ for _ in range(10): # Try 10 times to find valid connection
168
+ source = np.random.choice(nodes)
169
+ target = np.random.choice(nodes)
170
+
171
+ # Ensure forward propagation (source id < target id)
172
+ if source < target:
173
+ # Check if connection already exists
174
+ if not any(c.source == source and c.target == target
175
+ for c in self.connection_genes):
176
+ key, subkey = random.split(key)
177
+ weight = float(random.normal(subkey, shape=()) * 0.5)
178
+ self.connection_genes.append(
179
+ ConnectionGene(source, target, weight=weight)
180
+ )
181
+ break
182
+
183
+ class Network:
184
+ def __init__(self, genome: Genome):
185
+ self.genome = genome
186
+ # Sort nodes by ID to ensure consistent ordering
187
+ self.input_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'input'], key=lambda x: x.id)
188
+ self.hidden_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'hidden'], key=lambda x: x.id)
189
+ self.output_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'output'], key=lambda x: x.id)
190
+
191
+ # Verify we have exactly 3 output nodes
192
+ assert len(self.output_nodes) == 3, f"Expected 3 output nodes, got {len(self.output_nodes)}"
193
+
194
+ def forward(self, x: jnp.ndarray) -> jnp.ndarray:
195
+ # Ensure input is 2D with shape (batch_size, input_dim)
196
+ if len(x.shape) == 1:
197
+ x = jnp.expand_dims(x, 0)
198
+
199
+ batch_size = x.shape[0]
200
+
201
+ # Initialize node values
202
+ values = {}
203
+ for node in self.genome.node_genes.values():
204
+ values[node.id] = jnp.zeros((batch_size,))
205
+ values[node.id] = values[node.id] + node.bias
206
+
207
+ # Set input values
208
+ for i, node in enumerate(self.input_nodes):
209
+ values[node.id] = x[:, i]
210
+
211
+ # Process nodes in order
212
+ for node in self.hidden_nodes + self.output_nodes:
213
+ # Sum incoming connections
214
+ total = jnp.zeros((batch_size,))
215
+ total = total + node.bias
216
+
217
+ for conn in self.genome.connection_genes:
218
+ if conn.enabled and conn.target == node.id:
219
+ total = total + values[conn.source] * conn.weight
220
+
221
+ # Apply activation
222
+ values[node.id] = jnp.tanh(total)
223
+
224
+ # Get output values and ensure shape (batch_size, 3)
225
+ outputs = []
226
+ for node in self.output_nodes:
227
+ outputs.append(values[node.id])
228
+
229
+ # Stack along new axis to get (batch_size, 3)
230
+ return jnp.stack(outputs, axis=-1)
231
+
232
+ def evaluate_parallel(networks: List[Network], env: SlimeVolley, batch_size: int = 8) -> List[float]:
233
+ """Evaluate multiple networks in parallel using JAX's vectorization."""
234
+ total_networks = len(networks)
235
+ fitness_scores = []
236
+
237
+ for i in range(0, total_networks, batch_size):
238
+ batch = networks[i:i + batch_size]
239
+ batch_size_actual = len(batch)
240
+
241
+ # Initialize environment states with proper key shape
242
+ seed = int(time.time() * 1000) % (2**32 - 1)
243
+ key = random.PRNGKey(seed)
244
+ states = env.reset(key)
245
+ total_rewards = np.zeros(batch_size_actual)
246
+
247
+ # Run episodes
248
+ for step in range(1000): # Max steps per episode
249
+ # Get observations and normalize
250
+ observations = states.obs / 10.0
251
+
252
+ # Get actions from all networks
253
+ actions = np.stack([
254
+ net.forward(obs[None, :])
255
+ for net, obs in zip(batch, observations)
256
+ ])
257
+
258
+ # Convert to binary actions
259
+ thresholds = np.array([0.5, 0.5, 0.5])
260
+ binary_actions = (actions > thresholds).astype(np.float32)
261
+
262
+ # Step environment
263
+ key, subkey = random.split(key)
264
+ next_states, rewards, dones = env.step(states, binary_actions)
265
+ total_rewards += np.array([float(r) for r in rewards])
266
+ states = next_states
267
+
268
+ if np.all(dones):
269
+ break
270
+
271
+ fitness_scores.extend(list(total_rewards))
272
+
273
+ return fitness_scores
274
+
275
+ def create_next_generation(population: List[Network], fitness_scores: List[float], config: Dict):
276
+ """Create the next generation of networks based on the current population and fitness scores."""
277
+ next_population = []
278
+
279
+ # Keep top 20% unchanged (less elitism = faster adaptation)
280
+ n_elite = max(2, int(0.2 * len(population)))
281
+ next_population.extend(population[:n_elite])
282
+
283
+ # Fill rest with mutated versions of top 50%
284
+ n_top = max(5, int(0.5 * len(population)))
285
+ while len(next_population) < len(population):
286
+ # Tournament selection with size 3 (smaller = faster)
287
+ tournament_size = 3
288
+ candidates = np.random.choice(population[:n_top], tournament_size, replace=False)
289
+ parent = max(candidates, key=lambda x: fitness_scores[population.index(x)])
290
+
291
+ child = Network(parent.genome)
292
+ child.genome.mutate(config)
293
+ next_population.append(child)
294
+
295
+ return next_population
296
+
297
+ def record_gameplay(network: Network, env: SlimeVolley, filename: str = 'gameplay.gif', max_steps: int = 1000):
298
+ """Record a game played by the network and save it as a GIF."""
299
+ frames = []
300
+
301
+ # Initialize environment
302
+ seed = int(time.time() * 1000) % (2**32 - 1)
303
+ key = random.PRNGKey(seed)
304
+ state = env.reset(key)
305
+ done = False
306
+ steps = 0
307
+
308
+ while not done and steps < max_steps:
309
+ # Render current frame
310
+ frame = env.render(state)
311
+ frames.append(frame) # frame is already a PIL Image
312
+
313
+ # Get observation and normalize
314
+ obs = state.obs[None, :] / 10.0
315
+
316
+ # Get action from network
317
+ raw_action = network.forward(obs)
318
+
319
+ # Convert to binary actions
320
+ thresholds = jnp.array([0.5, 0.5, 0.5])
321
+ binary_action = (raw_action > thresholds).astype(jnp.float32)
322
+
323
+ # Prevent simultaneous left/right
324
+ both_active = jnp.logical_and(binary_action[:, 0] > 0, binary_action[:, 1] > 0)
325
+ prefer_left = raw_action[:, 0] > raw_action[:, 1]
326
+
327
+ binary_action = binary_action.at[:, 0].set(
328
+ jnp.where(both_active, prefer_left.astype(jnp.float32), binary_action[:, 0])
329
+ )
330
+ binary_action = binary_action.at[:, 1].set(
331
+ jnp.where(both_active, (~prefer_left).astype(jnp.float32), binary_action[:, 1])
332
+ )
333
+
334
+ # Step environment
335
+ key, subkey = random.split(key) # Get new key for each step
336
+ state, reward, done = env.step(state, binary_action) # Already batched
337
+ steps += 1
338
+
339
+ # Save as GIF
340
+ if frames:
341
+ frames[0].save(
342
+ filename,
343
+ save_all=True,
344
+ append_images=frames[1:],
345
+ duration=50, # 20 fps
346
+ loop=0
347
+ )
348
+ print(f"Gameplay recorded and saved to {filename}")
349
+ else:
350
+ print("No frames were recorded")
351
+
352
+ def main():
353
+ """Main training loop with hardware acceleration when available."""
354
+ # Initialize environment
355
+ env = SlimeVolley(max_steps=1000)
356
+
357
+ # Configuration for evolution
358
+ config = {
359
+ 'population_size': 64,
360
+ 'batch_size': 8, # Smaller batch size for better compatibility
361
+ 'weight_mutation_rate': 0.95,
362
+ 'weight_mutation_power': 4.0,
363
+ 'add_node_rate': 0.0,
364
+ 'add_connection_rate': 0.0,
365
+ }
366
+
367
+ print("\nTraining Configuration:")
368
+ print(f"Population Size: {config['population_size']}")
369
+ print(f"Batch Size: {config['batch_size']}")
370
+ print(f"Mutation Rate: {config['weight_mutation_rate']}")
371
+ print("-" * 40)
372
+
373
+ # Create initial population
374
+ population = [
375
+ Network(Genome(n_inputs=12, n_outputs=3))
376
+ for _ in range(config['population_size'])
377
+ ]
378
+
379
+ best_fitness = float('-inf')
380
+ best_network = None
381
+
382
+ # Evolution loop
383
+ for generation in range(1000):
384
+ start_time = time.time()
385
+ print(f"\nGeneration {generation}")
386
+
387
+ # Evaluate population in batches
388
+ fitness_scores = evaluate_parallel(
389
+ population,
390
+ env,
391
+ batch_size=config['batch_size']
392
+ )
393
+
394
+ # Track best network
395
+ max_fitness = max(fitness_scores)
396
+ if max_fitness > best_fitness:
397
+ best_idx = fitness_scores.index(max_fitness)
398
+ best_fitness = max_fitness
399
+ best_network = population[best_idx]
400
+ print(f"New best fitness: {best_fitness:.2f}")
401
+
402
+ # Record gameplay for significant improvements
403
+ if max_fitness > best_fitness + 2.0:
404
+ record_gameplay(best_network, env, f"best_gen_{generation}.gif")
405
+
406
+ # Early stopping
407
+ if best_fitness > 8.0:
408
+ print(f"Target fitness reached: {best_fitness:.2f}")
409
+ break
410
+
411
+ # Create next generation
412
+ population = create_next_generation(
413
+ population,
414
+ fitness_scores,
415
+ config
416
+ )
417
+
418
+ # Print stats every 5 generations
419
+ if generation % 5 == 0:
420
+ gen_time = time.time() - start_time
421
+ print(f"\nGeneration {generation} Stats:")
422
+ print(f"Best Fitness: {max_fitness:.2f}")
423
+ print(f"Average Fitness: {np.mean(fitness_scores):.2f}")
424
+ print(f"Generation Time: {gen_time:.2f}s")
425
+
426
+ print("\nTraining complete!")
427
+ print(f"Best fitness achieved: {best_fitness:.2f}")
428
+
429
+ # Save final network
430
+ if best_network:
431
+ record_gameplay(best_network, env, "final_gameplay.gif")
432
+
433
+ if __name__ == '__main__':
434
+ main()