import mne import streamlit as st import matplotlib.pyplot as plt from braindecode import EEGClassifier from braindecode.models import Deep4Net,ShallowFBCSPNet,EEGNetv4, TCN from braindecode.training.losses import CroppedLoss import torch import numpy as np def set_button_state(output,col): # Generate a random output value of 0 or 1 # output = 2023 #random.randint(0, 1) # Store the output value in session state st.session_state.output = output # Define the button color and text based on the output value if st.session_state.output == 0: button_color = "green" button_text = "Normal" elif st.session_state.output == 1: button_color = "red" button_text = "Abnormal" # elif st.session_state.output == 3: # button_color = "yellow" # button_text = "Waiting" else: button_color = "gray" button_text = "Unknown" # Create a custom HTML button with CSS styling col.markdown(f""" """, unsafe_allow_html=True) def predict(raw,clf): x = np.expand_dims(raw.get_data()[:21, :6000], axis=0) output = clf.predict(x) return output def build_model(model_name, n_classes, n_chans, input_window_samples, drop_prob=0.5, lr=0.01):#, weight_decay, batch_size, n_epochs, wandb_run, checkpoint, optimizer__param_groups, window_train_set, window_val): n_start_chans = 25 final_conv_length = 1 n_chan_factor = 2 stride_before_pool = True # input_window_samples =6000 model = Deep4Net( n_chans, n_classes, n_filters_time=n_start_chans, n_filters_spat=n_start_chans, input_window_samples=input_window_samples, n_filters_2=int(n_start_chans * n_chan_factor), n_filters_3=int(n_start_chans * (n_chan_factor ** 2.0)), n_filters_4=int(n_start_chans * (n_chan_factor ** 3.0)), final_conv_length=final_conv_length, stride_before_pool=stride_before_pool, drop_prob=drop_prob) clf = EEGClassifier( model, cropped=True, criterion=CroppedLoss, # criterion=CroppedLoss_sd, criterion__loss_function=torch.nn.functional.nll_loss, optimizer=torch.optim.AdamW, optimizer__lr=lr, iterator_train__shuffle=False, # iterator_train__sampler = ImbalancedDatasetSampler(window_train_set, labels=window_train_set.get_metadata().target), # batch_size=batch_size, callbacks=[ # EarlyStopping(patience=5), # StochasticWeightAveraging(swa_utils, swa_start=1, verbose=1, swa_lr=lr), # "accuracy", "balanced_accuracy","f1",("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)), # checkpoint, ], #"accuracy", # device='cuda' ) clf.initialize() pt_path = './Deep4Net_trained_tuh_scaling_wN_WAug_DefArgs_index8_number2700_state_dict_100.pt' clf.load_params(f_params=pt_path) return clf def preprocessing_and_plotting(raw): fig = raw.plot(duration=10, scalings='auto',remove_dc=True,show_scrollbars=False) #, n_channels=10 st.pyplot(fig) # # Plot the power spectrum # fig, ax = plt.subplots() # raw.plot_psd(fmin=1, fmax=60, ax=ax) # st.pyplot(fig) # # Plot the spectrogram # fig, ax = plt.subplots() # raw.plot_spectrogram(n_fft=512, ax=ax) # st.pyplot(fig) # # Select the first channel # channel = raw.ch_names[0] # st.write(f"Selected channel: {channel}") # # Plot the first channel # fig, ax = plt.subplots() # ax.plot(raw.times, raw[channel][0].T) # ax.set_xlabel("Time (s)") # ax.set_ylabel("Amplitude (µV)") # ax.set_title(f"EEG signal of {channel}") # st.pyplot(fig) def read_file(edf_file): # To read file as bytes: bytes_data = edf_file.getvalue() # Open a file named "output.bin" in the current directory in write binary mode with open('edf_file.edf', "wb") as f: # Write the bytes data to the file f.write(bytes_data) raw = mne.io.read_raw_edf('edf_file.edf') st.write(f"Loaded {edf_file.name} with {raw.info['nchan']} channels") return raw