|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
from sklearn.preprocessing import StandardScaler |
|
import scipy |
|
from scipy import signal |
|
import pickle |
|
from sklearn.svm import SVC |
|
from sklearn.model_selection import train_test_split, GridSearchCV |
|
|
|
|
|
global_data = None |
|
|
|
def get_data_preview(file): |
|
global global_data |
|
global_data = pd.read_csv(file.name) |
|
global_data['label'] = np.nan |
|
global_data['label'] = global_data['label'].astype(object) |
|
print("Data preview:\n", global_data.head()) |
|
return global_data.head() |
|
|
|
def label_data(ranges): |
|
global global_data |
|
print("Ranges received for labeling:", ranges) |
|
for i, (start, end, label) in enumerate(ranges.values): |
|
start = int(start) |
|
end = int(end) |
|
print(f"Processing range {i}: start={start}, end={end}, label={label}") |
|
if start < 0 or start >= len(global_data): |
|
print(f"Invalid range: start={start}, end={end}, label={label}") |
|
continue |
|
if end >= len(global_data): |
|
print(f"End index {end} exceeds data length {len(global_data)}. Adjusting to {len(global_data) - 1}.") |
|
end = len(global_data) - 1 |
|
global_data.loc[start:end, 'label'] = label |
|
print("Data after labeling:\n", global_data.tail()) |
|
return global_data.tail() |
|
|
|
def preprocess_data(): |
|
global global_data |
|
try: |
|
global_data.drop(columns=global_data.columns[0], axis=1, inplace=True) |
|
global_data.columns = ['raw_eeg', 'label'] |
|
raw_data = global_data['raw_eeg'] |
|
labels_old = global_data['label'] |
|
|
|
sampling_rate = 512 |
|
notch_freq = 50.0 |
|
lowcut, highcut = 0.5, 30.0 |
|
|
|
nyquist = (0.5 * sampling_rate) |
|
notch_freq_normalized = notch_freq / nyquist |
|
b_notch, a_notch = signal.iirnotch(notch_freq_normalized, Q=0.05, fs=sampling_rate) |
|
|
|
lowcut_normalized = lowcut / nyquist |
|
highcut_normalized = highcut / nyquist |
|
b_bandpass, a_bandpass = signal.butter(4, [lowcut_normalized, highcut_normalized], btype='band') |
|
|
|
features = [] |
|
labels = [] |
|
|
|
def calculate_psd_features(segment, sampling_rate): |
|
f, psd_values = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment)) |
|
alpha_indices = np.where((f >= 8) & (f <= 13)) |
|
beta_indices = np.where((f >= 14) & (f <= 30)) |
|
theta_indices = np.where((f >= 4) & (f <= 7)) |
|
delta_indices = np.where((f >= 0.5) & (f <= 3)) |
|
energy_alpha = np.sum(psd_values[alpha_indices]) |
|
energy_beta = np.sum(psd_values[beta_indices]) |
|
energy_theta = np.sum(psd_values[theta_indices]) |
|
energy_delta = np.sum(psd_values[delta_indices]) |
|
alpha_beta_ratio = energy_alpha / energy_beta |
|
return { |
|
'E_alpha': energy_alpha, |
|
'E_beta': energy_beta, |
|
'E_theta': energy_theta, |
|
'E_delta': energy_delta, |
|
'alpha_beta_ratio': alpha_beta_ratio |
|
} |
|
|
|
def calculate_additional_features(segment, sampling_rate): |
|
f, psd = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment)) |
|
peak_frequency = f[np.argmax(psd)] |
|
spectral_centroid = np.sum(f * psd) / np.sum(psd) |
|
log_f = np.log(f[1:]) |
|
log_psd = np.log(psd[1:]) |
|
spectral_slope = np.polyfit(log_f, log_psd, 1)[0] |
|
return { |
|
'peak_frequency': peak_frequency, |
|
'spectral_centroid': spectral_centroid, |
|
'spectral_slope': spectral_slope |
|
} |
|
|
|
for i in range(0, len(raw_data) - 512, 256): |
|
print(f"Processing segment {i} to {i + 512}") |
|
segment = raw_data.loc[i:i+512] |
|
segment = pd.to_numeric(segment, errors='coerce') |
|
segment = signal.filtfilt(b_notch, a_notch, segment) |
|
segment = signal.filtfilt(b_bandpass, a_bandpass, segment) |
|
segment_features = calculate_psd_features(segment, 512) |
|
additional_features = calculate_additional_features(segment, 512) |
|
segment_features = {**segment_features, **additional_features} |
|
features.append(segment_features) |
|
labels.append(labels_old[i]) |
|
|
|
columns = ['E_alpha', 'E_beta', 'E_theta', 'E_delta', 'alpha_beta_ratio', 'peak_frequency', 'spectral_centroid', 'spectral_slope'] |
|
df_features = pd.DataFrame(features, columns=columns) |
|
df_features['label'] = labels |
|
|
|
scaler = StandardScaler() |
|
X_scaled = scaler.fit_transform(df_features.drop('label', axis=1)) |
|
df_scaled = pd.DataFrame(X_scaled, columns=columns) |
|
df_scaled['label'] = df_features['label'] |
|
|
|
processed_data_filename = 'processed_data.csv' |
|
df_scaled.to_csv(processed_data_filename, index=False) |
|
|
|
scaler_filename = 'scaler.pkl' |
|
with open(scaler_filename, 'wb') as file: |
|
pickle.dump(scaler, file) |
|
|
|
return "Data preprocessing complete! Download the processed data and scaler below.", processed_data_filename, scaler_filename |
|
|
|
except Exception as e: |
|
print(f"An error occurred during preprocessing: {e}") |
|
return f"An error occurred during preprocessing: {e}", None, None |
|
|
|
def train_model(): |
|
global global_data |
|
try: |
|
preprocess_status, processed_data_filename, scaler_filename = preprocess_data() |
|
if processed_data_filename is None: |
|
return preprocess_status, None, None |
|
|
|
df_scaled = pd.read_csv(processed_data_filename) |
|
X = df_scaled.drop('label', axis=1) |
|
y = df_scaled['label'] |
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) |
|
|
|
param_grid = {'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.1, 0.01, 0.001, 0.0001], 'kernel': ['rbf']} |
|
svc = SVC(probability=True) |
|
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, verbose=2, n_jobs=-1) |
|
grid_search.fit(X_train, y_train) |
|
|
|
model = grid_search.best_estimator_ |
|
model_filename = 'model.pkl' |
|
|
|
with open(model_filename, 'wb') as file: |
|
pickle.dump(model, file) |
|
|
|
return "Training complete! Download the model and scaler below.", model_filename, scaler_filename |
|
|
|
except Exception as e: |
|
print(f"An error occurred during training: {e}") |
|
return f"An error occurred during training: {e}", None, None |
|
|
|
with gr.Blocks() as demo: |
|
file_input = gr.File(label="Upload CSV File") |
|
data_preview = gr.Dataframe(label="Data Preview", interactive=False) |
|
ranges_input = gr.Dataframe(headers=["Start Index", "End Index", "Label"], label="Ranges for Labeling") |
|
labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False) |
|
|
|
training_status = gr.Textbox(label="Training Status") |
|
model_file = gr.File(label="Download Trained Model") |
|
scaler_file = gr.File(label="Download Scaler") |
|
|
|
file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview) |
|
label_button = gr.Button("Label Data") |
|
label_button.click(label_data, inputs=[ranges_input], outputs=labeled_data_preview, queue=True) |
|
train_button = gr.Button("Train Model") |
|
train_button.click(train_model, outputs=[training_status, model_file, scaler_file]) |
|
|
|
demo.launch() |
|
|