BCI / app.py
Umang-Bansal's picture
Update app.py
dda0def verified
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import scipy
from scipy import signal
import pickle
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split, GridSearchCV
# Global variable to store the uploaded data
global_data = None
def get_data_preview(file):
global global_data
global_data = pd.read_csv(file.name)
global_data['label'] = np.nan # Initialize a label column
global_data['label'] = global_data['label'].astype(object) # Ensure the label column can hold different types
print("Data preview:\n", global_data.head())
return global_data.head()
def label_data(ranges):
global global_data
print("Ranges received for labeling:", ranges)
for i, (start, end, label) in enumerate(ranges.values):
start = int(start)
end = int(end)
print(f"Processing range {i}: start={start}, end={end}, label={label}")
if start < 0 or start >= len(global_data):
print(f"Invalid range: start={start}, end={end}, label={label}")
continue
if end >= len(global_data):
print(f"End index {end} exceeds data length {len(global_data)}. Adjusting to {len(global_data) - 1}.")
end = len(global_data) - 1
global_data.loc[start:end, 'label'] = label
print("Data after labeling:\n", global_data.tail())
return global_data.tail()
def preprocess_data():
global global_data
try:
global_data.drop(columns=global_data.columns[0], axis=1, inplace=True)
global_data.columns = ['raw_eeg', 'label']
raw_data = global_data['raw_eeg']
labels_old = global_data['label']
sampling_rate = 512
notch_freq = 50.0
lowcut, highcut = 0.5, 30.0
nyquist = (0.5 * sampling_rate)
notch_freq_normalized = notch_freq / nyquist
b_notch, a_notch = signal.iirnotch(notch_freq_normalized, Q=0.05, fs=sampling_rate)
lowcut_normalized = lowcut / nyquist
highcut_normalized = highcut / nyquist
b_bandpass, a_bandpass = signal.butter(4, [lowcut_normalized, highcut_normalized], btype='band')
features = []
labels = []
def calculate_psd_features(segment, sampling_rate):
f, psd_values = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
alpha_indices = np.where((f >= 8) & (f <= 13))
beta_indices = np.where((f >= 14) & (f <= 30))
theta_indices = np.where((f >= 4) & (f <= 7))
delta_indices = np.where((f >= 0.5) & (f <= 3))
energy_alpha = np.sum(psd_values[alpha_indices])
energy_beta = np.sum(psd_values[beta_indices])
energy_theta = np.sum(psd_values[theta_indices])
energy_delta = np.sum(psd_values[delta_indices])
alpha_beta_ratio = energy_alpha / energy_beta
return {
'E_alpha': energy_alpha,
'E_beta': energy_beta,
'E_theta': energy_theta,
'E_delta': energy_delta,
'alpha_beta_ratio': alpha_beta_ratio
}
def calculate_additional_features(segment, sampling_rate):
f, psd = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
peak_frequency = f[np.argmax(psd)]
spectral_centroid = np.sum(f * psd) / np.sum(psd)
log_f = np.log(f[1:])
log_psd = np.log(psd[1:])
spectral_slope = np.polyfit(log_f, log_psd, 1)[0]
return {
'peak_frequency': peak_frequency,
'spectral_centroid': spectral_centroid,
'spectral_slope': spectral_slope
}
for i in range(0, len(raw_data) - 512, 256):
print(f"Processing segment {i} to {i + 512}")
segment = raw_data.loc[i:i+512]
segment = pd.to_numeric(segment, errors='coerce')
segment = signal.filtfilt(b_notch, a_notch, segment)
segment = signal.filtfilt(b_bandpass, a_bandpass, segment)
segment_features = calculate_psd_features(segment, 512)
additional_features = calculate_additional_features(segment, 512)
segment_features = {**segment_features, **additional_features}
features.append(segment_features)
labels.append(labels_old[i])
columns = ['E_alpha', 'E_beta', 'E_theta', 'E_delta', 'alpha_beta_ratio', 'peak_frequency', 'spectral_centroid', 'spectral_slope']
df_features = pd.DataFrame(features, columns=columns)
df_features['label'] = labels
scaler = StandardScaler()
X_scaled = scaler.fit_transform(df_features.drop('label', axis=1))
df_scaled = pd.DataFrame(X_scaled, columns=columns)
df_scaled['label'] = df_features['label']
processed_data_filename = 'processed_data.csv'
df_scaled.to_csv(processed_data_filename, index=False)
scaler_filename = 'scaler.pkl'
with open(scaler_filename, 'wb') as file:
pickle.dump(scaler, file)
return "Data preprocessing complete! Download the processed data and scaler below.", processed_data_filename, scaler_filename
except Exception as e:
print(f"An error occurred during preprocessing: {e}")
return f"An error occurred during preprocessing: {e}", None, None
def train_model():
global global_data
try:
preprocess_status, processed_data_filename, scaler_filename = preprocess_data()
if processed_data_filename is None:
return preprocess_status, None, None
df_scaled = pd.read_csv(processed_data_filename)
X = df_scaled.drop('label', axis=1)
y = df_scaled['label']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
param_grid = {'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.1, 0.01, 0.001, 0.0001], 'kernel': ['rbf']}
svc = SVC(probability=True)
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, verbose=2, n_jobs=-1)
grid_search.fit(X_train, y_train)
model = grid_search.best_estimator_
model_filename = 'model.pkl'
with open(model_filename, 'wb') as file:
pickle.dump(model, file)
return "Training complete! Download the model and scaler below.", model_filename, scaler_filename
except Exception as e:
print(f"An error occurred during training: {e}")
return f"An error occurred during training: {e}", None, None
with gr.Blocks() as demo:
file_input = gr.File(label="Upload CSV File")
data_preview = gr.Dataframe(label="Data Preview", interactive=False)
ranges_input = gr.Dataframe(headers=["Start Index", "End Index", "Label"], label="Ranges for Labeling")
labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False)
training_status = gr.Textbox(label="Training Status")
model_file = gr.File(label="Download Trained Model")
scaler_file = gr.File(label="Download Scaler")
file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview)
label_button = gr.Button("Label Data")
label_button.click(label_data, inputs=[ranges_input], outputs=labeled_data_preview, queue=True)
train_button = gr.Button("Train Model")
train_button.click(train_model, outputs=[training_status, model_file, scaler_file])
demo.launch()