File size: 7,106 Bytes
5e22eaa
 
 
 
 
 
 
6acdc14
 
5e22eaa
2c8fa83
9ac0ba6
 
5e22eaa
9ac0ba6
 
17217b4
915c4bc
9ac0ba6
5e22eaa
9ac0ba6
 
17217b4
 
cc09ded
 
bd3d9f9
cc09ded
bd3d9f9
 
9ac0ba6
9956f18
5e22eaa
2c8fa83
 
5c8bfc0
bd3d9f9
5c8bfc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e22eaa
5c8bfc0
 
6acdc14
bd3d9f9
 
6acdc14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd3d9f9
6acdc14
 
bd3d9f9
6acdc14
 
 
 
 
5e22eaa
 
 
 
9ac0ba6
5e22eaa
bd3d9f9
6acdc14
 
5e22eaa
bd3d9f9
5e22eaa
 
6acdc14
 
 
5e22eaa
2c8fa83
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import scipy
from scipy import signal
import pickle
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split, GridSearchCV

# Global variable to store the uploaded data
global_data = None

def get_data_preview(file):
    global global_data
    global_data = pd.read_csv(file.name)
    global_data['label'] = np.nan  # Initialize a label column
    global_data['label'] = global_data['label'].astype(object)  # Ensure the label column can hold different types
    return global_data.head()

def label_data(ranges):
    global global_data
    print("Ranges received for labeling:", ranges)
    for i, (start, end, label) in enumerate(ranges.values):
        start = int(start)
        end = int(end)
        if start < 0 or start >= len(global_data):
            continue
        if end >= len(global_data):
            end = len(global_data) - 1
        global_data.loc[start:end, 'label'] = label
    return global_data.tail()

def preprocess_data():
    global global_data
    try:
        global_data.drop(columns=global_data.columns[0], axis=1, inplace=True)
        global_data.columns = ['raw_eeg', 'label']
        raw_data = global_data['raw_eeg']
        labels_old = global_data['label']

        sampling_rate = 512
        notch_freq = 50.0
        lowcut, highcut = 0.5, 30.0

        nyquist = (0.5 * sampling_rate)
        notch_freq_normalized = notch_freq / nyquist
        b_notch, a_notch = signal.iirnotch(notch_freq_normalized, Q=0.05, fs=sampling_rate)

        lowcut_normalized = lowcut / nyquist
        highcut_normalized = highcut / nyquist
        b_bandpass, a_bandpass = signal.butter(4, [lowcut_normalized, highcut_normalized], btype='band')

        features = []
        labels = []

        def calculate_psd_features(segment, sampling_rate):
            f, psd_values = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
            alpha_indices = np.where((f >= 8) & (f <= 13))
            beta_indices = np.where((f >= 14) & (f <= 30))
            theta_indices = np.where((f >= 4) & (f <= 7))
            delta_indices = np.where((f >= 0.5) & (f <= 3))
            energy_alpha = np.sum(psd_values[alpha_indices])
            energy_beta = np.sum(psd_values[beta_indices])
            energy_theta = np.sum(psd_values[theta_indices])
            energy_delta = np.sum(psd_values[delta_indices])
            alpha_beta_ratio = energy_alpha / energy_beta
            return {
                'E_alpha': energy_alpha,
                'E_beta': energy_beta,
                'E_theta': energy_theta,
                'E_delta': energy_delta,
                'alpha_beta_ratio': alpha_beta_ratio
            }

        def calculate_additional_features(segment, sampling_rate):
            f, psd = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
            peak_frequency = f[np.argmax(psd)]
            spectral_centroid = np.sum(f * psd) / np.sum(psd)
            log_f = np.log(f[1:])
            log_psd = np.log(psd[1:])
            spectral_slope = np.polyfit(log_f, log_psd, 1)[0]
            return {
                'peak_frequency': peak_frequency,
                'spectral_centroid': spectral_centroid,
                'spectral_slope': spectral_slope
            }

        for i in range(0, len(raw_data) - 512, 256):
            print(f"Processing segment {i} to {i + 512}")
            segment = raw_data.loc[i:i+512]
            segment = pd.to_numeric(segment, errors='coerce')
            segment = signal.filtfilt(b_notch, a_notch, segment)
            segment = signal.filtfilt(b_bandpass, a_bandpass, segment)
            segment_features = calculate_psd_features(segment, 512)
            additional_features = calculate_additional_features(segment, 512)
            segment_features = {**segment_features, **additional_features}
            features.append(segment_features)
            labels.append(labels_old[i])

        columns = ['E_alpha', 'E_beta', 'E_theta', 'E_delta', 'alpha_beta_ratio', 'peak_frequency', 'spectral_centroid', 'spectral_slope']
        df_features = pd.DataFrame(features, columns=columns)
        df_features['label'] = labels

        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(df_features.drop('label', axis=1))
        df_scaled = pd.DataFrame(X_scaled, columns=columns)
        df_scaled['label'] = df_features['label']

        processed_data_filename = 'processed_data.csv'
        df_scaled.to_csv(processed_data_filename, index=False)

        scaler_filename = 'scaler.pkl'
        with open(scaler_filename, 'wb') as file:
            pickle.dump(scaler, file)

        return "Data preprocessing complete! Download the processed data and scaler below.", processed_data_filename, scaler_filename
    
    except Exception as e:
        return f"An error occurred during preprocessing: {e}", None, None

def train_model():
    global global_data
    try:
        preprocess_status, processed_data_filename, scaler_filename = preprocess_data()
        if processed_data_filename is None:
            return preprocess_status, None, None
        
        df_scaled = pd.read_csv(processed_data_filename)
        X = df_scaled.drop('label', axis=1)
        y = df_scaled['label']
        
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        
        param_grid = {'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.1, 0.01, 0.001, 0.0001], 'kernel': ['rbf']}
        svc = SVC(probability=True)
        grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, verbose=2, n_jobs=-1)
        grid_search.fit(X_train, y_train)
        
        model = grid_search.best_estimator_
        model_filename = 'model.pkl'

        with open(model_filename, 'wb') as file:
            pickle.dump(model, file)

        return "Training complete! Download the model and scaler below.", model_filename, scaler_filename
    
    except Exception as e:
        print(f"An error occurred during training: {e}")
        return f"An error occurred during training: {e}", None, None

with gr.Blocks() as demo:
    file_input = gr.File(label="Upload CSV File")
    data_preview = gr.Dataframe(label="Data Preview", interactive=False)
    ranges_input = gr.Dataframe(headers=["Start Index", "End Index", "Label"], label="Ranges for Labeling")
    labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False)

    training_status = gr.Textbox(label="Training Status")
    model_file = gr.File(label="Download Trained Model")
    scaler_file = gr.File(label="Download Scaler")

    file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview)
    label_button = gr.Button("Label Data")
    label_button.click(label_data, inputs=[ranges_input], outputs=labeled_data_preview, queue=True)
    train_button = gr.Button("Train Model")
    train_button.click(train_model, outputs=[training_status, model_file, scaler_file])

demo.launch()