import numpy as np import plotly.graph_objects as go import mne from typing import Dict, Optional, Tuple import plotly.express as px import networkx as nx class BrainMapper: def __init__(self): self.montage = mne.channels.make_standard_montage('standard_1020') self._initialize_coordinates() def _initialize_coordinates(self): """Initialize electrode coordinates from standard montage""" pos = self.montage.get_positions() self.coords = pos['ch_pos'] # Extract x, y, z coordinates self.ch_names = list(self.coords.keys()) self.x_coords = np.array([self.coords[ch][0] for ch in self.ch_names]) self.y_coords = np.array([self.coords[ch][1] for ch in self.ch_names]) self.z_coords = np.array([self.coords[ch][2] for ch in self.ch_names]) def create_visualization(self, features: Dict, map_type: str = "2D Topographic") -> go.Figure: """Create brain visualization based on the specified type""" if map_type == "2D Topographic": return self._create_topographic_map(features) elif map_type == "3D Surface": return self._create_3d_surface(features) elif map_type == "Connectivity": return self._create_connectivity_map(features) else: raise ValueError(f"Unsupported map type: {map_type}") def _create_topographic_map(self, features: Dict) -> go.Figure: """Create 2D topographic map of brain activity""" # Extract band powers for visualization band_powers = features['band_powers'] # Create figure with subplots for each frequency band fig = go.Figure() for band_name, powers in band_powers.items(): # Create interpolated grid xi = np.linspace(min(self.x_coords), max(self.x_coords), 100) yi = np.linspace(min(self.y_coords), max(self.y_coords), 100) xi, yi = np.meshgrid(xi, yi) # Add contour plot for each band fig.add_trace(go.Contour( x=xi[0], y=yi[:, 0], z=powers.reshape(xi.shape), name=band_name, colorscale='Viridis', showscale=True, visible=(band_name == 'alpha') # Show alpha band by default )) # Add scatter plot for electrode positions fig.add_trace(go.Scatter( x=self.x_coords, y=self.y_coords, mode='markers+text', text=self.ch_names, textposition="top center", name='Electrodes', marker=dict(size=10, color='black'), visible=(band_name == 'alpha') )) # Update layout fig.update_layout( title="Brain Activity Topographic Map", xaxis_title="X Position", yaxis_title="Y Position", showlegend=True, updatemenus=[{ 'buttons': [ {'label': band, 'method': 'update', 'args': [{'visible': [i == j for i in range(len(band_powers)*2) for _ in range(2)]}]} for j, band in enumerate(band_powers.keys()) ], 'direction': 'down', 'showactive': True, }] ) return fig def _create_3d_surface(self, features: Dict) -> go.Figure: """Create 3D surface plot of brain activity""" # Create 3D surface using electrode positions fig = go.Figure() # Add surface plot fig.add_trace(go.Surface( x=self.x_coords.reshape(-1, 1), y=self.y_coords.reshape(-1, 1), z=features['statistics']['mean'].reshape(-1, 1), colorscale='Viridis', name='Brain Activity' )) # Add scatter plot for electrode positions fig.add_trace(go.Scatter3d( x=self.x_coords, y=self.y_coords, z=self.z_coords, mode='markers+text', text=self.ch_names, marker=dict(size=5, color='red'), name='Electrodes' )) # Update layout fig.update_layout( title="3D Brain Activity Surface", scene=dict( xaxis_title="X Position", yaxis_title="Y Position", zaxis_title="Activity Level", camera=dict( up=dict(x=0, y=0, z=1), center=dict(x=0, y=0, z=0), eye=dict(x=1.5, y=1.5, z=1.5) ) ) ) return fig def _create_connectivity_map(self, features: Dict) -> go.Figure: """Create brain connectivity visualization""" # Extract connectivity matrix connectivity = features['connectivity']['correlation'] # Create graph G = nx.from_numpy_array(connectivity) pos = nx.spring_layout(G, k=1, iterations=50) # Create edge trace edge_x = [] edge_y = [] for edge in G.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) edge_trace = go.Scatter( x=edge_x, y=edge_y, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines') # Create node trace node_x = [] node_y = [] for node in G.nodes(): x, y = pos[node] node_x.append(x) node_y.append(y) node_trace = go.Scatter( x=node_x, y=node_y, mode='markers+text', hoverinfo='text', text=self.ch_names, marker=dict( showscale=True, colorscale='YlOrRd', size=10, colorbar=dict( thickness=15, title='Node Connections', xanchor='left', titleside='right' ) ) ) # Color node points by the number of connections node_adjacencies = [] for node, adjacencies in enumerate(G.adjacency()): node_adjacencies.append(len(adjacencies[1])) node_trace.marker.color = node_adjacencies # Create figure fig = go.Figure(data=[edge_trace, node_trace], layout=go.Layout( title='Brain Connectivity Network', showlegend=False, hovermode='closest', margin=dict(b=20,l=5,r=5,t=40), xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False) )) return fig