Spaces:
Sleeping
Sleeping
File size: 5,330 Bytes
5e22eaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
import scipy
from scipy import signal
import pickle
def get_data_preview(file):
data = pd.read_csv(file.name)
return data.head()
def label_data(file, start, end, label):
data = pd.read_csv(file.name)
data.loc[start:end, 'label'] = label # Label the specified range
return data
def preprocess_data(data):
data.drop(columns=data.columns[0], axis=1, inplace=True)
data.columns = ['raw_eeg', 'label']
raw_data = data['raw_eeg']
labels_old = data['label']
sampling_rate = 512
notch_freq = 50.0
lowcut, highcut = 0.5, 30.0
nyquist = (0.5 * sampling_rate)
notch_freq_normalized = notch_freq / nyquist
b_notch, a_notch = signal.iirnotch(notch_freq_normalized, Q=0.05, fs=sampling_rate)
lowcut_normalized = lowcut / nyquist
highcut_normalized = highcut / nyquist
b_bandpass, a_bandpass = signal.butter(4, [lowcut_normalized, highcut_normalized], btype='band')
features = []
labels = []
def calculate_psd_features(segment, sampling_rate):
f, psd_values = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
alpha_indices = np.where((f >= 8) & (f <= 13))
beta_indices = np.where((f >= 14) & (f <= 30))
theta_indices = np.where((f >= 4) & (f <= 7))
delta_indices = np.where((f >= 0.5) & (f <= 3))
energy_alpha = np.sum(psd_values[alpha_indices])
energy_beta = np.sum(psd_values[beta_indices])
energy_theta = np.sum(psd_values[theta_indices])
energy_delta = np.sum(psd_values[delta_indices])
alpha_beta_ratio = energy_alpha / energy_beta
return {
'E_alpha': energy_alpha,
'E_beta': energy_beta,
'E_theta': energy_theta,
'E_delta': energy_delta,
'alpha_beta_ratio': alpha_beta_ratio
}
def calculate_additional_features(segment, sampling_rate):
f, psd = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
peak_frequency = f[np.argmax(psd)]
spectral_centroid = np.sum(f * psd) / np.sum(psd)
log_f = np.log(f[1:])
log_psd = np.log(psd[1:])
spectral_slope = np.polyfit(log_f, log_psd, 1)[0]
return {
'peak_frequency': peak_frequency,
'spectral_centroid': spectral_centroid,
'spectral_slope': spectral_slope
}
for i in range(0, len(raw_data) - 512, 256):
segment = raw_data.loc[i:i+512]
segment = pd.to_numeric(segment, errors='coerce')
segment = signal.filtfilt(b_notch, a_notch, segment)
segment = signal.filtfilt(b_bandpass, a_bandpass, segment)
segment_features = calculate_psd_features(segment, 512)
additional_features = calculate_additional_features(segment, 512)
segment_features = {**segment_features, **additional_features}
features.append(segment_features)
labels.append(labels_old[i])
columns = ['E_alpha', 'E_beta', 'E_theta', 'E_delta', 'alpha_beta_ratio', 'peak_frequency', 'spectral_centroid', 'spectral_slope']
df_features = pd.DataFrame(features, columns=columns)
df_features['label'] = labels
return df_features
def train_model(data):
scaler = StandardScaler()
X = data.drop('label', axis=1)
y = data['label']
X_scaled = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
param_grid = {'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.1, 0.01, 0.001, 0.0001], 'kernel': ['rbf']}
svc = SVC(probability=True)
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, verbose=2, n_jobs=-1)
grid_search.fit(X_train, y_train)
model = grid_search.best_estimator_
model_filename = 'model.pkl'
scaler_filename = 'scaler.pkl'
with open(model_filename, 'wb') as file:
pickle.dump(model, file)
with open(scaler_filename, 'wb') as file:
pickle.dump(scaler, file)
return f"Training complete! Model and scaler saved.", gr.File(model_filename), gr.File(scaler_filename)
with gr.Blocks() as demo:
file_input = gr.File(label="Upload CSV File")
data_preview = gr.Dataframe(label="Data Preview", interactive=False)
start_input = gr.Number(label="Start Index", value=0)
end_input = gr.Number(label="End Index", value=100)
label_input = gr.Number(label="Label Value", value=1)
labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False)
training_status = gr.Textbox(label="Training Status")
model_file = gr.File(label="Download Trained Model")
scaler_file = gr.File(label="Download Scaler")
file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview)
label_button = gr.Button("Label Data")
label_button.click(label_data, inputs=[file_input, start_input, end_input, label_input], outputs=labeled_data_preview)
train_button = gr.Button("Train Model")
train_button.click(train_model, inputs=labeled_data_preview, outputs=[training_status, model_file, scaler_file])
demo.launch()
|