Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.preprocessing import StandardScaler | |
| import scipy | |
| from scipy import signal | |
| import pickle | |
| from sklearn.svm import SVC | |
| from sklearn.model_selection import train_test_split, GridSearchCV | |
| # Global variable to store the uploaded data | |
| global_data = None | |
| def get_data_preview(file): | |
| global global_data | |
| global_data = pd.read_csv(file.name) | |
| global_data['label'] = np.nan # Initialize a label column | |
| global_data['label'] = global_data['label'].astype(object) # Ensure the label column can hold different types | |
| return global_data.head() | |
| def label_data(ranges): | |
| global global_data | |
| print("Ranges received for labeling:", ranges) | |
| for i, (start, end, label) in enumerate(ranges.values): | |
| start = int(start) | |
| end = int(end) | |
| if start < 0 or start >= len(global_data): | |
| continue | |
| if end >= len(global_data): | |
| end = len(global_data) - 1 | |
| global_data.loc[start:end, 'label'] = label | |
| return global_data.tail() | |
| def preprocess_data(): | |
| global global_data | |
| try: | |
| global_data.drop(columns=global_data.columns[0], axis=1, inplace=True) | |
| global_data.columns = ['raw_eeg', 'label'] | |
| raw_data = global_data['raw_eeg'] | |
| labels_old = global_data['label'] | |
| sampling_rate = 512 | |
| notch_freq = 50.0 | |
| lowcut, highcut = 0.5, 30.0 | |
| nyquist = (0.5 * sampling_rate) | |
| notch_freq_normalized = notch_freq / nyquist | |
| b_notch, a_notch = signal.iirnotch(notch_freq_normalized, Q=0.05, fs=sampling_rate) | |
| lowcut_normalized = lowcut / nyquist | |
| highcut_normalized = highcut / nyquist | |
| b_bandpass, a_bandpass = signal.butter(4, [lowcut_normalized, highcut_normalized], btype='band') | |
| features = [] | |
| labels = [] | |
| def calculate_psd_features(segment, sampling_rate): | |
| f, psd_values = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment)) | |
| alpha_indices = np.where((f >= 8) & (f <= 13)) | |
| beta_indices = np.where((f >= 14) & (f <= 30)) | |
| theta_indices = np.where((f >= 4) & (f <= 7)) | |
| delta_indices = np.where((f >= 0.5) & (f <= 3)) | |
| energy_alpha = np.sum(psd_values[alpha_indices]) | |
| energy_beta = np.sum(psd_values[beta_indices]) | |
| energy_theta = np.sum(psd_values[theta_indices]) | |
| energy_delta = np.sum(psd_values[delta_indices]) | |
| alpha_beta_ratio = energy_alpha / energy_beta | |
| return { | |
| 'E_alpha': energy_alpha, | |
| 'E_beta': energy_beta, | |
| 'E_theta': energy_theta, | |
| 'E_delta': energy_delta, | |
| 'alpha_beta_ratio': alpha_beta_ratio | |
| } | |
| def calculate_additional_features(segment, sampling_rate): | |
| f, psd = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment)) | |
| peak_frequency = f[np.argmax(psd)] | |
| spectral_centroid = np.sum(f * psd) / np.sum(psd) | |
| log_f = np.log(f[1:]) | |
| log_psd = np.log(psd[1:]) | |
| spectral_slope = np.polyfit(log_f, log_psd, 1)[0] | |
| return { | |
| 'peak_frequency': peak_frequency, | |
| 'spectral_centroid': spectral_centroid, | |
| 'spectral_slope': spectral_slope | |
| } | |
| for i in range(0, len(raw_data) - 512, 256): | |
| print(f"Processing segment {i} to {i + 512}") | |
| segment = raw_data.loc[i:i+512] | |
| segment = pd.to_numeric(segment, errors='coerce') | |
| segment = signal.filtfilt(b_notch, a_notch, segment) | |
| segment = signal.filtfilt(b_bandpass, a_bandpass, segment) | |
| segment_features = calculate_psd_features(segment, 512) | |
| additional_features = calculate_additional_features(segment, 512) | |
| segment_features = {**segment_features, **additional_features} | |
| features.append(segment_features) | |
| labels.append(labels_old[i]) | |
| columns = ['E_alpha', 'E_beta', 'E_theta', 'E_delta', 'alpha_beta_ratio', 'peak_frequency', 'spectral_centroid', 'spectral_slope'] | |
| df_features = pd.DataFrame(features, columns=columns) | |
| df_features['label'] = labels | |
| scaler = StandardScaler() | |
| X_scaled = scaler.fit_transform(df_features.drop('label', axis=1)) | |
| df_scaled = pd.DataFrame(X_scaled, columns=columns) | |
| df_scaled['label'] = df_features['label'] | |
| processed_data_filename = 'processed_data.csv' | |
| df_scaled.to_csv(processed_data_filename, index=False) | |
| scaler_filename = 'scaler.pkl' | |
| with open(scaler_filename, 'wb') as file: | |
| pickle.dump(scaler, file) | |
| return "Data preprocessing complete! Download the processed data and scaler below.", processed_data_filename, scaler_filename | |
| except Exception as e: | |
| return f"An error occurred during preprocessing: {e}", None, None | |
| def train_model(): | |
| global global_data | |
| try: | |
| preprocess_status, processed_data_filename, scaler_filename = preprocess_data() | |
| if processed_data_filename is None: | |
| return preprocess_status, None, None | |
| df_scaled = pd.read_csv(processed_data_filename) | |
| X = df_scaled.drop('label', axis=1) | |
| y = df_scaled['label'] | |
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
| param_grid = {'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.1, 0.01, 0.001, 0.0001], 'kernel': ['rbf']} | |
| svc = SVC(probability=True) | |
| grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, verbose=2, n_jobs=-1) | |
| grid_search.fit(X_train, y_train) | |
| model = grid_search.best_estimator_ | |
| model_filename = 'model.pkl' | |
| with open(model_filename, 'wb') as file: | |
| pickle.dump(model, file) | |
| return "Training complete! Download the model and scaler below.", model_filename, scaler_filename | |
| except Exception as e: | |
| print(f"An error occurred during training: {e}") | |
| return f"An error occurred during training: {e}", None, None | |
| with gr.Blocks() as demo: | |
| file_input = gr.File(label="Upload CSV File") | |
| data_preview = gr.Dataframe(label="Data Preview", interactive=False) | |
| ranges_input = gr.Dataframe(headers=["Start Index", "End Index", "Label"], label="Ranges for Labeling") | |
| labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False) | |
| training_status = gr.Textbox(label="Training Status") | |
| model_file = gr.File(label="Download Trained Model") | |
| scaler_file = gr.File(label="Download Scaler") | |
| file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview) | |
| label_button = gr.Button("Label Data") | |
| label_button.click(label_data, inputs=[ranges_input], outputs=labeled_data_preview, queue=True) | |
| train_button = gr.Button("Train Model") | |
| train_button.click(train_model, outputs=[training_status, model_file, scaler_file]) | |
| demo.launch() | |