|
"""Visualization utilities for NEAT networks and training progress.""" |
|
import os |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import networkx as nx |
|
from typing import List, Dict, Any |
|
import imageio |
|
from IPython.display import HTML |
|
from neat.network import Network |
|
from neat.genome import Genome |
|
|
|
def draw_network(network: Network, save_path: str = None) -> None: |
|
"""Draw a neural network visualization using networkx and matplotlib. |
|
|
|
Args: |
|
network: The network to visualize |
|
save_path: Optional path to save the visualization |
|
""" |
|
|
|
G = nx.DiGraph() |
|
|
|
|
|
node_types = {} |
|
node_positions = {} |
|
|
|
|
|
all_nodes = set() |
|
for conn in network.connection_genes: |
|
if conn.enabled: |
|
all_nodes.add(conn.source) |
|
all_nodes.add(conn.target) |
|
|
|
|
|
layer_spacing = 2.0 |
|
|
|
|
|
input_nodes = set(range(network.input_size)) |
|
input_y = np.linspace(-1, 1, len(input_nodes)) |
|
for i, node in enumerate(sorted(input_nodes)): |
|
node_id = str(node) |
|
node_types[node_id] = 'input' |
|
node_positions[node_id] = np.array([0, input_y[i]]) |
|
G.add_node(node_id) |
|
all_nodes.discard(node) |
|
|
|
|
|
output_start = len(network.node_genes) - network.output_size |
|
output_nodes = set(range(output_start, len(network.node_genes))) |
|
output_y = np.linspace(-1, 1, len(output_nodes)) |
|
for i, node in enumerate(sorted(output_nodes)): |
|
node_id = str(node) |
|
node_types[node_id] = 'output' |
|
node_positions[node_id] = np.array([layer_spacing, output_y[i]]) |
|
G.add_node(node_id) |
|
all_nodes.discard(node) |
|
|
|
|
|
hidden_nodes = all_nodes |
|
if hidden_nodes: |
|
hidden_y = np.linspace(-1, 1, len(hidden_nodes)) |
|
for i, node in enumerate(sorted(hidden_nodes)): |
|
node_id = str(node) |
|
node_types[node_id] = 'hidden' |
|
node_positions[node_id] = np.array([layer_spacing/2, hidden_y[i]]) |
|
G.add_node(node_id) |
|
|
|
|
|
for conn in network.connection_genes: |
|
if conn.enabled: |
|
G.add_edge(str(conn.source), str(conn.target), weight=conn.weight) |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
|
|
|
|
for node, (x, y) in node_positions.items(): |
|
node_type = node_types[node] |
|
if node_type == 'input': |
|
color = 'lightblue' |
|
elif node_type == 'hidden': |
|
color = 'gray' |
|
else: |
|
color = 'lightgreen' |
|
plt.scatter(x, y, c=color, s=500, zorder=2) |
|
plt.text(x, y, node, horizontalalignment='center', verticalalignment='center') |
|
|
|
|
|
edge_weights = [G[u][v]['weight'] for u, v in G.edges()] |
|
pos = node_positions |
|
nx.draw_networkx_edges(G, pos, edge_color='gray', |
|
width=1, alpha=0.5, |
|
arrows=True, arrowsize=10, |
|
edge_cmap=plt.cm.RdYlBu, edge_vmin=-1, edge_vmax=1, |
|
connectionstyle="arc3,rad=0.2") |
|
|
|
plt.title("Neural Network Architecture") |
|
plt.axis('equal') |
|
plt.axis('off') |
|
|
|
if save_path: |
|
plt.savefig(save_path, bbox_inches='tight', dpi=300) |
|
plt.close() |
|
else: |
|
plt.show() |
|
|
|
def plot_training_history(history: Dict[str, List[float]], save_path: str = None) -> None: |
|
"""Plot training metrics over generations. |
|
|
|
Args: |
|
history: Dictionary containing lists of metrics per generation |
|
save_path: Optional path to save the plot |
|
""" |
|
plt.figure(figsize=(12, 8)) |
|
|
|
|
|
if 'best_fitness' in history: |
|
plt.plot(history['best_fitness'], label='Best Fitness', color='green') |
|
if 'avg_fitness' in history: |
|
plt.plot(history['avg_fitness'], label='Average Fitness', color='blue') |
|
|
|
|
|
if 'species_count' in history: |
|
ax2 = plt.twinx() |
|
ax2.plot(history['species_count'], label='Species Count', color='red', linestyle='--') |
|
ax2.set_ylabel('Number of Species') |
|
|
|
plt.xlabel('Generation') |
|
plt.ylabel('Fitness') |
|
plt.title('Training Progress') |
|
plt.legend() |
|
|
|
if save_path: |
|
plt.savefig(save_path, bbox_inches='tight') |
|
plt.close() |
|
else: |
|
plt.show() |
|
|
|
def create_gameplay_gif(frames: List[np.ndarray], output_path: str, fps: int = 30) -> None: |
|
"""Create a GIF from gameplay frames. |
|
|
|
Args: |
|
frames: List of frames as numpy arrays |
|
output_path: Path to save the GIF |
|
fps: Frames per second for the GIF |
|
""" |
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
|
|
|
imageio.mimsave(output_path, frames, fps=fps) |
|
|
|
def plot_species_complexity(species_stats: List[Dict[str, Any]], save_path: str = None) -> None: |
|
"""Plot the complexity of species over generations. |
|
|
|
Args: |
|
species_stats: List of dictionaries containing species statistics per generation |
|
save_path: Optional path to save the plot |
|
""" |
|
plt.figure(figsize=(12, 8)) |
|
|
|
generations = range(len(species_stats)) |
|
avg_nodes = [stats['avg_nodes'] for stats in species_stats] |
|
avg_connections = [stats['avg_connections'] for stats in species_stats] |
|
|
|
plt.plot(generations, avg_nodes, label='Average Nodes', color='blue') |
|
plt.plot(generations, avg_connections, label='Average Connections', color='green') |
|
|
|
plt.xlabel('Generation') |
|
plt.ylabel('Count') |
|
plt.title('Network Complexity Over Time') |
|
plt.legend() |
|
|
|
if save_path: |
|
plt.savefig(save_path, bbox_inches='tight') |
|
plt.close() |
|
else: |
|
plt.show() |
|
|
|
def plot_activation_distribution(genomes: List[Genome], save_path: str = None) -> None: |
|
"""Plot the distribution of activation functions across the population. |
|
|
|
Args: |
|
genomes: List of genomes to analyze |
|
save_path: Optional path to save the plot |
|
""" |
|
activation_counts = {} |
|
|
|
|
|
for genome in genomes: |
|
for node in genome.nodes.values(): |
|
activation_name = node.activation.__name__ if hasattr(node.activation, '__name__') else str(node.activation) |
|
activation_counts[activation_name] = activation_counts.get(activation_name, 0) + 1 |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
plt.bar(activation_counts.keys(), activation_counts.values()) |
|
plt.xticks(rotation=45) |
|
plt.xlabel('Activation Function') |
|
plt.ylabel('Count') |
|
plt.title('Distribution of Activation Functions') |
|
|
|
if save_path: |
|
plt.savefig(save_path, bbox_inches='tight') |
|
plt.close() |
|
else: |
|
plt.show() |
|
|