Umang-Bansal commited on
Commit
dda0def
1 Parent(s): e7b69b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -106
app.py CHANGED
@@ -1,131 +1,174 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
4
- from sklearn.model_selection import train_test_split, GridSearchCV
5
- from sklearn.svm import SVC
6
  from sklearn.preprocessing import StandardScaler
7
  import scipy
8
  from scipy import signal
9
  import pickle
 
 
 
 
 
10
 
11
  def get_data_preview(file):
12
- data = pd.read_csv(file.name)
13
- return data.head()
14
-
15
- def label_data(file, start, end, label):
16
- data = pd.read_csv(file.name)
17
- data.loc[start:end, 'label'] = label # Label the specified range
18
- return data
19
-
20
- def preprocess_data(data):
21
- data.drop(columns=data.columns[0], axis=1, inplace=True)
22
- data.columns = ['raw_eeg', 'label']
23
- raw_data = data['raw_eeg']
24
- labels_old = data['label']
25
-
26
- sampling_rate = 512
27
- notch_freq = 50.0
28
- lowcut, highcut = 0.5, 30.0
29
-
30
- nyquist = (0.5 * sampling_rate)
31
- notch_freq_normalized = notch_freq / nyquist
32
- b_notch, a_notch = signal.iirnotch(notch_freq_normalized, Q=0.05, fs=sampling_rate)
33
-
34
- lowcut_normalized = lowcut / nyquist
35
- highcut_normalized = highcut / nyquist
36
- b_bandpass, a_bandpass = signal.butter(4, [lowcut_normalized, highcut_normalized], btype='band')
37
-
38
- features = []
39
- labels = []
40
-
41
- def calculate_psd_features(segment, sampling_rate):
42
- f, psd_values = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
43
- alpha_indices = np.where((f >= 8) & (f <= 13))
44
- beta_indices = np.where((f >= 14) & (f <= 30))
45
- theta_indices = np.where((f >= 4) & (f <= 7))
46
- delta_indices = np.where((f >= 0.5) & (f <= 3))
47
- energy_alpha = np.sum(psd_values[alpha_indices])
48
- energy_beta = np.sum(psd_values[beta_indices])
49
- energy_theta = np.sum(psd_values[theta_indices])
50
- energy_delta = np.sum(psd_values[delta_indices])
51
- alpha_beta_ratio = energy_alpha / energy_beta
52
- return {
53
- 'E_alpha': energy_alpha,
54
- 'E_beta': energy_beta,
55
- 'E_theta': energy_theta,
56
- 'E_delta': energy_delta,
57
- 'alpha_beta_ratio': alpha_beta_ratio
58
- }
59
-
60
- def calculate_additional_features(segment, sampling_rate):
61
- f, psd = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
62
- peak_frequency = f[np.argmax(psd)]
63
- spectral_centroid = np.sum(f * psd) / np.sum(psd)
64
- log_f = np.log(f[1:])
65
- log_psd = np.log(psd[1:])
66
- spectral_slope = np.polyfit(log_f, log_psd, 1)[0]
67
- return {
68
- 'peak_frequency': peak_frequency,
69
- 'spectral_centroid': spectral_centroid,
70
- 'spectral_slope': spectral_slope
71
- }
72
-
73
- for i in range(0, len(raw_data) - 512, 256):
74
- segment = raw_data.loc[i:i+512]
75
- segment = pd.to_numeric(segment, errors='coerce')
76
- segment = signal.filtfilt(b_notch, a_notch, segment)
77
- segment = signal.filtfilt(b_bandpass, a_bandpass, segment)
78
- segment_features = calculate_psd_features(segment, 512)
79
- additional_features = calculate_additional_features(segment, 512)
80
- segment_features = {**segment_features, **additional_features}
81
- features.append(segment_features)
82
- labels.append(labels_old[i])
83
-
84
- columns = ['E_alpha', 'E_beta', 'E_theta', 'E_delta', 'alpha_beta_ratio', 'peak_frequency', 'spectral_centroid', 'spectral_slope']
85
- df_features = pd.DataFrame(features, columns=columns)
86
- df_features['label'] = labels
87
- return df_features
88
-
89
- def train_model(data):
90
- scaler = StandardScaler()
91
- X = data.drop('label', axis=1)
92
- y = data['label']
93
- X_scaled = scaler.fit_transform(X)
94
- X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- param_grid = {'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.1, 0.01, 0.001, 0.0001], 'kernel': ['rbf']}
97
- svc = SVC(probability=True)
98
- grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, verbose=2, n_jobs=-1)
99
- grid_search.fit(X_train, y_train)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- model = grid_search.best_estimator_
102
- model_filename = 'model.pkl'
103
- scaler_filename = 'scaler.pkl'
104
-
105
- with open(model_filename, 'wb') as file:
106
- pickle.dump(model, file)
107
-
108
- with open(scaler_filename, 'wb') as file:
109
- pickle.dump(scaler, file)
110
-
111
- return f"Training complete! Model and scaler saved.", gr.File(model_filename), gr.File(scaler_filename)
112
-
113
 
