BCI / app.py
Umang-Bansal's picture
Create app.py
5e22eaa verified
raw
history blame
5.33 kB
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
import scipy
from scipy import signal
import pickle
def get_data_preview(file):
data = pd.read_csv(file.name)
return data.head()
def label_data(file, start, end, label):
data = pd.read_csv(file.name)
data.loc[start:end, 'label'] = label # Label the specified range
return data
def preprocess_data(data):
data.drop(columns=data.columns[0], axis=1, inplace=True)
data.columns = ['raw_eeg', 'label']
raw_data = data['raw_eeg']
labels_old = data['label']
sampling_rate = 512
notch_freq = 50.0
lowcut, highcut = 0.5, 30.0
nyquist = (0.5 * sampling_rate)
notch_freq_normalized = notch_freq / nyquist
b_notch, a_notch = signal.iirnotch(notch_freq_normalized, Q=0.05, fs=sampling_rate)
lowcut_normalized = lowcut / nyquist
highcut_normalized = highcut / nyquist
b_bandpass, a_bandpass = signal.butter(4, [lowcut_normalized, highcut_normalized], btype='band')
features = []
labels = []
def calculate_psd_features(segment, sampling_rate):
f, psd_values = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
alpha_indices = np.where((f >= 8) & (f <= 13))
beta_indices = np.where((f >= 14) & (f <= 30))
theta_indices = np.where((f >= 4) & (f <= 7))
delta_indices = np.where((f >= 0.5) & (f <= 3))
energy_alpha = np.sum(psd_values[alpha_indices])
energy_beta = np.sum(psd_values[beta_indices])
energy_theta = np.sum(psd_values[theta_indices])
energy_delta = np.sum(psd_values[delta_indices])
alpha_beta_ratio = energy_alpha / energy_beta
return {
'E_alpha': energy_alpha,
'E_beta': energy_beta,
'E_theta': energy_theta,
'E_delta': energy_delta,
'alpha_beta_ratio': alpha_beta_ratio
}
def calculate_additional_features(segment, sampling_rate):
f, psd = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
peak_frequency = f[np.argmax(psd)]
spectral_centroid = np.sum(f * psd) / np.sum(psd)
log_f = np.log(f[1:])
log_psd = np.log(psd[1:])
spectral_slope = np.polyfit(log_f, log_psd, 1)[0]
return {
'peak_frequency': peak_frequency,
'spectral_centroid': spectral_centroid,
'spectral_slope': spectral_slope
}
for i in range(0, len(raw_data) - 512, 256):
segment = raw_data.loc[i:i+512]
segment = pd.to_numeric(segment, errors='coerce')
segment = signal.filtfilt(b_notch, a_notch, segment)
segment = signal.filtfilt(b_bandpass, a_bandpass, segment)
segment_features = calculate_psd_features(segment, 512)
additional_features = calculate_additional_features(segment, 512)
segment_features = {**segment_features, **additional_features}
features.append(segment_features)
labels.append(labels_old[i])
columns = ['E_alpha', 'E_beta', 'E_theta', 'E_delta', 'alpha_beta_ratio', 'peak_frequency', 'spectral_centroid', 'spectral_slope']
df_features = pd.DataFrame(features, columns=columns)
df_features['label'] = labels
return df_features
def train_model(data):
scaler = StandardScaler()
X = data.drop('label', axis=1)
y = data['label']
X_scaled = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
param_grid = {'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.1, 0.01, 0.001, 0.0001], 'kernel': ['rbf']}
svc = SVC(probability=True)
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, verbose=2, n_jobs=-1)
grid_search.fit(X_train, y_train)
model = grid_search.best_estimator_
model_filename = 'model.pkl'
scaler_filename = 'scaler.pkl'
with open(model_filename, 'wb') as file:
pickle.dump(model, file)
with open(scaler_filename, 'wb') as file:
pickle.dump(scaler, file)
return f"Training complete! Model and scaler saved.", gr.File(model_filename), gr.File(scaler_filename)
with gr.Blocks() as demo:
file_input = gr.File(label="Upload CSV File")
data_preview = gr.Dataframe(label="Data Preview", interactive=False)
start_input = gr.Number(label="Start Index", value=0)
end_input = gr.Number(label="End Index", value=100)
label_input = gr.Number(label="Label Value", value=1)
labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False)
training_status = gr.Textbox(label="Training Status")
model_file = gr.File(label="Download Trained Model")
scaler_file = gr.File(label="Download Scaler")
file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview)
label_button = gr.Button("Label Data")
label_button.click(label_data, inputs=[file_input, start_input, end_input, label_input], outputs=labeled_data_preview)
train_button = gr.Button("Train Model")
train_button.click(train_model, inputs=labeled_data_preview, outputs=[training_status, model_file, scaler_file])
demo.launch()