Mohammad Javad Darvishi commited on
Commit
13f97c3
1 Parent(s): db8f37b

'first working version of the demo'

Browse files
Files changed (2) hide show
  1. app.py +20 -2
  2. misc.py +92 -0
app.py CHANGED
@@ -3,17 +3,35 @@ import streamlit as st
3
  import mne
4
  import matplotlib.pyplot as plt
5
  import os
 
 
 
6
  from misc import *
7
 
 
 
 
 
8
 
 
9
  # Load the edf file
10
- edf_file = st.file_uploader("Upload an EEG edf file", type="edf")
 
 
11
 
12
 
13
  if edf_file is not None:
14
-
15
  # Read the file
16
  raw = read_file(edf_file)
17
 
18
  # Preprocess and plot the data
19
  preprocessing_and_plotting(raw)
 
 
 
 
 
 
 
 
 
3
  import mne
4
  import matplotlib.pyplot as plt
5
  import os
6
+ import streamlit as st
7
+ import random
8
+
9
  from misc import *
10
 
11
+ import streamlit as st
12
+
13
+ # Create two columns with st.columns (new way)
14
+ col1, col2 = st.columns(2)
15
 
16
+ # Create the upload button in the first column
17
  # Load the edf file
18
+ edf_file = col1.file_uploader("Upload an EEG edf file", type="edf")
19
+ # Create the result placeholder button in the second column
20
+ col2.button('Result:')
21
 
22
 
23
  if edf_file is not None:
24
+
25
  # Read the file
26
  raw = read_file(edf_file)
27
 
28
  # Preprocess and plot the data
29
  preprocessing_and_plotting(raw)
30
+
31
+ # Build the model
32
+ clf = build_model(model_name='deep4net', n_classes=2, n_chans=21, input_window_samples=6000)
33
+
34
+ output = predict(raw,clf)
35
+
36
+ # # Print the output
37
+ set_button_state (output,col2)
misc.py CHANGED
@@ -3,7 +3,99 @@ import mne
3
  import streamlit as st
4
  import matplotlib.pyplot as plt
5
 
 
 
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def preprocessing_and_plotting(raw):
9
  # Select the first channel
 
3
  import streamlit as st
4
  import matplotlib.pyplot as plt
5
 
6
+ from braindecode import EEGClassifier
7
+ from braindecode.models import Deep4Net,ShallowFBCSPNet,EEGNetv4, TCN
8
+ from braindecode.training.losses import CroppedLoss
9
 
10
+ import torch
11
+ import numpy as np
12
+
13
+ def set_button_state(output,col):
14
+ # Generate a random output value of 0 or 1
15
+ # output = 2023 #random.randint(0, 1)
16
+
17
+ # Store the output value in session state
18
+ st.session_state.output = output
19
+
20
+ # Define the button color and text based on the output value
21
+ if st.session_state.output == 0:
22
+ button_color = "green"
23
+ button_text = "Normal"
24
+ elif st.session_state.output == 1:
25
+ button_color = "red"
26
+ button_text = "Abnormal"
27
+ # elif st.session_state.output == 3:
28
+ # button_color = "yellow"
29
+ # button_text = "Waiting"
30
+ else:
31
+ button_color = "gray"
32
+ button_text = "Unknown"
33
+
34
+ # Create a custom HTML button with CSS styling
35
+ col.markdown(f"""
36
+ <style>
37
+ .custom-button {{
38
+ background-color: {button_color};
39
+ color: black;
40
+ padding: 10px 20px;
41
+ border: none;
42
+ border-radius: 5px;
43
+ cursor: pointer;
44
+ }}
45
+ </style>
46
+ <button class="custom-button">Output: {button_text}</button>
47
+ """, unsafe_allow_html=True)
48
+
49
+
50
+ def predict(raw,clf):
51
+ x = np.expand_dims(raw.get_data()[:21, :6000], axis=0)
52
+ output = clf.predict(x)
53
+ return output
54
+
55
+
56
+ def build_model(model_name, n_classes, n_chans, input_window_samples, drop_prob=0.5, lr=0.01):#, weight_decay, batch_size, n_epochs, wandb_run, checkpoint, optimizer__param_groups, window_train_set, window_val):
57
+ n_start_chans = 25
58
+ final_conv_length = 1
59
+ n_chan_factor = 2
60
+ stride_before_pool = True
61
+ # input_window_samples =6000
62
+ model = Deep4Net(
63
+ n_chans, n_classes,
64
+ n_filters_time=n_start_chans,
65
+ n_filters_spat=n_start_chans,
66
+ input_window_samples=input_window_samples,
67
+ n_filters_2=int(n_start_chans * n_chan_factor),
68
+ n_filters_3=int(n_start_chans * (n_chan_factor ** 2.0)),
69
+ n_filters_4=int(n_start_chans * (n_chan_factor ** 3.0)),
70
+ final_conv_length=final_conv_length,
71
+ stride_before_pool=stride_before_pool,
72
+ drop_prob=drop_prob)
73
+
74
+ clf = EEGClassifier(
75
+ model,
76
+ cropped=True,
77
+ criterion=CroppedLoss,
78
+ # criterion=CroppedLoss_sd,
79
+ criterion__loss_function=torch.nn.functional.nll_loss,
80
+ optimizer=torch.optim.AdamW,
81
+ optimizer__lr=lr,
82
+ iterator_train__shuffle=False,
83
+ # iterator_train__sampler = ImbalancedDatasetSampler(window_train_set, labels=window_train_set.get_metadata().target),
84
+ # batch_size=batch_size,
85
+ callbacks=[
86
+ # EarlyStopping(patience=5),
87
+ # StochasticWeightAveraging(swa_utils, swa_start=1, verbose=1, swa_lr=lr),
88
+ # "accuracy", "balanced_accuracy","f1",("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
89
+ # checkpoint,
90
+ ], #"accuracy",
91
+ # device='cuda'
92
+ )
93
+ clf.initialize()
94
+ pt_path = './Deep4Net_trained_tuh_scaling_wN_WAug_DefArgs_index8_number2700_state_dict_100.pt'
95
+ clf.load_params(f_params=pt_path)
96
+
97
+ return clf
98
+
99
 
100
  def preprocessing_and_plotting(raw):
101
  # Select the first channel