114
  with gr.Blocks() as demo:
115
  file_input = gr.File(label="Upload CSV File")
116
  data_preview = gr.Dataframe(label="Data Preview", interactive=False)
117
- start_input = gr.Number(label="Start Index", value=0)
118
- end_input = gr.Number(label="End Index", value=100)
119
- label_input = gr.Number(label="Label Value", value=1)
120
  labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False)
 
121
  training_status = gr.Textbox(label="Training Status")
122
  model_file = gr.File(label="Download Trained Model")
123
  scaler_file = gr.File(label="Download Scaler")
124
 
125
  file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview)
126
  label_button = gr.Button("Label Data")
127
- label_button.click(label_data, inputs=[file_input, start_input, end_input, label_input], outputs=labeled_data_preview)
128
  train_button = gr.Button("Train Model")
129
- train_button.click(train_model, inputs=labeled_data_preview, outputs=[training_status, model_file, scaler_file])
130
 
131
  demo.launch()
 
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
 
 
4
  from sklearn.preprocessing import StandardScaler
5
  import scipy
6
  from scipy import signal
7
  import pickle
8
+ from sklearn.svm import SVC
9
+ from sklearn.model_selection import train_test_split, GridSearchCV
10
+
11
+ # Global variable to store the uploaded data
12
+ global_data = None
13
 
14
  def get_data_preview(file):
15
+ global global_data
16
+ global_data = pd.read_csv(file.name)
17
+ global_data['label'] = np.nan # Initialize a label column
18
+ global_data['label'] = global_data['label'].astype(object) # Ensure the label column can hold different types
19
+ print("Data preview:\n", global_data.head())
20
+ return global_data.head()
21
+
22
+ def label_data(ranges):
23
+ global global_data
24
+ print("Ranges received for labeling:", ranges)
25
+ for i, (start, end, label) in enumerate(ranges.values):
26
+ start = int(start)
27
+ end = int(end)
28
+ print(f"Processing range {i}: start={start}, end={end}, label={label}")
29
+ if start < 0 or start >= len(global_data):
30
+ print(f"Invalid range: start={start}, end={end}, label={label}")
31
+ continue
32
+ if end >= len(global_data):
33
+ print(f"End index {end} exceeds data length {len(global_data)}. Adjusting to {len(global_data) - 1}.")
34
+ end = len(global_data) - 1
35
+ global_data.loc[start:end, 'label'] = label
36
+ print("Data after labeling:\n", global_data.tail())
37
+ return global_data.tail()
38
+
39
+ def preprocess_data():
40
+ global global_data
41
+ try:
42
+ global_data.drop(columns=global_data.columns[0], axis=1, inplace=True)
43
+ global_data.columns = ['raw_eeg', 'label']
44
+ raw_data = global_data['raw_eeg']
45
+ labels_old = global_data['label']
46
+
47
+ sampling_rate = 512
48
+ notch_freq = 50.0
49
+ lowcut, highcut = 0.5, 30.0
50
+
51
+ nyquist = (0.5 * sampling_rate)
52
+ notch_freq_normalized = notch_freq / nyquist
53
+ b_notch, a_notch = signal.iirnotch(notch_freq_normalized, Q=0.05, fs=sampling_rate)
54
+
55
+ lowcut_normalized = lowcut / nyquist
56
+ highcut_normalized = highcut / nyquist
57
+ b_bandpass, a_bandpass = signal.butter(4, [lowcut_normalized, highcut_normalized], btype='band')
58
+
59
+ features = []
60
+ labels = []
61
+
62
+ def calculate_psd_features(segment, sampling_rate):
63
+ f, psd_values = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
64
+ alpha_indices = np.where((f >= 8) & (f <= 13))
65
+ beta_indices = np.where((f >= 14) & (f <= 30))
66
+ theta_indices = np.where((f >= 4) & (f <= 7))
67
+ delta_indices = np.where((f >= 0.5) & (f <= 3))
68
+ energy_alpha = np.sum(psd_values[alpha_indices])
69
+ energy_beta = np.sum(psd_values[beta_indices])
70
+ energy_theta = np.sum(psd_values[theta_indices])
71
+ energy_delta = np.sum(psd_values[delta_indices])
72
+ alpha_beta_ratio = energy_alpha / energy_beta
73
+ return {
74
+ 'E_alpha': energy_alpha,
75
+ 'E_beta': energy_beta,
76
+ 'E_theta': energy_theta,
77
+ 'E_delta': energy_delta,
78
+ 'alpha_beta_ratio': alpha_beta_ratio
79
+ }
80
+
81
+ def calculate_additional_features(segment, sampling_rate):
82
+ f, psd = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
83
+ peak_frequency = f[np.argmax(psd)]
84
+ spectral_centroid = np.sum(f * psd) / np.sum(psd)
85
+ log_f = np.log(f[1:])
86
+ log_psd = np.log(psd[1:])
87
+ spectral_slope = np.polyfit(log_f, log_psd, 1)[0]
88
+ return {
89
+ 'peak_frequency': peak_frequency,
90
+ 'spectral_centroid': spectral_centroid,
91
+ 'spectral_slope': spectral_slope
92
+ }
93
+
94
+ for i in range(0, len(raw_data) - 512, 256):
95
+ print(f"Processing segment {i} to {i + 512}")
96
+ segment = raw_data.loc[i:i+512]
97
+ segment = pd.to_numeric(segment, errors='coerce')
98
+ segment = signal.filtfilt(b_notch, a_notch, segment)
99
+ segment = signal.filtfilt(b_bandpass, a_bandpass, segment)
100
+ segment_features = calculate_psd_features(segment, 512)
101
+ additional_features = calculate_additional_features(segment, 512)
102
+ segment_features = {**segment_features, **additional_features}
103
+ features.append(segment_features)
104
+ labels.append(labels_old[i])
105
+
106
+ columns = ['E_alpha', 'E_beta', 'E_theta', 'E_delta', 'alpha_beta_ratio', 'peak_frequency', 'spectral_centroid', 'spectral_slope']
107
+ df_features = pd.DataFrame(features, columns=columns)
108
+ df_features['label'] = labels
109
+
110
+ scaler = StandardScaler()
111
+ X_scaled = scaler.fit_transform(df_features.drop('label', axis=1))
112
+ df_scaled = pd.DataFrame(X_scaled, columns=columns)
113
+ df_scaled['label'] = df_features['label']
114
+
115
+ processed_data_filename = 'processed_data.csv'
116
+ df_scaled.to_csv(processed_data_filename, index=False)
117
+
118
+ scaler_filename = 'scaler.pkl'
119
+ with open(scaler_filename, 'wb') as file:
120
+ pickle.dump(scaler, file)
121
+
122
+ return "Data preprocessing complete! Download the processed data and scaler below.", processed_data_filename, scaler_filename
123
 
