eyad-silx commited on
Commit
c12a8c5
·
verified ·
1 Parent(s): 0232428

Upload neat\visualize.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. neat//visualize.py +206 -0
neat//visualize.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualization utilities for NEAT networks and training progress."""
2
+ import os
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import networkx as nx
6
+ from typing import List, Dict, Any
7
+ import imageio
8
+ from IPython.display import HTML
9
+ from neat.network import Network
10
+ from neat.genome import Genome
11
+
12
+ def draw_network(network: Network, save_path: str = None) -> None:
13
+ """Draw a neural network visualization using networkx and matplotlib.
14
+
15
+ Args:
16
+ network: The network to visualize
17
+ save_path: Optional path to save the visualization
18
+ """
19
+ # Create directed graph
20
+ G = nx.DiGraph()
21
+
22
+ # Track node types and positions
23
+ node_types = {}
24
+ node_positions = {}
25
+
26
+ # Collect all unique nodes from connections
27
+ all_nodes = set()
28
+ for conn in network.connection_genes:
29
+ if conn.enabled:
30
+ all_nodes.add(conn.source)
31
+ all_nodes.add(conn.target)
32
+
33
+ # Calculate layout parameters
34
+ layer_spacing = 2.0
35
+
36
+ # Add input nodes (leftmost layer)
37
+ input_nodes = set(range(network.input_size))
38
+ input_y = np.linspace(-1, 1, len(input_nodes))
39
+ for i, node in enumerate(sorted(input_nodes)):
40
+ node_id = str(node)
41
+ node_types[node_id] = 'input'
42
+ node_positions[node_id] = np.array([0, input_y[i]])
43
+ G.add_node(node_id)
44
+ all_nodes.discard(node) # Remove from remaining nodes
45
+
46
+ # Add output nodes (rightmost layer)
47
+ output_start = len(network.node_genes) - network.output_size
48
+ output_nodes = set(range(output_start, len(network.node_genes)))
49
+ output_y = np.linspace(-1, 1, len(output_nodes))
50
+ for i, node in enumerate(sorted(output_nodes)):
51
+ node_id = str(node)
52
+ node_types[node_id] = 'output'
53
+ node_positions[node_id] = np.array([layer_spacing, output_y[i]])
54
+ G.add_node(node_id)
55
+ all_nodes.discard(node)
56
+
57
+ # Add hidden nodes (middle layer)
58
+ hidden_nodes = all_nodes # Remaining nodes are hidden
59
+ if hidden_nodes:
60
+ hidden_y = np.linspace(-1, 1, len(hidden_nodes))
61
+ for i, node in enumerate(sorted(hidden_nodes)):
62
+ node_id = str(node)
63
+ node_types[node_id] = 'hidden'
64
+ node_positions[node_id] = np.array([layer_spacing/2, hidden_y[i]])
65
+ G.add_node(node_id)
66
+
67
+ # Add connections
68
+ for conn in network.connection_genes:
69
+ if conn.enabled:
70
+ G.add_edge(str(conn.source), str(conn.target), weight=conn.weight)
71
+
72
+ # Draw the network
73
+ plt.figure(figsize=(8, 6))
74
+
75
+ # Draw nodes
76
+ for node, (x, y) in node_positions.items():
77
+ node_type = node_types[node]
78
+ if node_type == 'input':
79
+ color = 'lightblue'
80
+ elif node_type == 'hidden':
81
+ color = 'gray'
82
+ else: # output
83
+ color = 'lightgreen'
84
+ plt.scatter(x, y, c=color, s=500, zorder=2)
85
+ plt.text(x, y, node, horizontalalignment='center', verticalalignment='center')
86
+
87
+ # Draw edges
88
+ edge_weights = [G[u][v]['weight'] for u, v in G.edges()]
89
+ pos = node_positions
90
+ nx.draw_networkx_edges(G, pos, edge_color='gray',
91
+ width=1, alpha=0.5,
92
+ arrows=True, arrowsize=10,
93
+ edge_cmap=plt.cm.RdYlBu, edge_vmin=-1, edge_vmax=1,
94
+ connectionstyle="arc3,rad=0.2")
95
+
96
+ plt.title("Neural Network Architecture")
97
+ plt.axis('equal')
98
+ plt.axis('off')
99
+
100
+ if save_path:
101
+ plt.savefig(save_path, bbox_inches='tight', dpi=300)
102
+ plt.close()
103
+ else:
104
+ plt.show()
105
+
106
+ def plot_training_history(history: Dict[str, List[float]], save_path: str = None) -> None:
107
+ """Plot training metrics over generations.
108
+
109
+ Args:
110
+ history: Dictionary containing lists of metrics per generation
111
+ save_path: Optional path to save the plot
112
+ """
113
+ plt.figure(figsize=(12, 8))
114
+
115
+ # Plot fitness metrics
116
+ if 'best_fitness' in history:
117
+ plt.plot(history['best_fitness'], label='Best Fitness', color='green')
118
+ if 'avg_fitness' in history:
119
+ plt.plot(history['avg_fitness'], label='Average Fitness', color='blue')
120
+
121
+ # Plot species count if available
122
+ if 'species_count' in history:
123
+ ax2 = plt.twinx()
124
+ ax2.plot(history['species_count'], label='Species Count', color='red', linestyle='--')
125
+ ax2.set_ylabel('Number of Species')
126
+
127
+ plt.xlabel('Generation')
128
+ plt.ylabel('Fitness')
129
+ plt.title('Training Progress')
130
+ plt.legend()
131
+
132
+ if save_path:
133
+ plt.savefig(save_path, bbox_inches='tight')
134
+ plt.close()
135
+ else:
136
+ plt.show()
137
+
138
+ def create_gameplay_gif(frames: List[np.ndarray], output_path: str, fps: int = 30) -> None:
139
+ """Create a GIF from gameplay frames.
140
+
141
+ Args:
142
+ frames: List of frames as numpy arrays
143
+ output_path: Path to save the GIF
144
+ fps: Frames per second for the GIF
145
+ """
146
+ # Ensure output directory exists
147
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
148
+
149
+ # Save frames as GIF
150
+ imageio.mimsave(output_path, frames, fps=fps)
151
+
152
+ def plot_species_complexity(species_stats: List[Dict[str, Any]], save_path: str = None) -> None:
153
+ """Plot the complexity of species over generations.
154
+
155
+ Args:
156
+ species_stats: List of dictionaries containing species statistics per generation
157
+ save_path: Optional path to save the plot
158
+ """
159
+ plt.figure(figsize=(12, 8))
160
+
161
+ generations = range(len(species_stats))
162
+ avg_nodes = [stats['avg_nodes'] for stats in species_stats]
163
+ avg_connections = [stats['avg_connections'] for stats in species_stats]
164
+
165
+ plt.plot(generations, avg_nodes, label='Average Nodes', color='blue')
166
+ plt.plot(generations, avg_connections, label='Average Connections', color='green')
167
+
168
+ plt.xlabel('Generation')
169
+ plt.ylabel('Count')
170
+ plt.title('Network Complexity Over Time')
171
+ plt.legend()
172
+
173
+ if save_path:
174
+ plt.savefig(save_path, bbox_inches='tight')
175
+ plt.close()
176
+ else:
177
+ plt.show()
178
+
179
+ def plot_activation_distribution(genomes: List[Genome], save_path: str = None) -> None:
180
+ """Plot the distribution of activation functions across the population.
181
+
182
+ Args:
183
+ genomes: List of genomes to analyze
184
+ save_path: Optional path to save the plot
185
+ """
186
+ activation_counts = {}
187
+
188
+ # Count activation functions
189
+ for genome in genomes:
190
+ for node in genome.nodes.values():
191
+ activation_name = node.activation.__name__ if hasattr(node.activation, '__name__') else str(node.activation)
192
+ activation_counts[activation_name] = activation_counts.get(activation_name, 0) + 1
193
+
194
+ # Create bar plot
195
+ plt.figure(figsize=(10, 6))
196
+ plt.bar(activation_counts.keys(), activation_counts.values())
197
+ plt.xticks(rotation=45)
198
+ plt.xlabel('Activation Function')
199
+ plt.ylabel('Count')
200
+ plt.title('Distribution of Activation Functions')
201
+
202
+ if save_path:
203
+ plt.savefig(save_path, bbox_inches='tight')
204
+ plt.close()
205
+ else:
206
+ plt.show()