import mne import streamlit as st import matplotlib.pyplot as plt from braindecode import EEGClassifier from braindecode.models import Deep4Net,ShallowFBCSPNet,EEGNetv4, TCN from import CroppedLoss import torch import numpy as np def set_button_state(output,col): # Generate a random output value of 0 or 1 # output = 2023 #random.randint(0, 1) # Store the output value in session state st.session_state.output = output # Define the button color and text based on the output value if st.session_state.output == 0: button_color = "green" button_text = "Normal" elif st.session_state.output == 1: button_color = "red" button_text = "Abnormal" # elif st.session_state.output == 3: # button_color = "yellow" # button_text = "Waiting" else: button_color = "gray" button_text = "Unknown" # Create a custom HTML button with CSS styling col.markdown(f""" """, unsafe_allow_html=True) def predict(raw,clf): x = np.expand_dims(raw.get_data()[:21, :6000], axis=0) output = clf.predict(x) return output def build_model(model_name, n_classes, n_chans, input_window_samples, drop_prob=0.5, lr=0.01):#, weight_decay, batch_size, n_epochs, wandb_run, checkpoint, optimizer__param_groups, window_train_set, window_val): n_start_chans = 25 final_conv_length = 1 n_chan_factor = 2 stride_before_pool = True # input_window_samples =6000 model = Deep4Net( n_chans, n_classes, n_filters_time=n_start_chans, n_filters_spat=n_start_chans, input_window_samples=input_window_samples, n_filters_2=int(n_start_chans * n_chan_factor), n_filters_3=int(n_start_chans * (n_chan_factor ** 2.0)), n_filters_4=int(n_start_chans * (n_chan_factor ** 3.0)), final_conv_length=final_conv_length, stride_before_pool=stride_before_pool, drop_prob=drop_prob) clf = EEGClassifier( model, cropped=True, criterion=CroppedLoss, # criterion=CroppedLoss_sd, criterion__loss_function=torch.nn.functional.nll_loss, optimizer=torch.optim.AdamW, optimizer__lr=lr, iterator_train__shuffle=False, # iterator_train__sampler = ImbalancedDatasetSampler(window_train_set, labels=window_train_set.get_metadata().target), # batch_size=batch_size, callbacks=[ # EarlyStopping(patience=5), # StochasticWeightAveraging(swa_utils, swa_start=1, verbose=1, swa_lr=lr), # "accuracy", "balanced_accuracy","f1",("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)), # checkpoint, ], #"accuracy", # device='cuda' ) clf.initialize() pt_path = './' clf.load_params(f_params=pt_path) return clf def preprocessing_and_plotting(raw): # Select the first channel channel = raw.ch_names[0] st.write(f"Selected channel: {channel}") # Plot the first channel fig, ax = plt.subplots() ax.plot(raw.times, raw[channel][0].T) ax.set_xlabel("Time (s)") ax.set_ylabel("Amplitude (µV)") ax.set_title(f"EEG signal of {channel}") st.pyplot(fig) def read_file(edf_file): # To read file as bytes: bytes_data = edf_file.getvalue() # Open a file named "output.bin" in the current directory in write binary mode with open('edf_file.edf', "wb") as f: # Write the bytes data to the file f.write(bytes_data) raw ='edf_file.edf') st.write(f"Loaded {} with {['nchan']} channels") return raw