|  | """Analysis utilities for neural networks. | 
					
						
						|  |  | 
					
						
						|  | This module provides functions for analyzing neural network architectures, | 
					
						
						|  | including complexity measures and structural properties. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  | import networkx as nx | 
					
						
						|  | import matplotlib.pyplot as plt | 
					
						
						|  | from typing import Dict, Tuple, Union, Optional, List, Any | 
					
						
						|  | from .network import Network | 
					
						
						|  | from .genome import Genome | 
					
						
						|  | from collections import defaultdict | 
					
						
						|  | import os | 
					
						
						|  |  | 
					
						
						|  | def analyze_network_complexity(network: Network) -> Dict[str, Any]: | 
					
						
						|  | """Analyze the complexity of a neural network. | 
					
						
						|  |  | 
					
						
						|  | Computes various complexity metrics including: | 
					
						
						|  | 1. Number of nodes by type (input, hidden, output) | 
					
						
						|  | 2. Number of connections | 
					
						
						|  | 3. Network density | 
					
						
						|  | 4. Activation functions used | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | network: Network instance to analyze | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Dictionary containing complexity metrics | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | genome = network.genome | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | n_input = genome.input_size | 
					
						
						|  | n_hidden = len(genome.hidden_nodes) | 
					
						
						|  | n_output = genome.output_size | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | n_connections = len(genome.connections) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | n_possible = (n_input + n_hidden + n_output) * (n_hidden + n_output) | 
					
						
						|  | density = n_connections / n_possible if n_possible > 0 else 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | activation_functions = {'relu': n_hidden + n_output} | 
					
						
						|  |  | 
					
						
						|  | return { | 
					
						
						|  | 'n_input': n_input, | 
					
						
						|  | 'n_hidden': n_hidden, | 
					
						
						|  | 'n_output': n_output, | 
					
						
						|  | 'n_connections': n_connections, | 
					
						
						|  | 'density': density, | 
					
						
						|  | 'activation_functions': activation_functions | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | def get_network_stats(network: Network) -> Dict[str, float]: | 
					
						
						|  | """Get statistical measures of network properties. | 
					
						
						|  |  | 
					
						
						|  | Computes various statistics about the network structure and parameters: | 
					
						
						|  | - Number of nodes and connections | 
					
						
						|  | - Average and std of weights and biases | 
					
						
						|  | - Network density and depth | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | network: Network instance to analyze | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Dictionary containing network statistics | 
					
						
						|  | """ | 
					
						
						|  | stats = {} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | stats['n_nodes'] = network.n_nodes | 
					
						
						|  | stats['n_hidden'] = network.n_nodes - network.input_size - network.output_size | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | weights = np.array(list(network.weights.values())) | 
					
						
						|  | stats['n_connections'] = len(weights) | 
					
						
						|  | stats['weight_mean'] = float(np.mean(weights)) | 
					
						
						|  | stats['weight_std'] = float(np.std(weights)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | biases = np.array(list(network.bias.values())) | 
					
						
						|  | stats['n_biases'] = len(biases) | 
					
						
						|  | stats['bias_mean'] = float(np.mean(biases)) | 
					
						
						|  | stats['bias_std'] = float(np.std(biases)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | n_possible = network.n_nodes * (network.n_nodes - 1) | 
					
						
						|  | stats['density'] = len(weights) / n_possible if n_possible > 0 else 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | weight_matrix = network.weight_matrix | 
					
						
						|  | depth = 0 | 
					
						
						|  | visited = set(range(network.input_size)) | 
					
						
						|  | frontier = visited.copy() | 
					
						
						|  |  | 
					
						
						|  | while frontier and depth < network.n_nodes: | 
					
						
						|  | next_frontier = set() | 
					
						
						|  | for node in frontier: | 
					
						
						|  | for next_node in range(network.n_nodes): | 
					
						
						|  | if weight_matrix[node, next_node] != 0 and next_node not in visited: | 
					
						
						|  | next_frontier.add(next_node) | 
					
						
						|  | visited.add(next_node) | 
					
						
						|  | frontier = next_frontier | 
					
						
						|  | if frontier: | 
					
						
						|  | depth += 1 | 
					
						
						|  |  | 
					
						
						|  | stats['depth'] = depth | 
					
						
						|  |  | 
					
						
						|  | return stats | 
					
						
						|  |  | 
					
						
						|  | def visualize_network_architecture(network: Network, save_path: Optional[str] = None): | 
					
						
						|  | """Visualize the network architecture using networkx. | 
					
						
						|  |  | 
					
						
						|  | Creates a layered visualization of the neural network with: | 
					
						
						|  | - Input nodes in red (leftmost layer) | 
					
						
						|  | - Hidden nodes in blue (middle layer) | 
					
						
						|  | - Output nodes in green (rightmost layer) | 
					
						
						|  | - Connections shown as arrows with thickness proportional to weight | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | network: Network instance to visualize | 
					
						
						|  | save_path: Optional path to save the visualization | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | matplotlib figure object or None if visualization fails | 
					
						
						|  | """ | 
					
						
						|  | try: | 
					
						
						|  | import networkx as nx | 
					
						
						|  | import matplotlib.pyplot as plt | 
					
						
						|  |  | 
					
						
						|  | genome = network.genome | 
					
						
						|  | G = nx.DiGraph() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | n_inputs = len([node for node in genome.node_genes.values() if node.node_type == 'input']) | 
					
						
						|  | n_outputs = len([node for node in genome.node_genes.values() if node.node_type == 'output']) | 
					
						
						|  | hidden_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'hidden'] | 
					
						
						|  | n_hidden = len(hidden_nodes) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | node_spacing = 1.0 | 
					
						
						|  | layer_spacing = 2.0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pos = {} | 
					
						
						|  | node_colors = {} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | input_start_y = -(n_inputs - 1) * node_spacing / 2 | 
					
						
						|  | input_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'input'] | 
					
						
						|  | for i, node_idx in enumerate(input_nodes): | 
					
						
						|  | pos[node_idx] = (0, input_start_y + i * node_spacing) | 
					
						
						|  | node_colors[node_idx] = 'lightcoral' | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if hidden_nodes: | 
					
						
						|  | hidden_start_y = -(n_hidden - 1) * node_spacing / 2 | 
					
						
						|  | for i, node_idx in enumerate(hidden_nodes): | 
					
						
						|  | pos[node_idx] = (layer_spacing, hidden_start_y + i * node_spacing) | 
					
						
						|  | node_colors[node_idx] = 'lightblue' | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | output_start_y = -(n_outputs - 1) * node_spacing / 2 | 
					
						
						|  | output_nodes = [node.node_id for node in genome.node_genes.values() if node.node_type == 'output'] | 
					
						
						|  | for i, node_idx in enumerate(output_nodes): | 
					
						
						|  | pos[node_idx] = (2 * layer_spacing, output_start_y + i * node_spacing) | 
					
						
						|  | node_colors[node_idx] = 'lightgreen' | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | bias_node = [node.node_id for node in genome.node_genes.values() if node.node_type == 'bias'] | 
					
						
						|  | if bias_node: | 
					
						
						|  | pos[bias_node[0]] = (0, input_start_y - node_spacing) | 
					
						
						|  | node_colors[bias_node[0]] = 'yellow' | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for node_id in genome.node_genes: | 
					
						
						|  | G.add_node(node_id) | 
					
						
						|  | if node_id not in node_colors: | 
					
						
						|  | node_type = genome.node_genes[node_id].node_type | 
					
						
						|  | if node_type == 'input': | 
					
						
						|  | node_colors[node_id] = 'lightcoral' | 
					
						
						|  | elif node_type == 'hidden': | 
					
						
						|  | node_colors[node_id] = 'lightblue' | 
					
						
						|  | elif node_type == 'output': | 
					
						
						|  | node_colors[node_id] = 'lightgreen' | 
					
						
						|  | elif node_type == 'bias': | 
					
						
						|  | node_colors[node_id] = 'yellow' | 
					
						
						|  | else: | 
					
						
						|  | node_colors[node_id] = 'gray' | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if node_id not in pos: | 
					
						
						|  |  | 
					
						
						|  | pos[node_id] = (layer_spacing, 0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for conn in genome.connection_genes: | 
					
						
						|  | if conn.enabled: | 
					
						
						|  |  | 
					
						
						|  | width = abs(conn.weight) * 2.0 | 
					
						
						|  |  | 
					
						
						|  | color = 'red' if conn.weight < 0 else 'green' | 
					
						
						|  | alpha = min(abs(conn.weight), 1.0) | 
					
						
						|  | G.add_edge(conn.source, conn.target, weight=width, color=color, alpha=alpha) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | fig = plt.figure(figsize=(12, 8)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | nx.draw_networkx_nodes(G, pos, node_color=[node_colors[node] for node in G.nodes()], | 
					
						
						|  | node_size=800, alpha=0.8) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | edge_weights = [G.get_edge_data(edge[0], edge[1])['weight'] for edge in G.edges()] | 
					
						
						|  | if edge_weights: | 
					
						
						|  | max_weight = max(edge_weights) | 
					
						
						|  | normalized_weights = [3 * w / max_weight for w in edge_weights] | 
					
						
						|  | nx.draw_networkx_edges(G, pos, edge_color=[G.get_edge_data(edge[0], edge[1])['color'] for edge in G.edges()], | 
					
						
						|  | width=normalized_weights, | 
					
						
						|  | alpha=[G.get_edge_data(edge[0], edge[1])['alpha'] for edge in G.edges()], | 
					
						
						|  | arrows=True, arrowsize=20) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | nx.draw_networkx_labels(G, pos, font_size=10) | 
					
						
						|  |  | 
					
						
						|  | plt.title("Neural Network Architecture") | 
					
						
						|  | plt.axis('off') | 
					
						
						|  |  | 
					
						
						|  | if save_path: | 
					
						
						|  |  | 
					
						
						|  | os.makedirs(os.path.dirname(save_path), exist_ok=True) | 
					
						
						|  | plt.savefig(save_path, bbox_inches='tight', dpi=300) | 
					
						
						|  | plt.close(fig) | 
					
						
						|  |  | 
					
						
						|  | return fig | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"Error visualizing network: {str(e)}") | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | def plot_activation_distribution(population: List[Genome], save_path: Optional[str] = None): | 
					
						
						|  | """Plot the distribution of node types in the population. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | population: List of genomes in the population | 
					
						
						|  | save_path: Optional path to save the plot | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | matplotlib figure object or None if plotting fails | 
					
						
						|  | """ | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | node_type_counts = defaultdict(int) | 
					
						
						|  | for genome in population: | 
					
						
						|  | node_type_counts['input'] += genome.input_size | 
					
						
						|  | node_type_counts['hidden'] += len(genome.hidden_nodes) | 
					
						
						|  | node_type_counts['output'] += genome.output_size | 
					
						
						|  |  | 
					
						
						|  | if not node_type_counts: | 
					
						
						|  | print("No nodes found in population") | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | fig = plt.figure(figsize=(10, 6)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | node_types = list(node_type_counts.keys()) | 
					
						
						|  | counts = list(node_type_counts.values()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | colors = {'input': 'lightcoral', 'hidden': 'lightblue', 'output': 'lightgreen'} | 
					
						
						|  | plt.bar(node_types, counts, color=[colors[t] for t in node_types], alpha=0.7) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | plt.title('Distribution of Node Types in Population') | 
					
						
						|  | plt.xlabel('Node Type') | 
					
						
						|  | plt.ylabel('Total Count') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for i, count in enumerate(counts): | 
					
						
						|  | plt.text(i, count, str(count), ha='center', va='bottom') | 
					
						
						|  |  | 
					
						
						|  | plt.tight_layout() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if save_path: | 
					
						
						|  |  | 
					
						
						|  | os.makedirs(os.path.dirname(save_path), exist_ok=True) | 
					
						
						|  | plt.savefig(save_path, bbox_inches='tight', dpi=300) | 
					
						
						|  | plt.close(fig) | 
					
						
						|  |  | 
					
						
						|  | return fig | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"Error plotting activation distribution: {str(e)}") | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  | def analyze_evolution_trends(stats: Dict, save_dir: str) -> None: | 
					
						
						|  | """Analyze and plot evolution trends from training history. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | stats: Dictionary containing training statistics | 
					
						
						|  | save_dir: Directory to save plots | 
					
						
						|  | """ | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | os.makedirs(save_dir, exist_ok=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not stats or 'mean_fitness' not in stats or not stats['mean_fitness']: | 
					
						
						|  | print("No evolution stats available yet") | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | generations = list(range(len(stats['mean_fitness']))) | 
					
						
						|  | if not generations: | 
					
						
						|  | print("No generations completed yet") | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  | metrics = { | 
					
						
						|  | 'Fitness': { | 
					
						
						|  | 'mean': stats.get('mean_fitness', []), | 
					
						
						|  | 'best': stats.get('best_fitness', []) | 
					
						
						|  | }, | 
					
						
						|  | 'Complexity': { | 
					
						
						|  | 'mean': stats.get('mean_complexity', []), | 
					
						
						|  | 'best': stats.get('best_complexity', []) | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for metric_name, metric_data in metrics.items(): | 
					
						
						|  |  | 
					
						
						|  | if not metric_data['mean'] or not metric_data['best']: | 
					
						
						|  | print(f"No data available for {metric_name}") | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(generations) != len(metric_data['mean']) or len(generations) != len(metric_data['best']): | 
					
						
						|  | print(f"Data length mismatch for {metric_name}") | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | fig = plt.figure(figsize=(10, 6)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | plt.plot(generations, metric_data['mean'], label=f'Mean {metric_name}', alpha=0.7) | 
					
						
						|  | plt.plot(generations, metric_data['best'], label=f'Best {metric_name}', alpha=0.7) | 
					
						
						|  |  | 
					
						
						|  | plt.title(f'{metric_name} Over Generations') | 
					
						
						|  | plt.xlabel('Generation') | 
					
						
						|  | plt.ylabel(metric_name) | 
					
						
						|  | plt.legend() | 
					
						
						|  | plt.grid(True, alpha=0.3) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | save_path = os.path.join(save_dir, f'{metric_name.lower()}_trends.png') | 
					
						
						|  | plt.savefig(save_path, bbox_inches='tight', dpi=300) | 
					
						
						|  | plt.close(fig) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if 'n_species' in stats and stats['n_species']: | 
					
						
						|  | n_species = stats['n_species'] | 
					
						
						|  | if len(generations) == len(n_species): | 
					
						
						|  | fig = plt.figure(figsize=(10, 6)) | 
					
						
						|  | plt.plot(generations, n_species, label='Number of Species', alpha=0.7) | 
					
						
						|  | plt.title('Number of Species Over Generations') | 
					
						
						|  | plt.xlabel('Generation') | 
					
						
						|  | plt.ylabel('Number of Species') | 
					
						
						|  | plt.legend() | 
					
						
						|  | plt.grid(True, alpha=0.3) | 
					
						
						|  |  | 
					
						
						|  | save_path = os.path.join(save_dir, 'species_trends.png') | 
					
						
						|  | plt.savefig(save_path, bbox_inches='tight', dpi=300) | 
					
						
						|  | plt.close(fig) | 
					
						
						|  | else: | 
					
						
						|  | print("Species count data length mismatch") | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print(f"Error analyzing evolution trends: {str(e)}") | 
					
						
						|  |  |