Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from sklearn.preprocessing import StandardScaler | |
import scipy | |
from scipy import signal | |
import pickle | |
from sklearn.svm import SVC | |
from sklearn.model_selection import train_test_split, GridSearchCV | |
# Global variable to store the uploaded data | |
global_data = None | |
def get_data_preview(file): | |
global global_data | |
global_data = pd.read_csv(file.name) | |
global_data['label'] = np.nan # Initialize a label column | |
global_data['label'] = global_data['label'].astype(object) # Ensure the label column can hold different types | |
return global_data.head() | |
def label_data(ranges): | |
global global_data | |
print("Ranges received for labeling:", ranges) | |
for i, (start, end, label) in enumerate(ranges.values): | |
start = int(start) | |
end = int(end) | |
if start < 0 or start >= len(global_data): | |
continue | |
if end >= len(global_data): | |
end = len(global_data) - 1 | |
global_data.loc[start:end, 'label'] = label | |
return global_data.tail() | |
def preprocess_data(): | |
global global_data | |
try: | |
global_data.drop(columns=global_data.columns[0], axis=1, inplace=True) | |
global_data.columns = ['raw_eeg', 'label'] | |
raw_data = global_data['raw_eeg'] | |
labels_old = global_data['label'] | |
sampling_rate = 512 | |
notch_freq = 50.0 | |
lowcut, highcut = 0.5, 30.0 | |
nyquist = (0.5 * sampling_rate) | |
notch_freq_normalized = notch_freq / nyquist | |
b_notch, a_notch = signal.iirnotch(notch_freq_normalized, Q=0.05, fs=sampling_rate) | |
lowcut_normalized = lowcut / nyquist | |
highcut_normalized = highcut / nyquist | |
b_bandpass, a_bandpass = signal.butter(4, [lowcut_normalized, highcut_normalized], btype='band') | |
features = [] | |
labels = [] | |
def calculate_psd_features(segment, sampling_rate): | |
f, psd_values = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment)) | |
alpha_indices = np.where((f >= 8) & (f <= 13)) | |
beta_indices = np.where((f >= 14) & (f <= 30)) | |
theta_indices = np.where((f >= 4) & (f <= 7)) | |
delta_indices = np.where((f >= 0.5) & (f <= 3)) | |
energy_alpha = np.sum(psd_values[alpha_indices]) | |
energy_beta = np.sum(psd_values[beta_indices]) | |
energy_theta = np.sum(psd_values[theta_indices]) | |
energy_delta = np.sum(psd_values[delta_indices]) | |
alpha_beta_ratio = energy_alpha / energy_beta | |
return { | |
'E_alpha': energy_alpha, | |
'E_beta': energy_beta, | |
'E_theta': energy_theta, | |
'E_delta': energy_delta, | |
'alpha_beta_ratio': alpha_beta_ratio | |
} | |
def calculate_additional_features(segment, sampling_rate): | |
f, psd = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment)) | |
peak_frequency = f[np.argmax(psd)] | |
spectral_centroid = np.sum(f * psd) / np.sum(psd) | |
log_f = np.log(f[1:]) | |
log_psd = np.log(psd[1:]) | |
spectral_slope = np.polyfit(log_f, log_psd, 1)[0] | |
return { | |
'peak_frequency': peak_frequency, | |
'spectral_centroid': spectral_centroid, | |
'spectral_slope': spectral_slope | |
} | |
for i in range(0, len(raw_data) - 512, 256): | |
print(f"Processing segment {i} to {i + 512}") | |
segment = raw_data.loc[i:i+512] | |
segment = pd.to_numeric(segment, errors='coerce') | |
segment = signal.filtfilt(b_notch, a_notch, segment) | |
segment = signal.filtfilt(b_bandpass, a_bandpass, segment) | |
segment_features = calculate_psd_features(segment, 512) | |
additional_features = calculate_additional_features(segment, 512) | |
segment_features = {**segment_features, **additional_features} | |
features.append(segment_features) | |
labels.append(labels_old[i]) | |
columns = ['E_alpha', 'E_beta', 'E_theta', 'E_delta', 'alpha_beta_ratio', 'peak_frequency', 'spectral_centroid', 'spectral_slope'] | |
df_features = pd.DataFrame(features, columns=columns) | |
df_features['label'] = labels | |
scaler = StandardScaler() | |
X_scaled = scaler.fit_transform(df_features.drop('label', axis=1)) | |
df_scaled = pd.DataFrame(X_scaled, columns=columns) | |
df_scaled['label'] = df_features['label'] | |
processed_data_filename = 'processed_data.csv' | |
df_scaled.to_csv(processed_data_filename, index=False) | |
scaler_filename = 'scaler.pkl' | |
with open(scaler_filename, 'wb') as file: | |
pickle.dump(scaler, file) | |
return "Data preprocessing complete! Download the processed data and scaler below.", processed_data_filename, scaler_filename | |
except Exception as e: | |
return f"An error occurred during preprocessing: {e}", None, None | |
def train_model(): | |
global global_data | |
try: | |
preprocess_status, processed_data_filename, scaler_filename = preprocess_data() | |
if processed_data_filename is None: | |
return preprocess_status, None, None | |
df_scaled = pd.read_csv(processed_data_filename) | |
X = df_scaled.drop('label', axis=1) | |
y = df_scaled['label'] | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
param_grid = {'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.1, 0.01, 0.001, 0.0001], 'kernel': ['rbf']} | |
svc = SVC(probability=True) | |
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, verbose=2, n_jobs=-1) | |
grid_search.fit(X_train, y_train) | |
model = grid_search.best_estimator_ | |
model_filename = 'model.pkl' | |
with open(model_filename, 'wb') as file: | |
pickle.dump(model, file) | |
return "Training complete! Download the model and scaler below.", model_filename, scaler_filename | |
except Exception as e: | |
print(f"An error occurred during training: {e}") | |
return f"An error occurred during training: {e}", None, None | |
with gr.Blocks() as demo: | |
file_input = gr.File(label="Upload CSV File") | |
data_preview = gr.Dataframe(label="Data Preview", interactive=False) | |
ranges_input = gr.Dataframe(headers=["Start Index", "End Index", "Label"], label="Ranges for Labeling") | |
labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False) | |
training_status = gr.Textbox(label="Training Status") | |
model_file = gr.File(label="Download Trained Model") | |
scaler_file = gr.File(label="Download Scaler") | |
file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview) | |
label_button = gr.Button("Label Data") | |
label_button.click(label_data, inputs=[ranges_input], outputs=labeled_data_preview, queue=True) | |
train_button = gr.Button("Train Model") | |
train_button.click(train_model, outputs=[training_status, model_file, scaler_file]) | |
demo.launch() | |