Umang-Bansal commited on
Commit
5e22eaa
1 Parent(s): 14176e4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sklearn.model_selection import train_test_split, GridSearchCV
5
+ from sklearn.svm import SVC
6
+ from sklearn.preprocessing import StandardScaler
7
+ import scipy
8
+ from scipy import signal
9
+ import pickle
10
+
11
+ def get_data_preview(file):
12
+ data = pd.read_csv(file.name)
13
+ return data.head()
14
+
15
+ def label_data(file, start, end, label):
16
+ data = pd.read_csv(file.name)
17
+ data.loc[start:end, 'label'] = label # Label the specified range
18
+ return data
19
+
20
+ def preprocess_data(data):
21
+ data.drop(columns=data.columns[0], axis=1, inplace=True)
22
+ data.columns = ['raw_eeg', 'label']
23
+ raw_data = data['raw_eeg']
24
+ labels_old = data['label']
25
+
26
+ sampling_rate = 512
27
+ notch_freq = 50.0
28
+ lowcut, highcut = 0.5, 30.0
29
+
30
+ nyquist = (0.5 * sampling_rate)
31
+ notch_freq_normalized = notch_freq / nyquist
32
+ b_notch, a_notch = signal.iirnotch(notch_freq_normalized, Q=0.05, fs=sampling_rate)
33
+
34
+ lowcut_normalized = lowcut / nyquist
35
+ highcut_normalized = highcut / nyquist
36
+ b_bandpass, a_bandpass = signal.butter(4, [lowcut_normalized, highcut_normalized], btype='band')
37
+
38
+ features = []
39
+ labels = []
40
+
41
+ def calculate_psd_features(segment, sampling_rate):
42
+ f, psd_values = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
43
+ alpha_indices = np.where((f >= 8) & (f <= 13))
44
+ beta_indices = np.where((f >= 14) & (f <= 30))
45
+ theta_indices = np.where((f >= 4) & (f <= 7))
46
+ delta_indices = np.where((f >= 0.5) & (f <= 3))
47
+ energy_alpha = np.sum(psd_values[alpha_indices])
48
+ energy_beta = np.sum(psd_values[beta_indices])
49
+ energy_theta = np.sum(psd_values[theta_indices])
50
+ energy_delta = np.sum(psd_values[delta_indices])
51
+ alpha_beta_ratio = energy_alpha / energy_beta
52
+ return {
53
+ 'E_alpha': energy_alpha,
54
+ 'E_beta': energy_beta,
55
+ 'E_theta': energy_theta,
56
+ 'E_delta': energy_delta,
57
+ 'alpha_beta_ratio': alpha_beta_ratio
58
+ }
59
+
60
+ def calculate_additional_features(segment, sampling_rate):
61
+ f, psd = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
62
+ peak_frequency = f[np.argmax(psd)]
63
+ spectral_centroid = np.sum(f * psd) / np.sum(psd)
64
+ log_f = np.log(f[1:])
65
+ log_psd = np.log(psd[1:])
66
+ spectral_slope = np.polyfit(log_f, log_psd, 1)[0]
67
+ return {
68
+ 'peak_frequency': peak_frequency,
69
+ 'spectral_centroid': spectral_centroid,
70
+ 'spectral_slope': spectral_slope
71
+ }
72
+
73
+ for i in range(0, len(raw_data) - 512, 256):
74
+ segment = raw_data.loc[i:i+512]
75
+ segment = pd.to_numeric(segment, errors='coerce')
76
+ segment = signal.filtfilt(b_notch, a_notch, segment)
77
+ segment = signal.filtfilt(b_bandpass, a_bandpass, segment)
78
+ segment_features = calculate_psd_features(segment, 512)
79
+ additional_features = calculate_additional_features(segment, 512)
80
+ segment_features = {**segment_features, **additional_features}
81
+ features.append(segment_features)
82
+ labels.append(labels_old[i])
83
+
84
+ columns = ['E_alpha', 'E_beta', 'E_theta', 'E_delta', 'alpha_beta_ratio', 'peak_frequency', 'spectral_centroid', 'spectral_slope']
85
+ df_features = pd.DataFrame(features, columns=columns)
86
+ df_features['label'] = labels
87
+ return df_features
88
+
89
+ def train_model(data):
90
+ scaler = StandardScaler()
91
+ X = data.drop('label', axis=1)
92
+ y = data['label']
93
+ X_scaled = scaler.fit_transform(X)
94
+ X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
95
+
96
+ param_grid = {'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.1, 0.01, 0.001, 0.0001], 'kernel': ['rbf']}
97
+ svc = SVC(probability=True)
98
+ grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, verbose=2, n_jobs=-1)
99
+ grid_search.fit(X_train, y_train)
100
+
101
+ model = grid_search.best_estimator_
102
+ model_filename = 'model.pkl'
103
+ scaler_filename = 'scaler.pkl'
104
+
105
+ with open(model_filename, 'wb') as file:
106
+ pickle.dump(model, file)
107
+
108
+ with open(scaler_filename, 'wb') as file:
109
+ pickle.dump(scaler, file)
110
+
111
+ return f"Training complete! Model and scaler saved.", gr.File(model_filename), gr.File(scaler_filename)
112
+
113
+
114
+ with gr.Blocks() as demo:
115
+ file_input = gr.File(label="Upload CSV File")
116
+ data_preview = gr.Dataframe(label="Data Preview", interactive=False)
117
+ start_input = gr.Number(label="Start Index", value=0)
118
+ end_input = gr.Number(label="End Index", value=100)
119
+ label_input = gr.Number(label="Label Value", value=1)
120
+ labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False)
121
+ training_status = gr.Textbox(label="Training Status")
122
+ model_file = gr.File(label="Download Trained Model")
123
+ scaler_file = gr.File(label="Download Scaler")
124
+
125
+ file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview)
126
+ label_button = gr.Button("Label Data")
127
+ label_button.click(label_data, inputs=[file_input, start_input, end_input, label_input], outputs=labeled_data_preview)
128
+ train_button = gr.Button("Train Model")
129
+ train_button.click(train_model, inputs=labeled_data_preview, outputs=[training_status, model_file, scaler_file])
130
+
131
+ demo.launch()