neat / old_train.py
eyad-silx's picture
Upload old_train.py with huggingface_hub
0e0538d verified
import jax
import jax.numpy as jnp
from jax import random
from evojax.task.slimevolley import SlimeVolley
from typing import List, Tuple, Dict
import numpy as np
import time
class NodeGene:
def __init__(self, id: int, node_type: str, activation: str = 'tanh'):
self.id = id
self.type = node_type # 'input', 'hidden', or 'output'
self.activation = activation
# Use both id and timestamp for randomization
timestamp = int(time.time() * 1000)
key = random.PRNGKey(hash((id, timestamp)) % (2**32))
self.bias = float(random.normal(key, shape=()) * 0.1) # Small random bias
class ConnectionGene:
def __init__(self, source: int, target: int, weight: float = None, enabled: bool = True):
self.source = source
self.target = target
# Use source, target, and timestamp for randomization
timestamp = int(time.time() * 1000)
key = random.PRNGKey(hash((source, target, timestamp)) % (2**32))
if weight is None:
key, subkey = random.split(key)
weight = float(random.normal(subkey, shape=()) * 0.1) # Small random weight
self.weight = weight
self.enabled = enabled
self.innovation = hash((source, target))
class Genome:
def __init__(self, n_inputs: int, n_outputs: int):
# Create input nodes (0 to n_inputs-1)
self.node_genes = {i: NodeGene(i, 'input') for i in range(n_inputs)}
# Create exactly 3 output nodes for left, right, jump
n_outputs = 3 # Force exactly 3 outputs
for i in range(n_outputs):
self.node_genes[n_inputs + i] = NodeGene(n_inputs + i, 'output')
self.connection_genes: List[ConnectionGene] = []
# Initialize with randomized connections using unique keys
timestamp = int(time.time() * 1000)
master_key = random.PRNGKey(hash((n_inputs, n_outputs, timestamp)) % (2**32))
# Add direct connections with random weights
for i in range(n_inputs):
for j in range(n_outputs):
master_key, key = random.split(master_key)
if random.uniform(key, shape=()) < 0.7: # 70% chance of connection
master_key, key = random.split(master_key)
weight = float(random.normal(key, shape=()) * 0.5) # Larger initial weights
self.connection_genes.append(
ConnectionGene(i, n_inputs + j, weight=weight)
)
# Add hidden nodes with random connections
master_key, key = random.split(master_key)
n_hidden = int(random.randint(key, (), 1, 4)) # Random number of hidden nodes
hidden_start = n_inputs + n_outputs
for i in range(n_hidden):
node_id = hidden_start + i
self.node_genes[node_id] = NodeGene(node_id, 'hidden')
# Connect random inputs to this hidden node
for j in range(n_inputs):
master_key, key = random.split(master_key)
if random.uniform(key, shape=()) < 0.5:
master_key, key = random.split(master_key)
weight = float(random.normal(key, shape=()) * 0.5)
self.connection_genes.append(
ConnectionGene(j, node_id, weight=weight)
)
# Connect this hidden node to random outputs
for j in range(n_outputs):
master_key, key = random.split(master_key)
if random.uniform(key, shape=()) < 0.5:
master_key, key = random.split(master_key)
weight = float(random.normal(key, shape=()) * 0.5)
self.connection_genes.append(
ConnectionGene(node_id, n_inputs + j, weight=weight)
)
def mutate(self, config: Dict):
key = random.PRNGKey(0)
# Mutate connection weights
for conn in self.connection_genes:
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < config['weight_mutation_rate']:
key, subkey = random.split(key)
# Sometimes reset weight completely
if random.uniform(subkey, shape=()) < 0.1:
key, subkey = random.split(key)
conn.weight = float(random.normal(subkey, shape=()) * 0.5)
else:
# Otherwise adjust existing weight
key, subkey = random.split(key)
conn.weight += float(random.normal(subkey) * config['weight_mutation_power'])
# Mutate node biases
for node in self.node_genes.values():
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < 0.1: # 10% chance to mutate bias
key, subkey = random.split(key)
node.bias += float(random.normal(subkey) * 0.1)
# Add new node
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < config['add_node_rate']:
if self.connection_genes:
# Choose random connection to split
conn = np.random.choice(self.connection_genes)
new_id = max(self.node_genes.keys()) + 1
# Create new node with random bias
self.node_genes[new_id] = NodeGene(new_id, 'hidden')
# Create two new connections with some randomization
key, subkey = random.split(key)
weight1 = float(random.normal(subkey, shape=()) * 0.5)
key, subkey = random.split(key)
weight2 = float(random.normal(subkey, shape=()) * 0.5)
self.connection_genes.append(
ConnectionGene(conn.source, new_id, weight=weight1)
)
self.connection_genes.append(
ConnectionGene(new_id, conn.target, weight=weight2)
)
# Disable old connection
conn.enabled = False
# Add new connection
key, subkey = random.split(key)
if random.uniform(subkey, shape=()) < config['add_connection_rate']:
# Get all possible nodes
nodes = list(self.node_genes.keys())
for _ in range(10): # Try 10 times to find valid connection
source = np.random.choice(nodes)
target = np.random.choice(nodes)
# Ensure forward propagation (source id < target id)
if source < target:
# Check if connection already exists
if not any(c.source == source and c.target == target
for c in self.connection_genes):
key, subkey = random.split(key)
weight = float(random.normal(subkey, shape=()) * 0.5)
self.connection_genes.append(
ConnectionGene(source, target, weight=weight)
)
break
class Network:
def __init__(self, genome: Genome):
self.genome = genome
# Sort nodes by ID to ensure consistent ordering
self.input_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'input'], key=lambda x: x.id)
self.hidden_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'hidden'], key=lambda x: x.id)
self.output_nodes = sorted([n for n in genome.node_genes.values() if n.type == 'output'], key=lambda x: x.id)
# Verify we have exactly 3 output nodes
assert len(self.output_nodes) == 3, f"Expected 3 output nodes, got {len(self.output_nodes)}"
def forward(self, x: jnp.ndarray) -> jnp.ndarray:
# Ensure input is 2D with shape (batch_size, input_dim)
if len(x.shape) == 1:
x = jnp.expand_dims(x, 0)
batch_size = x.shape[0]
# Initialize node values
values = {}
for node in self.genome.node_genes.values():
values[node.id] = jnp.zeros((batch_size,))
values[node.id] = values[node.id] + node.bias
# Set input values
for i, node in enumerate(self.input_nodes):
values[node.id] = x[:, i]
# Process nodes in order
for node in self.hidden_nodes + self.output_nodes:
# Sum incoming connections
total = jnp.zeros((batch_size,))
total = total + node.bias
for conn in self.genome.connection_genes:
if conn.enabled and conn.target == node.id:
total = total + values[conn.source] * conn.weight
# Apply activation
values[node.id] = jnp.tanh(total)
# Get output values and ensure shape (batch_size, 3)
outputs = []
for node in self.output_nodes:
outputs.append(values[node.id])
# Stack along new axis to get (batch_size, 3)
return jnp.stack(outputs, axis=-1)
def evaluate_network(network: Network, env: SlimeVolley, n_episodes: int = 10) -> float:
total_reward = 0.0
# Generate a unique key for this evaluation
timestamp = int(time.time() * 1000)
network_id = id(network)
master_key = random.PRNGKey(hash((network_id, timestamp)) % (2**32))
for episode in range(n_episodes):
# Reset environment with proper key shape
master_key, reset_key = random.split(master_key)
state = env.reset(reset_key[None, :]) # Add batch dimension
done = False
episode_reward = 0.0
steps = 0
while not done and steps < 1000: # Add step limit
# Get observation and normalize
obs = state.obs[None, :] / 10.0 # Add batch dimension and scale inputs
# Get action from network (shape: batch_size, 3)
raw_action = network.forward(obs)
# Convert to binary actions using thresholds
thresholds = jnp.array([0.3, 0.3, 0.4]) # left, right, jump
binary_action = (raw_action > thresholds).astype(jnp.float32)
# Prevent simultaneous left/right using logical operations
both_active = jnp.logical_and(binary_action[:, 0] > 0, binary_action[:, 1] > 0)
prefer_left = raw_action[:, 0] > raw_action[:, 1]
# Update binary action based on preference
binary_action = binary_action.at[:, 0].set(
jnp.where(both_active, prefer_left.astype(jnp.float32), binary_action[:, 0])
)
binary_action = binary_action.at[:, 1].set(
jnp.where(both_active, (~prefer_left).astype(jnp.float32), binary_action[:, 1])
)
# Step environment
master_key, step_key = random.split(master_key)
next_state, reward, done = env.step(state, binary_action) # Already batched
# Process reward and done flag
if isinstance(reward, jnp.ndarray):
reward = float(jnp.reshape(reward, (-1,))[0]) # Get first element if batched
if isinstance(done, jnp.ndarray):
done = bool(jnp.reshape(done, (-1,))[0]) # Convert to Python bool
# Add small reward for movement to encourage exploration
any_movement = jnp.any(binary_action[:, :2] > 0)
movement_reward = 0.1 if bool(any_movement) else 0.0
# Add small reward for keeping ball in play
ball_height = float(jnp.reshape(next_state.obs[1], (-1,))[0]) if hasattr(next_state.obs, '__getitem__') else 0.0
height_reward = 0.1 if ball_height > 0.5 else 0.0
# Add reward for ball position and velocity
ball_x = float(jnp.reshape(next_state.obs[4], (-1,))[0]) # Ball x position
ball_vx = float(jnp.reshape(next_state.obs[6], (-1,))[0]) # Ball x velocity
position_reward = 0.2 if ball_x > 0 else 0.0 # Reward for keeping ball on opponent's side
velocity_reward = 0.1 if ball_vx > 0 else 0.0 # Reward for hitting ball towards opponent
# Calculate step reward with more emphasis on game outcome
step_reward = reward * 2.0 # Double the importance of winning/losing
bonus_reward = movement_reward + height_reward + position_reward + velocity_reward
total_step_reward = step_reward + bonus_reward * 0.5 # Scale down bonus rewards
episode_reward += total_step_reward
state = next_state
steps += 1
# Early termination bonus
if done and reward > 0: # Won the point
episode_reward += 10.0
total_reward += episode_reward
return total_reward / n_episodes
def main():
# Initialize environment
env = SlimeVolley()
# NEAT configuration
config = {
'population_size': 50, # Smaller population for faster iteration
'weight_mutation_rate': 0.8,
'weight_mutation_power': 0.3, # Increased for more exploration
'add_node_rate': 0.3,
'add_connection_rate': 0.5,
}
# Create initial population
population = [
Network(Genome(n_inputs=12, n_outputs=3))
for _ in range(config['population_size'])
]
best_fitness = float('-inf')
generations_without_improvement = 0
# Evolution loop
for generation in range(500): # More generations
print(f"\nGeneration {generation}")
print("-" * 20)
# Evaluate population
fitnesses = []
for i, net in enumerate(population):
fitness = evaluate_network(net, env)
fitnesses.append(fitness)
print(f"Network {i}: Fitness = {fitness:.2f}")
if fitness > best_fitness:
best_fitness = fitness
generations_without_improvement = 0
print(f"New best fitness: {best_fitness:.2f}")
# Check for improvement
generations_without_improvement += 1
if generations_without_improvement > 20:
print("No improvement for 20 generations, increasing mutation rates")
config['weight_mutation_rate'] = min(1.0, config['weight_mutation_rate'] * 1.2)
config['weight_mutation_power'] = min(0.5, config['weight_mutation_power'] * 1.2)
generations_without_improvement = 0
# Print progress
avg_fitness = sum(fitnesses) / len(fitnesses)
print(f"\nBest fitness: {best_fitness:.2f}")
print(f"Average fitness: {avg_fitness:.2f}")
# Selection and reproduction
new_population = []
sorted_indices = np.argsort(fitnesses)[::-1] # Best to worst
# Keep best networks
n_elite = 5 # Fewer elites
new_population.extend([population[i] for i in sorted_indices[:n_elite]])
print(f"Keeping top {n_elite} networks")
# Create offspring from best networks
while len(new_population) < config['population_size']:
# Tournament selection
tournament_size = 5
tournament = np.random.choice(sorted_indices[:20], tournament_size, replace=False)
parent_idx = tournament[np.argmax([fitnesses[i] for i in tournament])]
parent = population[parent_idx]
# Create offspring
child_genome = Genome(12, 3)
child_genome.node_genes = parent.genome.node_genes.copy()
child_genome.connection_genes = parent.genome.connection_genes.copy()
# Mutate child
child_genome.mutate(config)
# Add to new population
new_population.append(Network(child_genome))
population = new_population
print(f"Created {len(population)} networks for next generation")
if __name__ == '__main__':
main()