invincible-jha's picture
Update app.py
1fb8a6d verified
import gradio as gr
import numpy as np
import mne
import matplotlib.pyplot as plt
from scipy import signal
from scipy.stats import skew, kurtosis
import pandas as pd
from pathlib import Path
import tempfile
import os
import logging
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
class EEGProcessor:
def __init__(self):
self.sampling_rate = 250
self.freq_bands = {
'delta': (0.5, 4),
'theta': (4, 8),
'alpha': (8, 13),
'beta': (13, 30),
'gamma': (30, 50)
}
logger.info("EEG Processor initialized")
def load_eeg(self, file_path):
"""Load and validate EEG data from file"""
try:
if not os.path.exists(file_path):
raise ValueError("File does not exist")
raw = mne.io.read_raw_edf(file_path, preload=True)
self.sampling_rate = raw.info['sfreq']
logger.info(f"Loaded EEG with {len(raw.ch_names)} channels at {self.sampling_rate} Hz")
return raw
except Exception as e:
logger.error(f"Error loading EEG file: {str(e)}")
raise ValueError(f"Error loading EEG file: {str(e)}")
def preprocess_signal(self, raw):
"""Apply preprocessing to the EEG signal"""
try:
logger.info("Starting signal preprocessing")
raw.filter(l_freq=0.5, h_freq=50., fir_design='firwin')
raw.notch_filter(freqs=[50, 60])
return raw
except Exception as e:
logger.error(f"Error in preprocessing: {str(e)}")
raise
def _calculate_band_power(self, data, low_freq, high_freq):
"""
Calculate power in specific frequency band for each channel
Returns: Array of band powers for each channel
"""
freqs, psd = signal.welch(data, fs=self.sampling_rate)
idx = np.logical_and(freqs >= low_freq, freqs <= high_freq)
band_power = np.trapz(psd[:, idx], freqs[idx], axis=1)
return band_power
def extract_features(self, raw):
"""Extract time and frequency domain features"""
try:
logger.info("Starting feature extraction")
data = raw.get_data()
n_channels = data.shape[0]
# Initialize feature dictionary with channel names
features = {
'channel_names': raw.ch_names,
'n_channels': n_channels
}
# Calculate time domain features for each channel
features['mean'] = np.mean(data, axis=1)
features['variance'] = np.var(data, axis=1)
features['skewness'] = skew(data, axis=1)
features['kurtosis'] = kurtosis(data, axis=1)
# Calculate frequency domain features for each channel
band_powers = {}
for band_name, (low_freq, high_freq) in self.freq_bands.items():
band_powers[band_name] = self._calculate_band_power(data, low_freq, high_freq)
features['band_powers'] = band_powers
logger.info("Feature extraction completed")
return features
except Exception as e:
logger.error(f"Error in feature extraction: {str(e)}")
raise
def create_visualization(raw, features):
"""Create visualization plots with multi-channel support"""
try:
fig = plt.figure(figsize=(15, 20))
n_channels = features['n_channels']
# Plot 1: Raw EEG signal (first 10 seconds)
ax1 = plt.subplot(3, 1, 1)
data = raw.get_data()
times = np.arange(min(data.shape[1], int(10 * raw.info['sfreq']))) / raw.info['sfreq']
plot_data = data[:, :len(times)]
# Add offset to separate channels
offsets = np.arange(n_channels) * np.std(plot_data) * 3
for i in range(n_channels):
ax1.plot(times, plot_data[i] + offsets[i], label=features['channel_names'][i])
ax1.set_title('Raw EEG Signal (first 10 seconds)')
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Channel')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
# Plot 2: Power spectrum
ax2 = plt.subplot(3, 1, 2)
freqs, psd = signal.welch(data, fs=raw.info['sfreq'])
for i in range(n_channels):
ax2.semilogy(freqs, psd[i], label=features['channel_names'][i])
ax2.set_title('Power Spectrum')
ax2.set_xlabel('Frequency (Hz)')
ax2.set_ylabel('Power Spectral Density')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
# Plot 3: Average band powers across channels
ax3 = plt.subplot(3, 1, 3)
band_powers = features['band_powers']
# Calculate mean band power across channels
mean_powers = {band: np.mean(powers) for band, powers in band_powers.items()}
std_powers = {band: np.std(powers) for band, powers in band_powers.items()}
bands = list(mean_powers.keys())
means = list(mean_powers.values())
stds = list(std_powers.values())
bars = ax3.bar(bands, means, yerr=stds, capsize=5)
ax3.set_title('Average Band Powers Across Channels')
ax3.set_xlabel('Frequency Band')
ax3.set_ylabel('Power')
plt.tight_layout()
return fig
except Exception as e:
logger.error(f"Error creating visualization: {str(e)}")
raise
def process_eeg(file_obj):
"""Main processing function for Gradio interface"""
try:
logger.info("Starting EEG processing")
# Handle file content
if isinstance(file_obj, bytes):
file_content = file_obj
else:
file_content = file_obj.read()
# Create temporary file
with tempfile.NamedTemporaryFile(suffix='.edf', delete=False) as tmp_file:
tmp_file.write(file_content)
tmp_path = tmp_file.name
# Process EEG
processor = EEGProcessor()
raw = processor.load_eeg(tmp_path)
raw = processor.preprocess_signal(raw)
features = processor.extract_features(raw)
# Create visualizations
fig = create_visualization(raw, features)
# Create feature summary table
feature_summary = pd.DataFrame({
'Channel': features['channel_names'],
'Mean': features['mean'],
'Variance': features['variance'],
'Skewness': features['skewness'],
'Kurtosis': features['kurtosis']
})
# Add band powers to summary
for band, powers in features['band_powers'].items():
feature_summary[f'{band}_power'] = powers
# Clean up
os.unlink(tmp_path)
logger.info("Processing completed successfully")
return fig, feature_summary.to_string()
except Exception as e:
logger.error(f"Error in EEG processing: {str(e)}")
raise gr.Error(f"Error processing EEG: {str(e)}")
# Create Gradio interface
with gr.Blocks(title="EEG Signal Analysis") as iface:
gr.Markdown("# EEG Signal Analysis Tool")
gr.Markdown("Upload an EEG file (.edf format) for analysis")
with gr.Row():
file_input = gr.File(
label="Upload EEG File",
file_types=[".edf"],
type="binary"
)
with gr.Row():
analyze_btn = gr.Button("Analyze EEG")
with gr.Row():
plot_output = gr.Plot(label="EEG Analysis Plots")
feature_output = gr.Textbox(label="Feature Analysis", lines=10)
analyze_btn.click(
fn=process_eeg,
inputs=[file_input],
outputs=[plot_output, feature_output]
)
# Launch the application
if __name__ == "__main__":
iface.launch(share=True)