File size: 8,178 Bytes
baaf109 980d5ca baaf109 1fb8a6d 980d5ca baaf109 980d5ca baaf109 1fb8a6d baaf109 a080e03 baaf109 980d5ca baaf109 980d5ca baaf109 1fb8a6d baaf109 980d5ca baaf109 1fb8a6d baaf109 1fb8a6d baaf109 980d5ca baaf109 1fb8a6d baaf109 1fb8a6d baaf109 1fb8a6d baaf109 1fb8a6d baaf109 1fb8a6d 980d5ca baaf109 1fb8a6d baaf109 1fb8a6d baaf109 1fb8a6d baaf109 1fb8a6d 980d5ca 1fb8a6d baaf109 1fb8a6d baaf109 1fb8a6d baaf109 1fb8a6d baaf109 1fb8a6d baaf109 1fb8a6d baaf109 1fb8a6d baaf109 a080e03 980d5ca baaf109 1fb8a6d a080e03 1fb8a6d baaf109 1fb8a6d baaf109 980d5ca baaf109 980d5ca baaf109 1fb8a6d baaf109 980d5ca baaf109 1fb8a6d baaf109 a080e03 1fb8a6d a080e03 baaf109 1fb8a6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import gradio as gr
import numpy as np
import mne
import matplotlib.pyplot as plt
from scipy import signal
from scipy.stats import skew, kurtosis
import pandas as pd
from pathlib import Path
import tempfile
import os
import logging
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
class EEGProcessor:
def __init__(self):
self.sampling_rate = 250
self.freq_bands = {
'delta': (0.5, 4),
'theta': (4, 8),
'alpha': (8, 13),
'beta': (13, 30),
'gamma': (30, 50)
}
logger.info("EEG Processor initialized")
def load_eeg(self, file_path):
"""Load and validate EEG data from file"""
try:
if not os.path.exists(file_path):
raise ValueError("File does not exist")
raw = mne.io.read_raw_edf(file_path, preload=True)
self.sampling_rate = raw.info['sfreq']
logger.info(f"Loaded EEG with {len(raw.ch_names)} channels at {self.sampling_rate} Hz")
return raw
except Exception as e:
logger.error(f"Error loading EEG file: {str(e)}")
raise ValueError(f"Error loading EEG file: {str(e)}")
def preprocess_signal(self, raw):
"""Apply preprocessing to the EEG signal"""
try:
logger.info("Starting signal preprocessing")
raw.filter(l_freq=0.5, h_freq=50., fir_design='firwin')
raw.notch_filter(freqs=[50, 60])
return raw
except Exception as e:
logger.error(f"Error in preprocessing: {str(e)}")
raise
def _calculate_band_power(self, data, low_freq, high_freq):
"""
Calculate power in specific frequency band for each channel
Returns: Array of band powers for each channel
"""
freqs, psd = signal.welch(data, fs=self.sampling_rate)
idx = np.logical_and(freqs >= low_freq, freqs <= high_freq)
band_power = np.trapz(psd[:, idx], freqs[idx], axis=1)
return band_power
def extract_features(self, raw):
"""Extract time and frequency domain features"""
try:
logger.info("Starting feature extraction")
data = raw.get_data()
n_channels = data.shape[0]
# Initialize feature dictionary with channel names
features = {
'channel_names': raw.ch_names,
'n_channels': n_channels
}
# Calculate time domain features for each channel
features['mean'] = np.mean(data, axis=1)
features['variance'] = np.var(data, axis=1)
features['skewness'] = skew(data, axis=1)
features['kurtosis'] = kurtosis(data, axis=1)
# Calculate frequency domain features for each channel
band_powers = {}
for band_name, (low_freq, high_freq) in self.freq_bands.items():
band_powers[band_name] = self._calculate_band_power(data, low_freq, high_freq)
features['band_powers'] = band_powers
logger.info("Feature extraction completed")
return features
except Exception as e:
logger.error(f"Error in feature extraction: {str(e)}")
raise
def create_visualization(raw, features):
"""Create visualization plots with multi-channel support"""
try:
fig = plt.figure(figsize=(15, 20))
n_channels = features['n_channels']
# Plot 1: Raw EEG signal (first 10 seconds)
ax1 = plt.subplot(3, 1, 1)
data = raw.get_data()
times = np.arange(min(data.shape[1], int(10 * raw.info['sfreq']))) / raw.info['sfreq']
plot_data = data[:, :len(times)]
# Add offset to separate channels
offsets = np.arange(n_channels) * np.std(plot_data) * 3
for i in range(n_channels):
ax1.plot(times, plot_data[i] + offsets[i], label=features['channel_names'][i])
ax1.set_title('Raw EEG Signal (first 10 seconds)')
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Channel')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
# Plot 2: Power spectrum
ax2 = plt.subplot(3, 1, 2)
freqs, psd = signal.welch(data, fs=raw.info['sfreq'])
for i in range(n_channels):
ax2.semilogy(freqs, psd[i], label=features['channel_names'][i])
ax2.set_title('Power Spectrum')
ax2.set_xlabel('Frequency (Hz)')
ax2.set_ylabel('Power Spectral Density')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
# Plot 3: Average band powers across channels
ax3 = plt.subplot(3, 1, 3)
band_powers = features['band_powers']
# Calculate mean band power across channels
mean_powers = {band: np.mean(powers) for band, powers in band_powers.items()}
std_powers = {band: np.std(powers) for band, powers in band_powers.items()}
bands = list(mean_powers.keys())
means = list(mean_powers.values())
stds = list(std_powers.values())
bars = ax3.bar(bands, means, yerr=stds, capsize=5)
ax3.set_title('Average Band Powers Across Channels')
ax3.set_xlabel('Frequency Band')
ax3.set_ylabel('Power')
plt.tight_layout()
return fig
except Exception as e:
logger.error(f"Error creating visualization: {str(e)}")
raise
def process_eeg(file_obj):
"""Main processing function for Gradio interface"""
try:
logger.info("Starting EEG processing")
# Handle file content
if isinstance(file_obj, bytes):
file_content = file_obj
else:
file_content = file_obj.read()
# Create temporary file
with tempfile.NamedTemporaryFile(suffix='.edf', delete=False) as tmp_file:
tmp_file.write(file_content)
tmp_path = tmp_file.name
# Process EEG
processor = EEGProcessor()
raw = processor.load_eeg(tmp_path)
raw = processor.preprocess_signal(raw)
features = processor.extract_features(raw)
# Create visualizations
fig = create_visualization(raw, features)
# Create feature summary table
feature_summary = pd.DataFrame({
'Channel': features['channel_names'],
'Mean': features['mean'],
'Variance': features['variance'],
'Skewness': features['skewness'],
'Kurtosis': features['kurtosis']
})
# Add band powers to summary
for band, powers in features['band_powers'].items():
feature_summary[f'{band}_power'] = powers
# Clean up
os.unlink(tmp_path)
logger.info("Processing completed successfully")
return fig, feature_summary.to_string()
except Exception as e:
logger.error(f"Error in EEG processing: {str(e)}")
raise gr.Error(f"Error processing EEG: {str(e)}")
# Create Gradio interface
with gr.Blocks(title="EEG Signal Analysis") as iface:
gr.Markdown("# EEG Signal Analysis Tool")
gr.Markdown("Upload an EEG file (.edf format) for analysis")
with gr.Row():
file_input = gr.File(
label="Upload EEG File",
file_types=[".edf"],
type="binary"
)
with gr.Row():
analyze_btn = gr.Button("Analyze EEG")
with gr.Row():
plot_output = gr.Plot(label="EEG Analysis Plots")
feature_output = gr.Textbox(label="Feature Analysis", lines=10)
analyze_btn.click(
fn=process_eeg,
inputs=[file_input],
outputs=[plot_output, feature_output]
)
# Launch the application
if __name__ == "__main__":
iface.launch(share=True) |