124
+ except Exception as e:
125
+ print(f"An error occurred during preprocessing: {e}")
126
+ return f"An error occurred during preprocessing: {e}", None, None
127
+
128
+ def train_model():
129
+ global global_data
130
+ try:
131
+ preprocess_status, processed_data_filename, scaler_filename = preprocess_data()
132
+ if processed_data_filename is None:
133
+ return preprocess_status, None, None
134
+
135
+ df_scaled = pd.read_csv(processed_data_filename)
136
+ X = df_scaled.drop('label', axis=1)
137
+ y = df_scaled['label']
138
+
139
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
140
+
141
+ param_grid = {'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.1, 0.01, 0.001, 0.0001], 'kernel': ['rbf']}
142
+ svc = SVC(probability=True)
143
+ grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, verbose=2, n_jobs=-1)
144
+ grid_search.fit(X_train, y_train)
145
+
146
+ model = grid_search.best_estimator_
147
+ model_filename = 'model.pkl'
148
+
149
+ with open(model_filename, 'wb') as file:
150
+ pickle.dump(model, file)
151
+
152
+ return "Training complete! Download the model and scaler below.", model_filename, scaler_filename
153
 
154
+ except Exception as e:
155
+ print(f"An error occurred during training: {e}")
156
+ return f"An error occurred during training: {e}", None, None
 
 
 
 
 
 
 
 
 
157
 
158
  with gr.Blocks() as demo:
159
  file_input = gr.File(label="Upload CSV File")
160
  data_preview = gr.Dataframe(label="Data Preview", interactive=False)
161
+ ranges_input = gr.Dataframe(headers=["Start Index", "End Index", "Label"], label="Ranges for Labeling")
 
 
162
  labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False)
163
+
164
  training_status = gr.Textbox(label="Training Status")
165
  model_file = gr.File(label="Download Trained Model")
166
  scaler_file = gr.File(label="Download Scaler")
167
 
168
  file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview)
169
  label_button = gr.Button("Label Data")
170
+ label_button.click(label_data, inputs=[ranges_input], outputs=labeled_data_preview, queue=True)
171
  train_button = gr.Button("Train Model")
172
+ train_button.click(train_model, outputs=[training_status, model_file, scaler_file])
173
 
174
  demo.launch()