|
import gradio as gr |
|
import numpy as np |
|
import mne |
|
import matplotlib.pyplot as plt |
|
from scipy import signal |
|
from scipy.stats import skew, kurtosis |
|
import pandas as pd |
|
from pathlib import Path |
|
import tempfile |
|
import os |
|
import logging |
|
from datetime import datetime |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class EEGProcessor: |
|
def __init__(self): |
|
self.sampling_rate = 250 |
|
self.freq_bands = { |
|
'delta': (0.5, 4), |
|
'theta': (4, 8), |
|
'alpha': (8, 13), |
|
'beta': (13, 30), |
|
'gamma': (30, 50) |
|
} |
|
logger.info("EEG Processor initialized") |
|
|
|
def load_eeg(self, file_path): |
|
"""Load and validate EEG data from file""" |
|
try: |
|
if not os.path.exists(file_path): |
|
raise ValueError("File does not exist") |
|
|
|
raw = mne.io.read_raw_edf(file_path, preload=True) |
|
self.sampling_rate = raw.info['sfreq'] |
|
logger.info(f"Loaded EEG with {len(raw.ch_names)} channels at {self.sampling_rate} Hz") |
|
return raw |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading EEG file: {str(e)}") |
|
raise ValueError(f"Error loading EEG file: {str(e)}") |
|
|
|
def preprocess_signal(self, raw): |
|
"""Apply preprocessing to the EEG signal""" |
|
try: |
|
logger.info("Starting signal preprocessing") |
|
raw.filter(l_freq=0.5, h_freq=50., fir_design='firwin') |
|
raw.notch_filter(freqs=[50, 60]) |
|
return raw |
|
except Exception as e: |
|
logger.error(f"Error in preprocessing: {str(e)}") |
|
raise |
|
|
|
def _calculate_band_power(self, data, low_freq, high_freq): |
|
""" |
|
Calculate power in specific frequency band for each channel |
|
Returns: Array of band powers for each channel |
|
""" |
|
freqs, psd = signal.welch(data, fs=self.sampling_rate) |
|
idx = np.logical_and(freqs >= low_freq, freqs <= high_freq) |
|
band_power = np.trapz(psd[:, idx], freqs[idx], axis=1) |
|
return band_power |
|
|
|
def extract_features(self, raw): |
|
"""Extract time and frequency domain features""" |
|
try: |
|
logger.info("Starting feature extraction") |
|
data = raw.get_data() |
|
n_channels = data.shape[0] |
|
|
|
|
|
features = { |
|
'channel_names': raw.ch_names, |
|
'n_channels': n_channels |
|
} |
|
|
|
|
|
features['mean'] = np.mean(data, axis=1) |
|
features['variance'] = np.var(data, axis=1) |
|
features['skewness'] = skew(data, axis=1) |
|
features['kurtosis'] = kurtosis(data, axis=1) |
|
|
|
|
|
band_powers = {} |
|
for band_name, (low_freq, high_freq) in self.freq_bands.items(): |
|
band_powers[band_name] = self._calculate_band_power(data, low_freq, high_freq) |
|
|
|
features['band_powers'] = band_powers |
|
logger.info("Feature extraction completed") |
|
return features |
|
|
|
except Exception as e: |
|
logger.error(f"Error in feature extraction: {str(e)}") |
|
raise |
|
|
|
def create_visualization(raw, features): |
|
"""Create visualization plots with multi-channel support""" |
|
try: |
|
fig = plt.figure(figsize=(15, 20)) |
|
n_channels = features['n_channels'] |
|
|
|
|
|
ax1 = plt.subplot(3, 1, 1) |
|
data = raw.get_data() |
|
times = np.arange(min(data.shape[1], int(10 * raw.info['sfreq']))) / raw.info['sfreq'] |
|
plot_data = data[:, :len(times)] |
|
|
|
|
|
offsets = np.arange(n_channels) * np.std(plot_data) * 3 |
|
for i in range(n_channels): |
|
ax1.plot(times, plot_data[i] + offsets[i], label=features['channel_names'][i]) |
|
|
|
ax1.set_title('Raw EEG Signal (first 10 seconds)') |
|
ax1.set_xlabel('Time (s)') |
|
ax1.set_ylabel('Channel') |
|
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left') |
|
|
|
|
|
ax2 = plt.subplot(3, 1, 2) |
|
freqs, psd = signal.welch(data, fs=raw.info['sfreq']) |
|
for i in range(n_channels): |
|
ax2.semilogy(freqs, psd[i], label=features['channel_names'][i]) |
|
|
|
ax2.set_title('Power Spectrum') |
|
ax2.set_xlabel('Frequency (Hz)') |
|
ax2.set_ylabel('Power Spectral Density') |
|
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left') |
|
|
|
|
|
ax3 = plt.subplot(3, 1, 3) |
|
band_powers = features['band_powers'] |
|
|
|
|
|
mean_powers = {band: np.mean(powers) for band, powers in band_powers.items()} |
|
std_powers = {band: np.std(powers) for band, powers in band_powers.items()} |
|
|
|
bands = list(mean_powers.keys()) |
|
means = list(mean_powers.values()) |
|
stds = list(std_powers.values()) |
|
|
|
bars = ax3.bar(bands, means, yerr=stds, capsize=5) |
|
ax3.set_title('Average Band Powers Across Channels') |
|
ax3.set_xlabel('Frequency Band') |
|
ax3.set_ylabel('Power') |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
except Exception as e: |
|
logger.error(f"Error creating visualization: {str(e)}") |
|
raise |
|
|
|
def process_eeg(file_obj): |
|
"""Main processing function for Gradio interface""" |
|
try: |
|
logger.info("Starting EEG processing") |
|
|
|
|
|
if isinstance(file_obj, bytes): |
|
file_content = file_obj |
|
else: |
|
file_content = file_obj.read() |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.edf', delete=False) as tmp_file: |
|
tmp_file.write(file_content) |
|
tmp_path = tmp_file.name |
|
|
|
|
|
processor = EEGProcessor() |
|
raw = processor.load_eeg(tmp_path) |
|
raw = processor.preprocess_signal(raw) |
|
features = processor.extract_features(raw) |
|
|
|
|
|
fig = create_visualization(raw, features) |
|
|
|
|
|
feature_summary = pd.DataFrame({ |
|
'Channel': features['channel_names'], |
|
'Mean': features['mean'], |
|
'Variance': features['variance'], |
|
'Skewness': features['skewness'], |
|
'Kurtosis': features['kurtosis'] |
|
}) |
|
|
|
|
|
for band, powers in features['band_powers'].items(): |
|
feature_summary[f'{band}_power'] = powers |
|
|
|
|
|
os.unlink(tmp_path) |
|
logger.info("Processing completed successfully") |
|
|
|
return fig, feature_summary.to_string() |
|
|
|
except Exception as e: |
|
logger.error(f"Error in EEG processing: {str(e)}") |
|
raise gr.Error(f"Error processing EEG: {str(e)}") |
|
|
|
|
|
with gr.Blocks(title="EEG Signal Analysis") as iface: |
|
gr.Markdown("# EEG Signal Analysis Tool") |
|
gr.Markdown("Upload an EEG file (.edf format) for analysis") |
|
|
|
with gr.Row(): |
|
file_input = gr.File( |
|
label="Upload EEG File", |
|
file_types=[".edf"], |
|
type="binary" |
|
) |
|
|
|
with gr.Row(): |
|
analyze_btn = gr.Button("Analyze EEG") |
|
|
|
with gr.Row(): |
|
plot_output = gr.Plot(label="EEG Analysis Plots") |
|
feature_output = gr.Textbox(label="Feature Analysis", lines=10) |
|
|
|
analyze_btn.click( |
|
fn=process_eeg, |
|
inputs=[file_input], |
|
outputs=[plot_output, feature_output] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch(share=True) |